diff --git a/.pylintrc b/.pylintrc index 6f4dcbf6..260d890f 100644 --- a/.pylintrc +++ b/.pylintrc @@ -121,7 +121,7 @@ enable = no-init, abstract-method, invalid-overridden-method, - arguments-differ, + # arguments-differ, signature-differs, bad-staticmethod-argument, useless-super-delegation, diff --git a/README.md b/README.md index 34e0e223..9b6c5eea 100644 --- a/README.md +++ b/README.md @@ -5,17 +5,17 @@ [![OS - Windows | Linux](https://img.shields.io/badge/OS-windows%20%7C%20linux-blue)](https://github.com/onnx/turnkeyml/blob/main/docs/install.md "Check out our instructions") [![Made with Python](https://img.shields.io/badge/Python-3.8,3.10-blue?logo=python&logoColor=white)](https://github.com/onnx/turnkeyml/blob/main/docs/install.md "Check out our instructions") -We are on a mission to make it easy to use the most important tools in the ONNX ecosystem. TurnkeyML accomplishes this by providing no-code CLIs and low-code APIs for both general ONNX workflows as well as LLMs. +We are on a mission to make it easy to use the most important tools in the ONNX ecosystem. TurnkeyML accomplishes this by providing no-code CLIs and low-code APIs for both general ONNX workflows with `turnkey` as well as LLMs with `lemonade`. -| [**Turnkey LLM**](https://github.com/onnx/turnkeyml/tree/main/src/turnkeyml/llm) | [**Turnkey Classic**](https://github.com/onnx/turnkeyml/blob/main/docs/classic_getting_started.md) | +| [**Lemonade**](https://github.com/onnx/turnkeyml/tree/main/src/turnkeyml/llm) | [**Turnkey**](https://github.com/onnx/turnkeyml/blob/main/docs/classic_getting_started.md) | |:----------------------------------------------: |:-----------------------------------------------------------------: | -| Serve and benchmark LLMs on CPU, GPU, and NPU.
[Click here to get started with turnkey-llm.](https://github.com/onnx/turnkeyml/tree/main/src/turnkeyml/llm) | Export and optimize ONNX models for CNNs, Transformers, and GNNs.
[Click here to get started with turnkey classic.](https://github.com/onnx/turnkeyml/blob/main/docs/classic_getting_started.md) | +| Serve and benchmark LLMs on CPU, GPU, and NPU.
[Click here to get started with `lemonade`.](https://github.com/onnx/turnkeyml/blob/main/docs/lemonade_getting_started.md) | Export and optimize ONNX models for CNNs and Transformers.
[Click here to get started with `turnkey`.](https://github.com/onnx/turnkeyml/blob/main/docs/classic_getting_started.md) | | | | ## How It Works -The `turnkey` (classic) and `turnkey-llm` CLIs provide a set of `Tools` that users can invoke in a `Sequence`. The first `Tool` takes the input (`-i`), performs some action, and passes its state to the next `Tool` in the `Sequence`. +The `turnkey` (CNNs and transformers) and `lemonade` (LLMs) CLIs provide a set of `Tools` that users can invoke in a `Sequence`. The first `Tool` takes the input (`-i`), performs some action, and passes its state to the next `Tool` in the `Sequence`. You can read the `Sequence` out like a sentence. For example, the demo command above was: @@ -51,3 +51,4 @@ This project is licensed under the [Apache 2.0 License](https://github.com/onnx/ ## Attribution TurnkeyML used code from other open source projects as a starting point (see [NOTICE.md](NOTICE.md)). Thank you Philip Colangelo, Derek Elkins, Jeremy Fowers, Dan Gard, Victoria Godsoe, Mark Heaps, Daniel Holanda, Brian Kurtz, Mariah Larwood, Philip Lassen, Andrew Ling, Adrian Macias, Gary Malik, Sarah Massengill, Ashwin Murthy, Hatice Ozen, Tim Sears, Sean Settle, Krishna Sivakumar, Aviv Weinstein, Xueli Xao, Bill Xing, and Lev Zlotnik for your contributions to that work. + diff --git a/docs/humaneval_accuracy.md b/docs/humaneval_accuracy.md new file mode 100644 index 00000000..815baaee --- /dev/null +++ b/docs/humaneval_accuracy.md @@ -0,0 +1,108 @@ +# Using the HumanEval accuracy test tools + +The HumanEval benchmark is a code generation and functional correctness evaluation framework designed to assess language models' ability to generate Python code. It consists of 164 handwritten programming problems, each containing a function signature, docstring, body, and several unit tests. This benchmark focuses on evaluating a model's capability to generate functionally correct code that passes the test cases, making it particularly useful for assessing code generation capabilities. + +This tool provides an automated way to evaluate language models on the HumanEval benchmark. It handles the process of downloading the dataset, generating code completions, executing them in a secure environment, and calculating pass@k metrics. + +## Dataset + +The HumanEval dataset is automatically downloaded from [OpenAI's human-eval repository](https://github.com/openai/human-eval) when you first run the benchmark. The dataset contains programming problems that test various aspects of Python programming, including: + +- Basic programming operations +- String manipulation +- Mathematical computations +- List operations +- Algorithm implementation +- Data structure manipulation + +## Running the Benchmark + +```bash +lemonade -i meta-llama/Llama-3.2-1B oga-load --device igpu --dtype int4 accuracy-humaneval --k-samples 1 --first-n-samples 5 --timeout 30.0 +``` + +### Optional arguments: + +`--k-samples`: Number of completions to generate per prompt (default: 1). This parameter determines the k in pass@k metrics. For example: +- `--k-samples 1`: Calculates pass@1 (single attempt per problem) +- `--k-samples 10`: Calculates pass@10 (ten attempts per problem) +- `--k-samples 100`: Calculates pass@100 (hundred attempts per problem) + +Higher k values provide more robust evaluation but take longer to run. + +`--first-n-samples`: Evaluate only the first N problems from the dataset (default: entire dataset). Useful for quick testing or when you want to evaluate a subset of problems. + +`--timeout`: Maximum time in seconds allowed for each test case execution (default: 30.0). This prevents infinite loops or long-running code from blocking the evaluation. + +`--data-dir`: Custom directory for storing the HumanEval dataset (default: "/data/humaneval"). + +## How It Works + +1. **Dataset Preparation:** + - On first run, the tool downloads the HumanEval dataset (HumanEval.jsonl.gz) + - The dataset contains function signatures, docstrings, and test cases + - Each problem is structured to test specific programming capabilities + - You can evaluate only the first N problems using `--first-n-samples` + +2. **Code Generation:** + - For each programming problem, the model is provided with a prompt containing: + - Function signature (e.g., `def sort_numbers(numbers):`) + - Docstring describing the function's purpose and requirements + - The model generates k code completions for the function body (controlled by `--k-samples`) + - These k samples are used to calculate the pass@k metric + +3. **Secure Execution:** + - Generated code is executed in a secure sandbox environment maintained by OpenAI's human-eval library. For your awareness, OpenAI's policy is to disable code execution by default, however lemonade enables code execution by default by automatically setting the environment variable `HF_ALLOW_CODE_EVAL=1`. OpenAI provides the following code execution protections: + - **Process Isolation**: Each code sample runs in a separate process to prevent interference + - **Resource Limits**: + - CPU time limit (controlled by `--timeout`) + - Memory usage restrictions + - Maximum output size restrictions + - **Restricted Access**: + - No network access + - No file system access outside test directory + - No subprocess creation + - No system calls + - **Module Restrictions**: + - Only allows importing standard Python libraries needed for testing + - Blocks potentially dangerous modules (os, sys, subprocess, etc.) + These security measures are implemented through: + - Python's built-in `resource` module for resource limits + - AST (Abstract Syntax Tree) analysis for code validation + - Process-level isolation using `multiprocessing` + - Custom import hooks to restrict module access + +4. **Evaluation Metrics:** + - **pass@k**: Percentage of problems solved with k attempts + - pass@1: Success rate with single attempt + - pass@10: Success rate within 10 attempts + - pass@100: Success rate within 100 attempts + - A problem is considered solved if all test cases pass + - Results are normalized to percentages + +5. **Output Files:** + The tool generates several output files in the results directory: + - `evaluation_results.csv`: Contains prompts, completions, and expected answers + - `humaneval_predictions.jsonl`: Raw model predictions in JSONL format + - `humaneval_predictions.jsonl_results.jsonl`: Detailed evaluation results + +## Example Results Format + +The evaluation produces metrics in the following format: +```json +{ + "pass@1": 0.25, // 25% success rate with 1 attempt + "pass@10": 0.45, // 45% success rate within 10 attempts + "pass@100": 0.65 // 65% success rate within 100 attempts +} +``` + +## Limitations + +1. **Resource Requirements**: Generating multiple samples per problem (high k values) can be computationally intensive and time-consuming. +2. **Memory Usage**: Large language models may require significant memory, especially when generating multiple samples. + +## References + +1. [Evaluating Large Language Models Trained on Code](https://arxiv.org/abs/2107.03374) +2. [OpenAI HumanEval Repository](https://github.com/openai/human-eval) \ No newline at end of file diff --git a/src/turnkeyml/llm/README.md b/docs/lemonade_getting_started.md similarity index 98% rename from src/turnkeyml/llm/README.md rename to docs/lemonade_getting_started.md index af9cf1a2..4501d6c6 100644 --- a/src/turnkeyml/llm/README.md +++ b/docs/lemonade_getting_started.md @@ -79,8 +79,8 @@ Note that the `llm-prompt`, `accuracy-mmlu`, and `serve` tools can all be used w Lemonade is also available via API. Here's a quick example of how to benchmark an LLM: ```python -import turnkeyml.llm.tools.torch_llm as tl -import turnkeyml.llm.tools.chat as cl +import lemonade.tools.torch_llm as tl +import lemonade.tools.chat as cl from turnkeyml.state import State state = State(cache_dir="cache", build_name="test") diff --git a/examples/llm/leap_basic.py b/examples/llm/leap_basic.py new file mode 100644 index 00000000..418cecc8 --- /dev/null +++ b/examples/llm/leap_basic.py @@ -0,0 +1,18 @@ +""" +This example demonstrates how to use the LEAP API to load a model for +inference on CPU using the hf-cpu recipe, and then use it to generate +the response to a prompt. + +If you have a discrete GPU, you can try that by changing the recipe +to hf-dgpu. Note: make sure to have torch+cuda installed when trying +hf-dgpu. +""" + +from lemonade import leap + +model, tokenizer = leap.from_pretrained("facebook/opt-125m", recipe="hf-cpu") + +input_ids = tokenizer("This is my prompt", return_tensors="pt").input_ids +response = model.generate(input_ids, max_new_tokens=30) + +print(tokenizer.decode(response[0])) diff --git a/examples/llm/leap_ryzenai_npu.py b/examples/llm/leap_ryzenai_npu.py new file mode 100644 index 00000000..30ee0222 --- /dev/null +++ b/examples/llm/leap_ryzenai_npu.py @@ -0,0 +1,21 @@ +""" +This example demonstrates how to use the LEAP API to load a model for +inference on a Ryzen AI NPU using the ryzenai-npu-load recipe, +and then use it to generate the response to a prompt. + +Note that this example will only run if the Ryzen AI NPU Private recipe is installed. +See genai/docs/ryzenai_npu.md for instructions. + +You can try the same model on CPU by changing the recipe to "hf-cpu". +""" + +from lemonade import leap + +model, tokenizer = leap.from_pretrained( + "meta-llama/Llama-2-7b-chat-hf", recipe="ryzenai-npu" +) + +input_ids = tokenizer("This is my prompt", return_tensors="pt").input_ids +response = model.generate(input_ids, max_new_tokens=30) + +print(tokenizer.decode(response[0])) diff --git a/examples/llm/leap_streaming.py b/examples/llm/leap_streaming.py new file mode 100644 index 00000000..e2951dbd --- /dev/null +++ b/examples/llm/leap_streaming.py @@ -0,0 +1,38 @@ +""" +This example demonstrates how to use the LEAP API to load a model for +inference on CPU using the hf-cpu recipe, and then use a thread to +generate a streaming the response to a prompt. + +Note: this approach only works with recipes that support TextIteratorStreamer, +i.e., huggingface-based recipes such as hf-cpu and ryzenai-npu. +""" + +from thread import Thread +from transformers import TextIteratorStreamer +from lemonade import leap + +# Replace the recipe with "ryzenai-npu" to run on the RyzenAI NPU +model, tokenizer = leap.from_pretrained( + "meta-llama/Llama-2-7b-chat-hf", recipe="hf-cpu" +) + +input_ids = tokenizer("This is my prompt", return_tensors="pt").input_ids + +streamer = TextIteratorStreamer( + tokenizer, + skip_prompt=True, +) +generation_kwargs = { + "input_ids": input_ids, + "streamer": streamer, + "max_new_tokens": 30, +} + +thread = Thread(target=model.generate, kwargs=generation_kwargs) +thread.start() + +# Generate the response using streaming +for new_text in streamer: + print(new_text) + +thread.join() diff --git a/examples/llm/turnkey_llm.ipynb b/examples/llm/turnkey_llm.ipynb index 07d7e329..ffe6d90c 100644 --- a/examples/llm/turnkey_llm.ipynb +++ b/examples/llm/turnkey_llm.ipynb @@ -85,7 +85,7 @@ "outputs": [], "source": [ "# Import the turnkey APIs\n", - "from turnkeyml.llm import leap\n", + "from lemonade import leap\n", "\n", "# Load the model on to RyzenAI NPU\n", "# NOTE: this takes a couple of minutes, but after you've done it once\n", @@ -133,7 +133,7 @@ "outputs": [], "source": [ "# Import the turnkey APIs\n", - "from turnkeyml.llm import leap\n", + "from lemonade import leap\n", "\n", "# Load the model on iGPU\n", "igpu_model, igpu_tokenizer = leap.from_pretrained(\n", diff --git a/examples/readme.md b/examples/readme.md index 09dd7fbb..b9dcb05a 100644 --- a/examples/readme.md +++ b/examples/readme.md @@ -3,3 +3,4 @@ This directory contains examples to help you learn how to use the tools. The examples are split up into two sub-directories: 1. `examples/cli`: a tutorial series for the `turnkey` CLI. This is the recommended starting point. 1. `examples/api`: scripts that demonstrate how to use the `turnkey.evaluate_files()` API. +1. `examples/llm`: scripts that demonstrate the `lemonade` CLI for LLMs. diff --git a/setup.py b/setup.py index 7654314e..199bc3bf 100644 --- a/setup.py +++ b/setup.py @@ -3,11 +3,11 @@ with open("src/turnkeyml/version.py", encoding="utf-8") as fp: version = fp.read().split('"')[1] + setup( name="turnkeyml", version=version, description="TurnkeyML Tools and Models", - author="Jeremy Fowers, Daniel Holanda, Ramakrishnan Sivakumar, Victoria Godsoe", author_email="turnkeyml@amd.com", package_dir={"": "src", "turnkeyml_models": "models"}, packages=[ @@ -17,10 +17,10 @@ "turnkeyml.sequence", "turnkeyml.cli", "turnkeyml.common", - "turnkeyml.llm", - "turnkeyml.llm.tools", - "turnkeyml.llm.tools.ort_genai", - "turnkeyml.llm.tools.ryzenai_npu", + "lemonade", + "lemonade.tools", + "lemonade.tools.ort_genai", + "lemonade.tools.ryzenai_npu", "turnkeyml_models", "turnkeyml_models.graph_convolutions", "turnkeyml_models.selftest", @@ -46,6 +46,7 @@ "psutil", "wmi", "pytz", + "tqdm", # Conditional dependencies for ONNXRuntime backends "onnxruntime >=1.10.1;platform_system=='Linux' and extra != 'llm-oga-cuda'", "onnxruntime-directml >=1.19.0;platform_system=='Windows' and extra != 'llm-oga-cuda'", @@ -53,70 +54,53 @@ ], extras_require={ "llm": [ - "tqdm", "torch>=2.0.0", "transformers", "accelerate", "py-cpuinfo", "sentencepiece", "datasets", + # Install human-eval from a forked repo with Windows support until the + # PR (https://github.com/openai/human-eval/pull/53) is merged + "human-eval @ git+https://github.com/ramkrishna2910/human-eval.git", "fastapi", "uvicorn[standard]", ], - "llm-oga-dml": [ + "llm-oga-igpu": [ "onnxruntime-genai-directml==0.4.0", - "tqdm", "torch>=2.0.0,<2.4", "transformers<4.45.0", - "accelerate", - "py-cpuinfo", - "sentencepiece", - "datasets", - "fastapi", - "uvicorn[standard]", + "turnkeyml[llm]", ], "llm-oga-cuda": [ "onnxruntime-genai-cuda==0.4.0", - "tqdm", "torch>=2.0.0,<2.4", "transformers<4.45.0", - "accelerate", - "py-cpuinfo", - "sentencepiece", - "datasets", - "fastapi", - "uvicorn[standard]", + "turnkeyml[llm]", ], "llm-oga-npu": [ - "transformers", - "torch", "onnx==1.16.0", "onnxruntime==1.18.0", "numpy==1.26.4", - "tqdm", - "accelerate", - "py-cpuinfo", - "sentencepiece", - "datasets", - "fastapi", - "uvicorn[standard]", + "turnkeyml[llm]", ], "llm-oga-hybrid": [ - "transformers", - "torch", "onnx==1.16.1", "numpy==1.26.4", - "datasets", - "fastapi", - "uvicorn[standard]", + "turnkeyml[llm]", + ], + "cuda": [ + "torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp310-cp310-win_amd64.whl", + "torchvision @ https://download.pytorch.org/whl/cu118/torchvision-0.18.1%2Bcu118-cp310-cp310-win_amd64.whl", + "torchaudio @ https://download.pytorch.org/whl/cu118/torchaudio-2.3.1%2Bcu118-cp310-cp310-win_amd64.whl", ], }, classifiers=[], entry_points={ "console_scripts": [ "turnkey=turnkeyml:turnkeycli", - "turnkey-llm=turnkeyml.llm:lemonadecli", - "lemonade=turnkeyml.llm:lemonadecli", + "turnkey-llm=lemonade:lemonadecli", + "lemonade=lemonade:lemonadecli", ] }, python_requires=">=3.8, <3.12", diff --git a/src/turnkeyml/llm/__init__.py b/src/lemonade/__init__.py similarity index 100% rename from src/turnkeyml/llm/__init__.py rename to src/lemonade/__init__.py diff --git a/src/turnkeyml/llm/cache.py b/src/lemonade/cache.py similarity index 100% rename from src/turnkeyml/llm/cache.py rename to src/lemonade/cache.py diff --git a/src/turnkeyml/llm/cli.py b/src/lemonade/cli.py similarity index 84% rename from src/turnkeyml/llm/cli.py rename to src/lemonade/cli.py index e8b695fb..5417673e 100644 --- a/src/turnkeyml/llm/cli.py +++ b/src/lemonade/cli.py @@ -7,20 +7,21 @@ from turnkeyml.tools.report import Report from turnkeyml.state import State -from turnkeyml.llm.tools.huggingface_load import ( +from lemonade.tools.huggingface_load import ( HuggingfaceLoad, AdaptHuggingface, ) -from turnkeyml.llm.tools.huggingface_bench import HuggingfaceBench -from turnkeyml.llm.tools.ort_genai.oga_bench import OgaBench +from lemonade.tools.huggingface_bench import HuggingfaceBench +from lemonade.tools.ort_genai.oga_bench import OgaBench -from turnkeyml.llm.tools.llamacpp import LoadLlamaCpp +from lemonade.tools.llamacpp import LoadLlamaCpp -import turnkeyml.llm.cache as cache -from turnkeyml.llm.tools.mmlu import AccuracyMMLU -from turnkeyml.llm.tools.perplexity import AccuracyPerplexity -from turnkeyml.llm.tools.chat import LLMPrompt, Serve +import lemonade.cache as cache +from lemonade.tools.mmlu import AccuracyMMLU +from lemonade.tools.humaneval import AccuracyHumaneval +from lemonade.tools.perplexity import AccuracyPerplexity +from lemonade.tools.chat import LLMPrompt, Serve def main(): @@ -30,6 +31,7 @@ def main(): HuggingfaceLoad, LoadLlamaCpp, AccuracyMMLU, + AccuracyHumaneval, AccuracyPerplexity, LLMPrompt, AdaptHuggingface, @@ -45,7 +47,7 @@ def main(): # Import onnxruntime-genai recipes try: - from turnkeyml.llm.tools.ort_genai.oga import OgaLoad + from lemonade.tools.ort_genai.oga import OgaLoad tools = tools + [OgaLoad] @@ -54,7 +56,7 @@ def main(): # Import RyzenAI NPU modules only if RyzenAI NPU is installed try: - from turnkeyml.llm.tools.ryzenai_npu.ryzenai_npu import RyzenAINPULoad + from lemonade.tools.ryzenai_npu.ryzenai_npu import RyzenAINPULoad tools = tools + [RyzenAINPULoad] except ModuleNotFoundError: diff --git a/src/turnkeyml/llm/leap.py b/src/lemonade/leap.py similarity index 91% rename from src/turnkeyml/llm/leap.py rename to src/lemonade/leap.py index 75475dc1..da9342f9 100644 --- a/src/turnkeyml/llm/leap.py +++ b/src/lemonade/leap.py @@ -3,8 +3,8 @@ from typing import Tuple, Dict from turnkeyml.state import State import turnkeyml.common.printing as printing -import turnkeyml.llm.cache as cache -from turnkeyml.llm.tools.adapter import ModelAdapter, TokenizerAdapter +import lemonade.cache as cache +from lemonade.tools.adapter import ModelAdapter, TokenizerAdapter class NotSupported(Exception): @@ -78,7 +78,7 @@ def from_pretrained( # Huggingface supports all checkpoints, so there is nothing to check for import torch - from turnkeyml.llm.tools.huggingface_load import HuggingfaceLoad + from lemonade.tools.huggingface_load import HuggingfaceLoad state = _make_state(recipe, checkpoint) @@ -94,7 +94,7 @@ def from_pretrained( # Huggingface Transformers recipe for discrete GPU (Nvidia, Instinct, Radeon) import torch - from turnkeyml.llm.tools.huggingface_load import HuggingfaceLoad + from lemonade.tools.huggingface_load import HuggingfaceLoad state = _make_state(recipe, checkpoint) @@ -111,7 +111,7 @@ def from_pretrained( return state.model, tokenizer elif recipe == "oga-dml-igpu": - import turnkeyml.llm.tools.ort_genai.oga as oga + import lemonade.tools.ort_genai.oga as oga state = _make_state(recipe, checkpoint) @@ -134,7 +134,7 @@ def from_pretrained( ): _raise_not_supported(recipe, checkpoint) - import turnkeyml.llm.tools.ryzenai_npu.ryzenai_npu as ryzenai_npu + import lemonade.tools.ryzenai_npu.ryzenai_npu as ryzenai_npu state = _make_state(recipe, checkpoint) diff --git a/src/turnkeyml/llm/tools/__init__.py b/src/lemonade/tools/__init__.py similarity index 100% rename from src/turnkeyml/llm/tools/__init__.py rename to src/lemonade/tools/__init__.py diff --git a/src/turnkeyml/llm/tools/adapter.py b/src/lemonade/tools/adapter.py similarity index 100% rename from src/turnkeyml/llm/tools/adapter.py rename to src/lemonade/tools/adapter.py diff --git a/src/turnkeyml/llm/tools/chat.py b/src/lemonade/tools/chat.py similarity index 99% rename from src/turnkeyml/llm/tools/chat.py rename to src/lemonade/tools/chat.py index 8daec102..44c05031 100644 --- a/src/turnkeyml/llm/tools/chat.py +++ b/src/lemonade/tools/chat.py @@ -11,7 +11,7 @@ import uvicorn from turnkeyml.state import State from turnkeyml.tools import Tool -from turnkeyml.llm.tools.adapter import ModelAdapter, TokenizerAdapter +from lemonade.tools.adapter import ModelAdapter, TokenizerAdapter DEFAULT_GENERATE_PARAMS = { "do_sample": True, @@ -301,7 +301,7 @@ async def stream_response(websocket: WebSocket): # Set up the generation parameters if isinstance(model, ModelAdapter) and model.type == "ort-genai": # Onnxruntime-genai models - import turnkeyml.llm.tools.ort_genai.oga as oga + import lemonade.tools.ort_genai.oga as oga streamer = oga.OrtGenaiStreamer(tokenizer) diff --git a/src/turnkeyml/llm/tools/huggingface_bench.py b/src/lemonade/tools/huggingface_bench.py similarity index 98% rename from src/turnkeyml/llm/tools/huggingface_bench.py rename to src/lemonade/tools/huggingface_bench.py index 4cee8542..1add8b3a 100644 --- a/src/turnkeyml/llm/tools/huggingface_bench.py +++ b/src/lemonade/tools/huggingface_bench.py @@ -8,8 +8,8 @@ import tqdm from turnkeyml.state import State from turnkeyml.tools import Tool -from turnkeyml.llm.cache import Keys -import turnkeyml.llm.tools.ort_genai.oga_bench as general +from lemonade.cache import Keys +import lemonade.tools.ort_genai.oga_bench as general def benchmark_huggingface_llm( @@ -85,6 +85,9 @@ def benchmark_huggingface_llm( if token_len >= target_output_tokens: per_iteration_result.append((latency, token_len)) + if not per_iteration_result: + raise general.not_enough_tokens(target_output_tokens) + return per_iteration_result diff --git a/src/turnkeyml/llm/tools/huggingface_load.py b/src/lemonade/tools/huggingface_load.py similarity index 98% rename from src/turnkeyml/llm/tools/huggingface_load.py rename to src/lemonade/tools/huggingface_load.py index ba46f214..8cbec6ea 100644 --- a/src/turnkeyml/llm/tools/huggingface_load.py +++ b/src/lemonade/tools/huggingface_load.py @@ -6,8 +6,8 @@ from turnkeyml.state import State import turnkeyml.common.status as status from turnkeyml.tools import Tool, FirstTool -from turnkeyml.llm.tools.adapter import ModelAdapter -from turnkeyml.llm.cache import Keys +from lemonade.tools.adapter import ModelAdapter +from lemonade.cache import Keys # Command line interfaces for tools will use string inputs for data # types, however the internal tool logic will need to know the actual diff --git a/src/lemonade/tools/humaneval.py b/src/lemonade/tools/humaneval.py new file mode 100644 index 00000000..c9433a39 --- /dev/null +++ b/src/lemonade/tools/humaneval.py @@ -0,0 +1,254 @@ +import argparse +import os +import csv +from typing import Dict, Optional, Any +import requests +from human_eval.data import write_jsonl, read_problems +from human_eval.evaluation import evaluate_functional_correctness + +from turnkeyml.state import State +from turnkeyml.tools import Tool +import turnkeyml.common.printing as printing +import turnkeyml.common.build as build + + +class AccuracyHumaneval(Tool): + """ + HumanEval accuracy measurement tool. + + This tool evaluates language models on the HumanEval dataset, which consists of + Python programming problems. It measures the model's ability to: + 1. Generate functionally correct code completions + 2. Pass unit tests for each programming problem + + Metrics: + - pass@1: Percentage of problems solved with 1 generation attempt + - pass@10: Percentage of problems solved within 10 generation attempts + - pass@100: Percentage of problems solved within 100 generation attempts + + See docs/humaneval_accuracy.md for more details + """ + + unique_name = "accuracy-humaneval" + DATASET = "https://github.com/openai/human-eval/blob/master/data/HumanEval.jsonl.gz?raw=true" + TOTAL_PROBLEMS = 164 # Total number of problems in the HumanEval dataset + + def __init__(self): + super().__init__(monitor_message="Measuring accuracy with HumanEval") + self.status_stats = [] + # Enable code evaluation for HumanEval + os.environ["HF_ALLOW_CODE_EVAL"] = "1" + + @staticmethod + def parser(add_help: bool = True) -> argparse.ArgumentParser: + parser = __class__.helpful_parser( + short_description="Run accuracy benchmark using HumanEval dataset", + add_help=add_help, + ) + parser.add_argument( + "--k-samples", + type=int, + default=1, + help="Number of completions to generate per prompt for pass@k calculation" + " (default: %(default)s)", + ) + parser.add_argument( + "--first-n-samples", + type=int, + default=AccuracyHumaneval.TOTAL_PROBLEMS, + help=f"Evaluate only the first N problems from the dataset (default: " + f"%(default)s, evaluates all {AccuracyHumaneval.TOTAL_PROBLEMS} problems)", + ) + parser.add_argument( + "--timeout", + type=float, + default=30.0, + help="Timeout in seconds for each test case (default: %(default)s)", + ) + parser.add_argument( + "--data-dir", + type=str, + default=None, + help="Custom directory for dataset storage (default: %(default)s, " + "uses /data/humaneval)", + ) + return parser + + def run( + self, + state: State, + data_dir: Optional[str] = None, + k_samples: int = 1, + first_n_samples: Optional[int] = TOTAL_PROBLEMS, + timeout: float = 30.0, + ) -> State: + """ + Run HumanEval evaluation on the model. + + Args: + state: Current state containing model and tokenizer + data_dir: Optional custom directory for dataset storage + k_samples: Number of completions to generate per prompt for pass@k calculation + first_n_samples: Number of first N problems to evaluate + timeout: Timeout in seconds for each test case + + Returns: + Updated state with evaluation results + """ + # Validate required state components + if not hasattr(state, "model") or not hasattr(state, "tokenizer"): + raise ValueError("State must contain both 'model' and 'tokenizer'") + + # Setup directories + data_dir_to_use = data_dir or os.path.join(state.cache_dir, "data", "humaneval") + data_path = os.path.join(data_dir_to_use, "HumanEval.jsonl.gz") + model_results_dir = os.path.join( + build.output_dir(state.cache_dir, state.build_name), "humaneval" + ) + os.makedirs(model_results_dir, exist_ok=True) + + # Download dataset if needed + self._download_dataset(data_path) + + # Run evaluation + results = self._evaluate_model( + state.model, + state.tokenizer, + data_path, + k_samples, + timeout, + model_results_dir, + first_n_samples, + ) + + # Save metrics + self._save_metrics(state, results) + + return state + + def _download_dataset(self, output_path: str) -> None: + """Download HumanEval dataset if not already present.""" + if os.path.exists(output_path): + printing.log_info(f"Dataset already exists at: {output_path}") + return + + os.makedirs(os.path.dirname(output_path), exist_ok=True) + response = requests.get(self.DATASET, stream=True) + + if response.status_code == 200: + with open(output_path, "wb") as file: + for chunk in response.iter_content(chunk_size=8192): + file.write(chunk) + printing.log_info(f"Dataset downloaded successfully to: {output_path}") + else: + raise RuntimeError( + f"Failed to download dataset. Status code: {response.status_code}" + ) + + def _evaluate_model( + self, + model: Any, + tokenizer: Any, + data_path: str, + k_samples: int, + timeout: float, + results_dir: str, + first_n_samples: Optional[int] = TOTAL_PROBLEMS, + ) -> Dict[str, float]: + """ + Evaluate model on HumanEval dataset. + + Args: + model: The language model to evaluate + tokenizer: The tokenizer for the model + data_path: Path to the HumanEval dataset + k_samples: Number of completions per prompt for pass@k calculation + timeout: Test case timeout in seconds + results_dir: Directory to save results + first_n_samples: Number of first N problems to evaluate + + Returns: + Dictionary containing evaluation metrics + """ + dataset = read_problems(data_path) + + # Limit to first N problems + dataset_keys = list(dataset.keys())[:first_n_samples] + ignore_incomplete = True + + samples = [] + + # Update Tool progress monitor + self.set_percent_progress(0.0) + questions_completed = 0 + number_of_questions = first_n_samples * k_samples + + # Save completions and expected answers + csv_path = os.path.join(results_dir, "evaluation_results.csv") + with open( + csv_path, mode="w", newline="", encoding="utf-8", errors="replace" + ) as file: + writer = csv.writer(file) + writer.writerow(["Prompt", "Completion", "Expected Answer"]) + + for task_id in dataset_keys: + try: + for _ in range(k_samples): + prompt = dataset[task_id]["prompt"] + expected = dataset[task_id]["canonical_solution"] + + # Generate completion + input_ids = tokenizer(prompt, return_tensors="pt").input_ids + completion = model.generate( + input_ids, + max_new_tokens=512, + do_sample=False, + ) + completion_text = tokenizer.decode( + completion[0], skip_special_tokens=True + ) + + # Save results + samples.append( + {"task_id": task_id, "completion": completion_text} + ) + writer.writerow([prompt, completion_text, expected]) + + # Update progress monitor after completing all samples for a question + questions_completed = questions_completed + 1 + percent_completed = ( + questions_completed / number_of_questions * 100 + ) + self.set_percent_progress(percent_completed) + + # pylint: disable=W0718 + except Exception as e: + printing.log_info(f"Error processing task {task_id}: {str(e)}") + continue + + # Save predictions and evaluate + pred_path = os.path.join(results_dir, "humaneval_predictions.jsonl") + write_jsonl(pred_path, samples) + printing.log_info(f"Results saved in: {results_dir}") + + # Run functional correctness evaluation + k_values = [k_samples] + results = evaluate_functional_correctness( + pred_path, + k_values, + n_workers=1, + timeout=timeout, + problem_file=data_path, + ignore_incomplete=ignore_incomplete, + ) + return results + + def _save_metrics(self, state: State, results: Dict[str, float]) -> None: + """Save evaluation metrics to state.""" + for metric, value in results.items(): + metric_name = f"humaneval_{metric}" + state.save_stat( + metric_name, float(value) * 100 if value is not None else None + ) + state.save_stat(f"{metric_name}_units", "%") + self.status_stats.append(metric_name) diff --git a/src/turnkeyml/llm/tools/llamacpp.py b/src/lemonade/tools/llamacpp.py similarity index 100% rename from src/turnkeyml/llm/tools/llamacpp.py rename to src/lemonade/tools/llamacpp.py diff --git a/src/turnkeyml/llm/tools/mmlu.py b/src/lemonade/tools/mmlu.py similarity index 77% rename from src/turnkeyml/llm/tools/mmlu.py rename to src/lemonade/tools/mmlu.py index 26946d11..33abfcb4 100644 --- a/src/turnkeyml/llm/tools/mmlu.py +++ b/src/lemonade/tools/mmlu.py @@ -18,6 +18,16 @@ dataset_url = "https://people.eecs.berkeley.edu/~hendrycks/data.tar" +def min_handle_none(*args: int): + """ + Returns the minimum of the arguments. If one of the arguments is none, + it doesn't count towards the min. + """ + + filter_out_none = (value for value in args if value is not None) + return min(filter_out_none) + + class AccuracyMMLU(Tool): """ See docs/mmlu_accuracy.md for more details @@ -95,15 +105,35 @@ def run( if tests is not None: unsupported_tests = set(tests) - set(tests_to_run) if unsupported_tests: - printing.log_warning( - "Warning: Unsupported tests specified and will be ignored:" - + f"{', '.join(unsupported_tests)}" + raise ValueError( + f"Invalid test names provided: {', '.join(unsupported_tests)}. " + f"Valid tests are: {', '.join(tests_to_run)}" ) tests_to_run = [test for test in tests if test in tests_to_run] tokenizer = state.tokenizer model = state.model + # Update Tool progress monitor + self.set_percent_progress(0.0) + number_of_questions = float( + sum( + [ + min_handle_none( + len( + _safe_read_csv( + os.path.join(dataset_dir, "test", f"{subject}_test.csv") + ) + ), + max_evals, + ) + for subject in tests_to_run + ] + ) + ) + + questions_completed = 0 + summary_data = [] for subject in tqdm.tqdm(tests_to_run): dev_df = _safe_read_csv( @@ -113,9 +143,40 @@ def run( os.path.join(dataset_dir, "test", f"{subject}_test.csv") ) - detailed_results, acc = _eval_model( - ntrain, max_evals, subject, model, tokenizer, dev_df, test_df - ) + # Evaluate the model on the test data for a given subject + detailed_results = [] + + for i in range(min_handle_none(test_df.shape[0], max_evals)): + prompt = _gen_prompt(dev_df, subject, ntrain) + _format_example( + test_df, i, include_answer=False + ) + input_ids = tokenizer(prompt, return_tensors="pt").input_ids + + response_text = _generate_response(tokenizer, model, input_ids) + try: + pred_label = response_text[-1].upper() + # Handle models generating empty outputs + except IndexError: + pred_label = "-" + + label = test_df.iloc[i, -1].strip().upper() + detailed_results.append( + { + "Question": test_df.iloc[i, 0], + "Prompt": prompt, + "Correct Answer": label, + "Generated Answer": pred_label, + "Correct": pred_label == label, + } + ) + + # Update progress monitor + questions_completed = questions_completed + 1 + percent_completed = questions_completed / number_of_questions * 100 + self.set_percent_progress(percent_completed) + + acc = np.mean([res["Correct"] for res in detailed_results]) + subject_results_df = pd.DataFrame(detailed_results) subject_csv_path = os.path.join( model_results_dir, f"{subject}_detailed_results.csv" @@ -254,37 +315,3 @@ def download_and_extract_dataset(data_cache_dir: str, dataset_url: str): # MMLU data is stored in data.tar/data return os.path.join(data_cache_dir, "data") - - -def _eval_model(ntrain, max_evals, subject, model, tokenizer, dev_df, test_df): - """Evaluates the model on the test data for a given subject.""" - detailed_results = [] - - for i in range(test_df.shape[0]): - prompt = _gen_prompt(dev_df, subject, ntrain) + _format_example( - test_df, i, include_answer=False - ) - input_ids = tokenizer(prompt, return_tensors="pt").input_ids - - response_text = _generate_response(tokenizer, model, input_ids) - try: - pred_label = response_text[-1].upper() - # Handle models generating empty outputs - except IndexError: - pred_label = "-" - - label = test_df.iloc[i, -1].strip().upper() - detailed_results.append( - { - "Question": test_df.iloc[i, 0], - "Prompt": prompt, - "Correct Answer": label, - "Generated Answer": pred_label, - "Correct": pred_label == label, - } - ) - if max_evals is not None and i >= max_evals - 1: - break - - acc = np.mean([res["Correct"] for res in detailed_results]) - return detailed_results, acc diff --git a/src/turnkeyml/llm/tools/ort_genai/__init__.py b/src/lemonade/tools/ort_genai/__init__.py similarity index 100% rename from src/turnkeyml/llm/tools/ort_genai/__init__.py rename to src/lemonade/tools/ort_genai/__init__.py diff --git a/src/turnkeyml/llm/tools/ort_genai/oga.py b/src/lemonade/tools/ort_genai/oga.py similarity index 99% rename from src/turnkeyml/llm/tools/ort_genai/oga.py rename to src/lemonade/tools/ort_genai/oga.py index cf176c57..890e181e 100644 --- a/src/turnkeyml/llm/tools/ort_genai/oga.py +++ b/src/lemonade/tools/ort_genai/oga.py @@ -22,12 +22,12 @@ from turnkeyml.tools import FirstTool import turnkeyml.common.status as status import turnkeyml.common.printing as printing -from turnkeyml.llm.tools.adapter import ( +from lemonade.tools.adapter import ( ModelAdapter, TokenizerAdapter, PassthroughTokenizerResult, ) -from turnkeyml.llm.cache import Keys +from lemonade.cache import Keys # ONNX Runtime GenAI models will be cached in this subfolder of the lemonade cache folder oga_models_path = "oga_models" @@ -216,7 +216,7 @@ class OgaLoad(FirstTool): Input: path to a checkpoint. Supported choices for cpu and igpu from HF model repository: LLM models on Huggingface supported by model_builder. See documentation - (https://github.com/aigdat/genai/blob/main/docs/ort_genai_igpu.md) for supported models. + (https://github.com/onnx/turnkeyml/blob/main/docs/ort_genai_igpu.md) for supported models. Supported choices for npu from HF model repository: Models on Hugging Face that follow the "amd/**-onnx-ryzen-strix" pattern Local models for cpu, igpu, or npu: diff --git a/src/turnkeyml/llm/tools/ort_genai/oga_bench.py b/src/lemonade/tools/ort_genai/oga_bench.py similarity index 80% rename from src/turnkeyml/llm/tools/ort_genai/oga_bench.py rename to src/lemonade/tools/ort_genai/oga_bench.py index fed7f8c2..0ae29756 100644 --- a/src/turnkeyml/llm/tools/ort_genai/oga_bench.py +++ b/src/lemonade/tools/ort_genai/oga_bench.py @@ -4,8 +4,8 @@ import tqdm from turnkeyml.state import State from turnkeyml.tools import Tool -from turnkeyml.llm.cache import Keys -from turnkeyml.llm.tools.adapter import ModelAdapter, TokenizerAdapter +from lemonade.cache import Keys +from lemonade.tools.adapter import ModelAdapter, TokenizerAdapter default_iterations = 10 default_warmup_runs = 5 @@ -14,6 +14,29 @@ default_output_tokens = 5 +def not_enough_tokens(output_tokens: int): + """ + Raise an exception that explains why a benchmark did not produce any results + """ + + raise ValueError( + "Your model was benchmarked, however none of the benchmarking " + "iterations produced the requested amount of output tokens " + f"(currently {output_tokens}), so " + "the results have been discarded. You have the following options " + "to solve this: \n" + "1. Use the -p option to change the prompt to something that will " + "produce more output tokens. For example, 'The extremely long " + "story of my life, told in excruciating details is:' " + "is an example of a prompt that will result in a lot of output. \n" + "2. Set a lower value for --output-tokens to make it more likely " + "that the model will produce enough. \n" + "3. Set more verbose hyperparameters. \n" + "4. Run more benchmarking iterations, to improve the chance of " + "getting at least one with enough output tokens. \n" + ) + + class OgaBench(Tool): """ Benchmark any model that adheres to the ModelAdapter interface. @@ -144,6 +167,9 @@ def run( per_iteration_time_to_first_token.append(model.time_to_first_token) per_iteration_tokens_per_second.append(model.tokens_per_second) + if not per_iteration_time_to_first_token or not per_iteration_tokens_per_second: + raise not_enough_tokens(output_tokens) + mean_time_to_first_token = statistics.mean(per_iteration_time_to_first_token) prefill_tokens_per_second = input_ids_len / mean_time_to_first_token token_generation_tokens_per_second = statistics.mean( diff --git a/src/turnkeyml/llm/tools/perplexity.py b/src/lemonade/tools/perplexity.py similarity index 100% rename from src/turnkeyml/llm/tools/perplexity.py rename to src/lemonade/tools/perplexity.py diff --git a/src/turnkeyml/llm/tools/ryzenai_npu/__init__.py b/src/lemonade/tools/ryzenai_npu/__init__.py similarity index 100% rename from src/turnkeyml/llm/tools/ryzenai_npu/__init__.py rename to src/lemonade/tools/ryzenai_npu/__init__.py diff --git a/src/turnkeyml/llm/tools/ryzenai_npu/ryzenai_npu.py b/src/lemonade/tools/ryzenai_npu/ryzenai_npu.py similarity index 98% rename from src/turnkeyml/llm/tools/ryzenai_npu/ryzenai_npu.py rename to src/lemonade/tools/ryzenai_npu/ryzenai_npu.py index 491b5687..1d94d44a 100644 --- a/src/turnkeyml/llm/tools/ryzenai_npu/ryzenai_npu.py +++ b/src/lemonade/tools/ryzenai_npu/ryzenai_npu.py @@ -12,8 +12,8 @@ from modeling_phi3 import Phi3ForCausalLM from turnkeyml.state import State from turnkeyml.tools import FirstTool -from turnkeyml.llm.tools.adapter import ModelAdapter -from turnkeyml.llm.cache import Keys +from lemonade.tools.adapter import ModelAdapter +from lemonade.cache import Keys npu_root_dir = os.path.dirname(__file__) quantized_models_path = os.path.join(npu_root_dir, "quantized_models") diff --git a/src/turnkeyml/version.py b/src/turnkeyml/version.py index 1e92209b..ba7be38e 100644 --- a/src/turnkeyml/version.py +++ b/src/turnkeyml/version.py @@ -1 +1 @@ -__version__ = "4.0.11" +__version__ = "5.0.0" diff --git a/test/llm_api.py b/test/llm_api.py index fe7cdbef..8f7c6397 100644 --- a/test/llm_api.py +++ b/test/llm_api.py @@ -5,11 +5,12 @@ from turnkeyml.state import State import turnkeyml.common.filesystem as fs import turnkeyml.common.test_helpers as common -from turnkeyml.llm.tools.huggingface_load import HuggingfaceLoad -from turnkeyml.llm.tools.huggingface_bench import HuggingfaceBench -from turnkeyml.llm.tools.mmlu import AccuracyMMLU -from turnkeyml.llm.tools.chat import LLMPrompt -from turnkeyml.llm.cache import Keys +from lemonade.tools.huggingface_load import HuggingfaceLoad +from lemonade.tools.huggingface_bench import HuggingfaceBench +from lemonade.tools.mmlu import AccuracyMMLU +from lemonade.tools.humaneval import AccuracyHumaneval +from lemonade.tools.chat import LLMPrompt +from lemonade.cache import Keys ci_mode = os.getenv("LEMONADE_CI_MODE", False) @@ -64,6 +65,31 @@ def test_002_accuracy_mmlu(self): stats = fs.Stats(state.cache_dir, state.build_name).stats assert stats[f"mmlu_{subject[0]}_accuracy"] > 0 + def test_003_accuracy_humaneval(self): + """Test HumanEval benchmarking with known model""" + checkpoint = "facebook/opt-125m" + + state = State( + cache_dir=cache_dir, + build_name="test", + ) + + # Enable code evaluation for HumanEval + os.environ["HF_ALLOW_CODE_EVAL"] = "1" + + state = HuggingfaceLoad().run(state, input=checkpoint) + state = AccuracyHumaneval().run( + state, + first_n_samples=1, # Test only one problem for speed + k_samples=1, # Single attempt per problem + timeout=30.0 + ) + + # Verify results + stats = fs.Stats(state.cache_dir, state.build_name).stats + assert "humaneval_pass@1" in stats, "HumanEval pass@1 metric not found" + assert isinstance(stats["humaneval_pass@1"], (int, float)), "HumanEval pass@1 metric should be numeric" + def test_001_huggingface_bench(self): # Benchmark OPT checkpoint = "facebook/opt-125m"