From 11926e2575002ebab4532735066fa43bf8eb9284 Mon Sep 17 00:00:00 2001 From: Graham Neubig Date: Thu, 7 Sep 2023 17:28:31 -0400 Subject: [PATCH 1/7] Refactor the dataset generator --- prompt2model/dataset_generator/base.py | 17 +- .../dataset_generator/prompt_based.py | 418 ++--- prompt2model/prompt_parser/instr_parser.py | 2 +- prompt2model/utils/api_tools.py | 15 +- test_helpers/__init__.py | 4 - test_helpers/dataset_tools.py | 31 - tests/dataset_generator_test.py | 948 ++++++++++++ tests/dataset_generator_with_filter_test.py | 1376 ----------------- .../dataset_generator_without_filter_test.py | 1036 ------------- tests/dataset_processor_test.py | 123 +- 10 files changed, 1120 insertions(+), 2850 deletions(-) delete mode 100644 test_helpers/dataset_tools.py create mode 100644 tests/dataset_generator_test.py delete mode 100644 tests/dataset_generator_with_filter_test.py delete mode 100644 tests/dataset_generator_without_filter_test.py diff --git a/prompt2model/dataset_generator/base.py b/prompt2model/dataset_generator/base.py index 0ae9c1703..3f80b47b9 100644 --- a/prompt2model/dataset_generator/base.py +++ b/prompt2model/dataset_generator/base.py @@ -4,7 +4,6 @@ from abc import ABC, abstractmethod from enum import Enum -from pathlib import Path import datasets @@ -26,14 +25,14 @@ class DatasetGenerator(ABC): def generate_dataset_split( self, prompt_spec: PromptSpec, - expected_num_examples: int, + num_examples: int, split: DatasetSplit, ) -> datasets.Dataset: """Generate data for a single named split of data. Args: prompt_spec: A prompt spec (containing a system description). - expected_num_examples: Expected number of examples in split. + num_examples: Expected number of examples in split. split: Name of dataset split to generate. Returns: @@ -43,14 +42,13 @@ def generate_dataset_split( def generate_dataset_dict( self, prompt_spec: PromptSpec, - expected_num_examples: dict[DatasetSplit, int], - output_dir: str | None = None, + num_examples: dict[DatasetSplit, int], ) -> datasets.DatasetDict: """Generate full dataset splits (e.g. train/dev/test) from a prompt. Args: prompt_spec: A prompt specification. - expected_num_examples: Expected number of + num_examples: Expected number of examples per split (train/val/test). Returns: @@ -59,13 +57,8 @@ def generate_dataset_dict( dataset_dict = datasets.DatasetDict( { split.value: self.generate_dataset_split(prompt_spec, num, split=split) - for split, num in expected_num_examples.items() + for split, num in num_examples.items() } ) - if output_dir: - save_dir = Path(output_dir) - save_dir.mkdir(parents=True, exist_ok=True) - dataset_dict.save_to_disk(str(save_dir)) - return dataset_dict diff --git a/prompt2model/dataset_generator/prompt_based.py b/prompt2model/dataset_generator/prompt_based.py index 42d6b568a..793a791a9 100644 --- a/prompt2model/dataset_generator/prompt_based.py +++ b/prompt2model/dataset_generator/prompt_based.py @@ -8,7 +8,6 @@ import random from collections import Counter, defaultdict from dataclasses import dataclass -from pathlib import Path import nest_asyncio import openai @@ -37,6 +36,14 @@ class Example: input_col: str output_col: str + def __eq__(self, other) -> bool: + """Example equality.""" + return self.input_col == other.input_col and self.output_col == other.output_col + + def __lt__(self, other) -> bool: + """Example less than.""" + return self.input_col < other.input_col or self.output_col < other.output_col + class PromptBasedDatasetGenerator(DatasetGenerator): """A abstract class for NLP dataset generation using a prompted API.""" @@ -74,7 +81,6 @@ def __init__( requests_per_minute: The maximum number of requests per minute. filter_duplicated_examples: If True, filters duplicated examples, using the most-frequent output for each input. - cache_root: The root directory for caching generated examples. Raises: ValueError: If the 'max_api_calls' value is not greater than 0. @@ -117,13 +123,13 @@ def __init__( self.responses_per_request = responses_per_request self.requests_per_minute = requests_per_minute self.filter_duplicated_examples = filter_duplicated_examples - self.cache_root = Path(cache_root) def construct_prompt( self, instruction: str, few_shot_example_string: str, generated_examples: list[Example], + context_cutoff: int = 3500, ) -> str: """Generates a prompt string. @@ -133,53 +139,33 @@ def construct_prompt( generation, it defaults to using the few_shot_example_string. The function uses different prompt templates based on the number of selected - examples from the generated dataset. If the total length of the prompt exceeds - 3500 tokens, repeat the prompt generation process to generate a shorter one. + examples from the generated dataset. Args: instruction: The natural language instruction for the prompt. few_shot_example_string: A string representing the few-shot examples parsed from the user's prompt, which quality is higher than the - genrated examples. + generated examples. generated_examples: A list of currently generated examples. + context_cutoff: If the total length of the prompt exceeds this value, + repeat the prompt generation process to generate a shorter one. Returns: The generated prompt string. """ - # The random_example_string is a string, which contains several random - # few-shot examples as demonstrations for the DatasetGenerator. If - # generated_examples is empty, then the random_example_string - # is the few-shot examples parsed from the user's prompt. while True: + # Choose a few examples to add to the prompt if examples exist if len(generated_examples) == 0: low_quality_example_string = "N/A\n" - # Create default low_quality_example_string if generated_examples - # is empty. few_shot_example_string is the high-quality few-shot - # examples parsed from the user's prompt. But if user does not - # provideany examples in the input prompt, the few_shot_example_string - # will be "N/A"/""/None. random_selected_generated_example_num = 0 - # random_selected_generated_example_num is the number of selected - # random examples from generated_examples that will be used to - # create the low_quality_example_string. If generated_examples - # is empty, then random_selected_generated_example_num is 0. else: - # If generated_examples is not empty, low_quality_example_string - # is sveral random generated examples from generated_examples. - low_quality_example_string = "" random_selected_generated_example_num = random.randint( 1, min(len(generated_examples), 10) ) - # random_selected_generated_example_num is the number of selected - # random examples from generated_examples that will - # be concatenated to create low_quality_example_string. random_examples = random.sample( generated_examples, random_selected_generated_example_num ) - # If generated_examples is not empty, then choose several - # random examples from generated_examples to construct - # new low_quality_example_string. for example in random_examples: low_quality_example_string += ( f'input="{example.input_col}"\noutput="{example.output_col}"\n' @@ -198,10 +184,7 @@ def construct_prompt( high_quality_example_string=few_shot_example_string, template_type=template_type, ) - # The max content length of gpt-3.5-turbo is 4097, so if the - # generated prompt is longer than 3500, then the prompt - # should be regenerated. - if count_tokens_from_string(prompt) < 3500: + if count_tokens_from_string(prompt) < context_cutoff: return prompt else: orginal_input_string = ( @@ -209,69 +192,17 @@ def construct_prompt( if few_shot_example_string else instruction ) - if count_tokens_from_string(orginal_input_string) > 3500: + if count_tokens_from_string(orginal_input_string) > context_cutoff: logger.warning( - "The original input prompt is too long. Consider writing a shorter prompt." # noqa E501 + "The original input prompt is too long. " + "Consider writing a shorter prompt." ) continue - def construct_input_output_map( + def apply_multi_vote_filtering( self, generated_examples: list[Example], - ) -> dict[str, Counter]: - """Constructs a dictionary mapping inputs to `Counter` objects of outputs. - - Args: - generated_examples: A list of currently generated examples. - - Ideally, each input should have a unique output (one-to-one mapping). - However, language models may occasionally generate different outputs - for identical inputs. For instance, given the input “What is the biggest - city in China?”, it might produce different but correct outputs such as - “Shanghai” and “The biggest city in China is Shanghai”. At other times, - it may produce incorrect variations. For the input “What is the Chemical - symbol of gold?”, the outputs might be “Au”, “Au”, and “AU”, where the - last one is wrong due to capital letters. - - To address this, PromptBasedDatasetGenerator uses a two-step multi-vote - filtering mechanism. This function represents the first step, creating a - dictionary to map inputs to a `Counter` of their outputs. - - The function iterates over all the examples, building a dictionary where - inputs serve as keys and `Counter` objects as values. The `Counter` - tracks the frequency of each output for a specific input. - - For example: - input: ["apple", "banana", "apple", "orange", "apple"] - output: ["A", "B", "A", "O", "D"] - - Then input_output_map value is: - { - "apple": Counter({"A": 2, "D": 1}), - "banana": Counter({"B": 1}), - "orange": Counter({"O": 1}) - } - """ - input_output_map: dict[str, Counter] = defaultdict(Counter) - - # Iterate through the examples and construct the mapping. - for example in generated_examples: - input_str = example.input_col - output_str = example.output_col - - # Increment the count of the output for the specific input. - input_output_map[input_str][output_str] += 1 - - # Ensure that the generated_examples list is not empty - # and the map is constructed correctly. - if len(generated_examples) != 0 and input_output_map is None: - raise ValueError("input_output_map is not correctly constructed.") - - return input_output_map - - def apply_multi_vote_to_construct_generated_dataset( - self, input_output_map: dict[str, Counter] - ) -> Dataset: + ) -> list[Example]: """Multi-vote to construct generated_dataset from input_output_map. Args: @@ -290,35 +221,20 @@ def apply_multi_vote_to_construct_generated_dataset( highest frequency, it selects the one that comes first in lexicographical (alphabetical) order. - Example: - Suppose input_output_map is: - { - "apple": Counter({"A": 2, "D": 2}), - "banana": Counter({"B": 2, "C": 1}), - "orange": Counter({"O": 1}) - } - - The function will produce generated_dataset: - { - "input_col": ["apple", "banana", "orange"], - "output_col": ["A", "B", "O"] - } - - Note: When generated_examples is empty, both input_output_map - and generated_dataset will be empty. - Returns: Currently generated dataset with multi-vote filtering applied. """ - # Ensure that multi-vote filtering is enabled. - if not self.filter_duplicated_examples: - raise ValueError("Multi-vote filtering is not enabled.") + filtered_examples = [] + + input_output_map: dict[str, Counter] = defaultdict(Counter) - filtered_inputs = [] - filtered_outputs = [] + for ex in generated_examples: + input_output_map[ex.input_col][ex.output_col] += 1 + + if len(generated_examples) != 0 and input_output_map is None: + raise ValueError("input_output_map is not correctly constructed.") for input_str, output_counter in input_output_map.items(): - # Find the most frequent output count. most_common_count = output_counter.most_common(1)[0][1] # Get all the outputs that have the most common count. @@ -335,73 +251,10 @@ def apply_multi_vote_to_construct_generated_dataset( most_frequent_outputs.sort(key=len) final_output = most_frequent_outputs[0] - filtered_inputs.append(input_str) - filtered_outputs.append(final_output) - - # Note that when `generated_examples` is empty, - # `input_output_map` is None, and `generated_dataset` - # will also be empty. - generated_dataset = Dataset.from_dict( - {"input_col": filtered_inputs, "output_col": filtered_outputs} - ) - return generated_dataset - - def create_all_examples_dataset_and_generated_dataset( - self, generated_examples: list[Example] - ) -> Dataset: - """Converts generated_examples into generated_dataset. - - Args: - generated_examples: A list of currently generated examples. - - Depending on the value of self.filter_duplicated_examples, the function either - constructs a mapping for input-output pairs followed by multi-vote filtering - to create a Dataset or directly converts the generated examples into a Dataset. - - The function also verifies the presence of data in the input-output map - and the generated dataset if there are any generated examples and - self.filter_duplicated_examples is True. - - Lastly, the function stores all generated examples, irrespective of the value - of self.filter_duplicated_examples, into a Dataset on the disk. - - Returns: - A dataset of all the generated examples and the currently generated - dataset. If filter_duplicated_examples is True, multi-vote filtering is - performed. Else, the generated examples are directly converted into - a Dataset. - """ - # Convert all generated examples into a Dataset. - all_generated_examples_dataset = Dataset.from_dict( - { - "input_col": [example.input_col for example in generated_examples], - "output_col": [example.output_col for example in generated_examples], - } - ) - - if self.filter_duplicated_examples: - # When filtering duplicated examples is - # enabled, perform multi-vote filtering. - input_output_map = self.construct_input_output_map(generated_examples) - generated_dataset = self.apply_multi_vote_to_construct_generated_dataset( - input_output_map - ) - - if len(generated_examples) != 0 and input_output_map is None: - raise ValueError("The input-output map is not correctly constructed.") - else: - # When filtering duplicated examples is not enabled, - # use all_generated_examples_dataset directly. - generated_dataset = all_generated_examples_dataset + filtered_examples.append(Example(input_str, final_output)) + return filtered_examples - if len(generated_examples) != 0 and len(generated_dataset) == 0: - raise ValueError("The generated dataset is not correctly constructed.") - - return all_generated_examples_dataset, generated_dataset - - def compute_batch_size( - self, expected_num_examples: int, generated_dataset: Dataset - ) -> int: + def compute_batch_size(self, num_examples: int, generated_dataset_size: int) -> int: """Computes the batch size for API calls in a batch. The batch size is determined based on the remaining number of examples to be @@ -409,10 +262,8 @@ def compute_batch_size( the maximum limit of API calls if it is set. Args: - expected_num_examples: The total number of examples expected to be - generated for the current dataset split. Note that if max_api_calls is not - set, the actual number of generated examples can be slightly higher due - to each API call returning `responses_per_request` examples. + num_examples: The total number of examples expected to be + generated for the current dataset split. generated_dataset: Currently generated dataset. Returns: @@ -427,10 +278,7 @@ def compute_batch_size( batch_size = min( self.max_batch_size, math.ceil( - ( - (expected_num_examples - len(generated_dataset)) - / self.responses_per_request - ) + ((num_examples - generated_dataset_size) / self.responses_per_request) ), max_api_calls, ) @@ -439,15 +287,15 @@ def compute_batch_size( raise ValueError("Batch size must be greater than 0.") return batch_size - def extract_responses( + def extract_and_append_responses( self, completions: list[openai.Completion], generated_examples: list[Example] - ) -> list[Example]: + ) -> None: """Extracts the generated sample and annotation from an API response. Args: completions: A list of Completion objects returned by the API. - Each API call returns a number of completion objects equivalent to - `responses_per_request`. The default `responses_per_request` = 5. + Each API call returns a number of completion objects equivalent to + `responses_per_request`. The default `responses_per_request` = 5. generated_examples: Currently generated examples of DatasetGenerator. This function iterates through the provided completions, attempting to @@ -459,44 +307,6 @@ def extract_responses( with `input_col` and `output_col` fields, representing the generated example and label strings respectively. The `example` is then added to generated_examples. - - Note: The function process `batch_size * responses_per_request` - responses at a time. - - Example: - Given a list of two completion objects: [completion_1, completion_2], - where: - completion_1.choices = [ - {"message": {"content": '{"input": "1", "output": "a"}'}}, - {"message": {"content": '{"input": "1", "output": "b"}'}}, - {"message": {"content": '{"input": "1", "output": "a"}'}}, - ] - completion_2.choices = [ - {"message": {"content": '{"input": "1", "output": "c"}'}}, - {"message": {"content": '{"input": "2", "output": "a"}'}}, - {"message": {"content": '{"input": "2", "output": "b"}'}}, - ] - - The function will create 'example' namedtuples: - Example(input_col="1", output_col="a") - Example(input_col="1", output_col="b") - Example(input_col="1", output_col="a") - Example(input_col="1", output_col="c") - Example(input_col="2", output_col="a") - Example(input_col="2", output_col="b") - - It will then append them to generated_examples. - - Returns: - A list of `Example` objects. - Each API call will return `responses_per_request` completion objects. - If the response is a valid JSON object, create a namedtuple called - `example` and append it to generated_examples. `example` consists - of `input_col` and`output_col`, where: - - input_col is the generated example string extracted from the response. - - output_col is the generated label string extracted from the response. - If the response is not a valid JSON object, discard it. - There is responses_per_request * len(completions) responses at a time. """ for completion in completions: try: @@ -532,12 +342,11 @@ def extract_responses( f"Error happened when parsing API completion: {completion}" ) continue - return generated_examples async def generate_responses( self, chat_api: APIAgent, - generated_dataset: Dataset, + generated_dataset_size: int, expected_num_examples: int, prompts: list[str], ) -> list[openai.Completion]: @@ -563,9 +372,12 @@ async def generate_responses( """ # Calculate the dynamic temperature based # on the size of the generated dataset - dynamic_temperature = (self.max_temperature - self.initial_temperature) * len( - generated_dataset - ) / expected_num_examples + self.initial_temperature + dynamic_temperature = ( + (self.max_temperature - self.initial_temperature) + * generated_dataset_size + / expected_num_examples + + self.initial_temperature + ) # Ensure the dynamic temperature is within the range [0, 2.0] clipped_temperature = max(0.0, min(2.0, dynamic_temperature)) @@ -580,12 +392,12 @@ async def generate_responses( def generate_dataset_split( self, prompt_spec: PromptSpec, - expected_num_examples: int, - split: DatasetSplit, + num_examples: int, + split: DatasetSplit = DatasetSplit.TRAIN, ) -> Dataset: - """Generates a dataset split using GPT-3.5. + """Generates a dataset split using API-based LMs. - This method iteratively makes API calls to GPT-3.5 to generate a dataset split. + This method iteratively makes API calls to generate a dataset split. Each API call yields a batch of responses. From these responses, new examples are extracted and added to 'generated_examples'. The process continues until the desired number of examples is reached, or the maximum limit on API @@ -594,97 +406,63 @@ def generate_dataset_split( Args: prompt_spec: PromptParser to be used for generating examples. - expected_num_examples: The number of examples expected to be - generated. If the maximum limit on API calls is not set, the actual - number of generated examples can be slightly higher due to each - API call returning `responses_per_request` examples. - split: The dataset split (e.g., train, validation, test) for which the - examples are being generated. + num_examples: The number of examples to be generated. Returns: The generated dataset split. """ - # Refresh the relevant data structures for the new split. - self.cache_root.mkdir(parents=True, exist_ok=True) - examples_cache_path = Path( - self.cache_root / f"generated_examples_{split.value}" - ) - dataset_cache_path = Path(self.cache_root / f"generated_dataset_{split.value}") - - if examples_cache_path.exists(): - # If cache exists, load generated examples from disk. - logger.info(f"Loading cache from {str(examples_cache_path)}.") - all_generated_examples_dataset = Dataset.load_from_disk(examples_cache_path) - generated_examples = [ - Example(input_col=ex["input_col"], output_col=ex["output_col"]) - for ex in all_generated_examples_dataset - ] - else: - # Initialize data structures for a new split. - generated_examples = [] + all_generated_examples: list[Example] = [] + generated_examples: list[Example] = [] - pbar = tqdm(total=expected_num_examples, desc="Generating examples") + pbar = tqdm(total=num_examples, desc="Generating examples") chat_api = APIAgent() - while True: - # Each API call will return `responses_per_request` completion - # objects. The upper bound of the length of the generated dataset - # is expected_num_examples + responses_per_request. - try: - # Convert the generated examples into a - # Dataset and update the progress bar. - ( - all_generated_examples_dataset, - generated_dataset, - ) = self.create_all_examples_dataset_and_generated_dataset( - generated_examples - ) - all_generated_examples_dataset.save_to_disk(examples_cache_path) - generated_dataset.save_to_disk(dataset_cache_path) - pbar.update(len(generated_dataset)) - - if self.max_api_calls and self.api_call_counter >= self.max_api_calls: - logger.warning("Maximum number of API calls reached.") - break - elif len(generated_dataset) >= expected_num_examples: - break - else: - # Compute the batch size for the next API call. - batch_size = self.compute_batch_size( - expected_num_examples, generated_dataset - ) - self.api_call_counter += batch_size - - # Generate prompts for the batch call. - prompts = [ - self.construct_prompt( - instruction=prompt_spec.instruction, - few_shot_example_string=prompt_spec.examples, - generated_examples=generated_examples - if not self.filter_duplicated_examples - else [ - Example(each["input_col"], each["output_col"]) - for each in generated_dataset - ], - ) - for _ in range(batch_size) - ] + while len(generated_examples) < num_examples: + if self.max_api_calls and self.api_call_counter >= self.max_api_calls: + logger.warning("Maximum number of API calls reached.") + break - loop = asyncio.get_event_loop() - responses = loop.run_until_complete( - self.generate_responses( - chat_api=chat_api, - generated_dataset=generated_dataset, - expected_num_examples=expected_num_examples, - prompts=prompts, - ) - ) + batch_size = self.compute_batch_size(num_examples, len(generated_examples)) + self.api_call_counter += batch_size + + # Generate prompts for the batch call. + prompts = [ + self.construct_prompt( + instruction=prompt_spec.instruction, + few_shot_example_string=prompt_spec.examples, + generated_examples=generated_examples, + ) + for _ in range(batch_size) + ] - # Extract the responses and add new examples to the dataset. - generated_examples = self.extract_responses( - responses, generated_examples + try: + loop = asyncio.get_event_loop() + responses = loop.run_until_complete( + self.generate_responses( + chat_api=chat_api, + generated_dataset_size=len(generated_examples), + expected_num_examples=num_examples, + prompts=prompts, ) + ) except API_ERRORS as e: - # Handle API errors and adjust the API call counter. - self.api_call_counter = handle_api_error(e, self.api_call_counter) - return generated_dataset + handle_api_error(e) + + # Extract the responses and add new examples to the dataset. + self.extract_and_append_responses(responses, all_generated_examples) + generated_examples = ( + self.apply_multi_vote_filtering(all_generated_examples) + if self.filter_duplicated_examples + else all_generated_examples + ) + if len(generated_examples) >= num_examples: + generated_examples = generated_examples[:num_examples] + + pbar.update(len(generated_examples)) + + return Dataset.from_dict( + { + "input_col": [ex.input_col for ex in generated_examples], + "output_col": [ex.output_col for ex in generated_examples], + } + ) diff --git a/prompt2model/prompt_parser/instr_parser.py b/prompt2model/prompt_parser/instr_parser.py index fc9815ed1..be941b652 100644 --- a/prompt2model/prompt_parser/instr_parser.py +++ b/prompt2model/prompt_parser/instr_parser.py @@ -112,7 +112,7 @@ def parse_from_prompt(self, prompt: str) -> None: ) return None except API_ERRORS as e: - self.api_call_counter = handle_api_error(e, self.api_call_counter) + handle_api_error(e) if self.max_api_calls and self.api_call_counter >= self.max_api_calls: logger.error("Maximum number of API calls reached.") raise ValueError("Maximum number of API calls reached.") from e diff --git a/prompt2model/utils/api_tools.py b/prompt2model/utils/api_tools.py index 1b3baaec4..afc51fcaa 100644 --- a/prompt2model/utils/api_tools.py +++ b/prompt2model/utils/api_tools.py @@ -173,18 +173,16 @@ async def _throttled_completion_acreate( return responses -def handle_api_error(e, api_call_counter): +def handle_api_error(e) -> None: """Handle OpenAI errors or related errors that the API may raise. Args: e: The error to handle. This could be an OpenAI error or a related non-fatal error, such as JSONDecodeError or AssertionError. - api_call_counter: The number of API calls made so far. - - Returns: - The api_call_counter (if no error was raised), else raise the error. """ logging.error(e) + if not isinstance(e, API_ERRORS): + raise e if isinstance( e, (openai.error.APIError, openai.error.Timeout, openai.error.RateLimitError), @@ -192,13 +190,6 @@ def handle_api_error(e, api_call_counter): # For these errors, OpenAI recommends waiting before retrying. time.sleep(1) - if isinstance(e, API_ERRORS): - # For these errors, we can increment a counter and retry the API call. - return api_call_counter - else: - # For all other errors, immediately throw an exception. - raise e - def count_tokens_from_string(string: str, encoding_name: str = "cl100k_base") -> int: """Handle count the tokens in a string with OpenAI's tokenizer. diff --git a/test_helpers/__init__.py b/test_helpers/__init__.py index 44f394058..73fd56172 100644 --- a/test_helpers/__init__.py +++ b/test_helpers/__init__.py @@ -1,8 +1,4 @@ """Import mock classes used in unit tests.""" -from test_helpers.dataset_tools import ( - are_dataset_dicts_identical, - are_datasets_identical, -) from test_helpers.mock_api import ( MockBatchDifferentCompletions, MockCompletion, diff --git a/test_helpers/dataset_tools.py b/test_helpers/dataset_tools.py deleted file mode 100644 index ef0743800..000000000 --- a/test_helpers/dataset_tools.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Tools for testing if two Datasets or DatasetDicts are identical.""" - - -from __future__ import annotations # noqa FI58 - -import datasets - - -def are_datasets_identical( - dataset1: datasets.Dataset, dataset2: datasets.Dataset -) -> bool: - """Check if two datasets are identical in terms of instance values and order.""" - if len(dataset1) != len(dataset2): - return False - - return all( - instance1 == instance2 for instance1, instance2 in zip(dataset1, dataset2) - ) - - -def are_dataset_dicts_identical( - dataset_dict1: datasets.DatasetDict, dataset_dict2: datasets.DatasetDict -) -> bool: - """Check if two DatasetDict objects are identical.""" - if set(dataset_dict1.keys()) != set(dataset_dict2.keys()): - return False - - return all( - are_datasets_identical(dataset_dict1[split_name], dataset_dict2[split_name]) - for split_name in dataset_dict1.keys() - ) diff --git a/tests/dataset_generator_test.py b/tests/dataset_generator_test.py new file mode 100644 index 000000000..b15990778 --- /dev/null +++ b/tests/dataset_generator_test.py @@ -0,0 +1,948 @@ +"""Testing DatasetGenerator through PromptBasedDatasetGenerator.""" + +import logging +import os +import tempfile +from functools import partial +from unittest.mock import patch + +import datasets +import pytest +from datasets import Dataset + +from prompt2model.dataset_generator.base import DatasetSplit +from prompt2model.dataset_generator.prompt_based import ( + Example, + PromptBasedDatasetGenerator, +) +from prompt2model.prompt_parser import MockPromptSpec, TaskType +from test_helpers import ( + MockCompletion, + UnknownGpt3Exception, + mock_batch_api_response_identical_completions, +) +from test_helpers.mock_api import MockBatchDifferentCompletions + +logger = logging.getLogger("DatasetGenerator") + +MOCK_CLASSIFICATION_EXAMPLE = partial( + mock_batch_api_response_identical_completions, + content='{"input": "This is a great movie!", "output": "1"}', +) +MOCK_WRONG_KEY_EXAMPLE = partial( + mock_batch_api_response_identical_completions, + content='{"input": "This is a great movie!", "label": "1"}', +) +MOCK_INVALID_JSON = partial( + mock_batch_api_response_identical_completions, + content='{"input": "This is a great movie!", "output": "1}', +) + +MOCK_CLASSIFICATION_EXAMPLE = partial( + mock_batch_api_response_identical_completions, + content='{"input": "This is a great movie!", "output": "1"}', +) +MOCK_WRONG_KEY_EXAMPLE = partial( + mock_batch_api_response_identical_completions, + content='{"input": "This is a great movie!", "label": "1"}', +) +MOCK_INVALID_JSON = partial( + mock_batch_api_response_identical_completions, + content='{"input": "This is a great movie!", "output": "1}', +) + +MOCK_CLASSIFICATION_EXAMPLE = partial( + mock_batch_api_response_identical_completions, + content='{"input": "This is a great movie!", "output": "1"}', +) +MOCK_WRONG_KEY_EXAMPLE = partial( + mock_batch_api_response_identical_completions, + content='{"input": "This is a great movie!", "label": "1"}', +) +MOCK_INVALID_JSON = partial( + mock_batch_api_response_identical_completions, + content='{"input": "This is a great movie!", "output": "1}', +) + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MOCK_CLASSIFICATION_EXAMPLE, +) +def test_generate_dataset(mocked_generate_example): + """Test the `generate_dataset_split()` function of `PromptBasedDatasetGenerator`.""" + os.environ["OPENAI_API_KEY"] = "fake_api_key" + dataset_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=False) + prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) + split = DatasetSplit.TRAIN + num_examples = 29 + # If num_examples >= max_api_calls, the returned dataset's + # length will be less than or equal to max_api_calls. + dataset = dataset_generator.generate_dataset_split(prompt_spec, num_examples, split) + # Since each API call would return one completion object with 5 responses + # and some of the responses are invalid JSON objects, the upper bound of + # the length of the dataset is num_examples + 5, where 5 is the + # default number of responses per API call. + assert len(dataset) < num_examples + 5 + expected_columns = {"input_col", "output_col"} + assert set(dataset.column_names) == expected_columns + return dataset + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MOCK_CLASSIFICATION_EXAMPLE, +) +def test_generate_dataset_dict(mocked_generate_example): + """Test the `generate_dataset_dict()` function of `PromptBasedDatasetGenerator`.""" + os.environ["OPENAI_API_KEY"] = "fake_api_key" + dataset_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=False) + prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) + num_examples = { + DatasetSplit.TRAIN: 50, + DatasetSplit.VAL: 24, + DatasetSplit.TEST: 26, + } + dataset_dict = dataset_generator.generate_dataset_dict( + prompt_spec=prompt_spec, + num_examples=num_examples, + ) + + assert set(dataset_dict.keys()) == {"train", "val", "test"} + for split, num in num_examples.items(): + # As explained previously, the upper bound of the length of + # generated dataset is num_examples + 5, where + # 5 is the default number of responses per API call. + assert len(dataset_dict[split.value]) < num + 5 + expected_columns = {"input_col", "output_col"} + for dataset in dataset_dict.values(): + assert set(dataset.column_names) == expected_columns + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MOCK_CLASSIFICATION_EXAMPLE, +) +def test_generator_without_filter(mocked_generate_example): + """Unlimited dataset generation using the PromptBasedDatasetGenerator.""" + os.environ["OPENAI_API_KEY"] = "fake_api_key" + dataset_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=False) + dataset = dataset_generator.generate_dataset_split( + MockPromptSpec(TaskType.TEXT_GENERATION), 29, DatasetSplit.TRAIN + ) + assert len(dataset) == 29 + # The default responses_per_request is 5. So each API call will return + # 5 responses, i.e. 5 choices in openai.Completion.choices. + # Each API call will return 5 responses, and each response is a valid JSON. + # So the unlimited_dataset_generator will call the API 6 times. + assert dataset_generator.api_call_counter == 6 + # The default batch_size is 5. So generate_batch_completion + # will be called 2 times with first batch_size = 5 and second batch_size = 1. + assert mocked_generate_example.call_count == 2 + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MOCK_CLASSIFICATION_EXAMPLE, +) +def test_generator_without_filter_dict(mocked_generate_example): + """Test generation of a dataset dict.""" + dataset_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=False) + + prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) + num_examples = { + DatasetSplit.TRAIN: 50, + DatasetSplit.VAL: 24, + DatasetSplit.TEST: 26, + } + + dataset_dict = dataset_generator.generate_dataset_dict( + prompt_spec=prompt_spec, + num_examples=num_examples, + ) + + assert set(dataset_dict.keys()) == {"train", "val", "test"} + for split, num in num_examples.items(): + # As explained previously, the upper bound of the length of + # generated dataset is num_examples + 5, where + # 5 is the default number of responses per API call. + assert len(dataset_dict[split.value]) < num + 5 + expected_columns = {"input_col", "output_col"} + for dataset in dataset_dict.values(): + assert set(dataset.column_names) == expected_columns + + # Each API call returns five responses. So the dataset_generator will + # call the API (50 // 5 + 24 // 5 + 1 + 26 // 5 + 1) = 21 times. + assert dataset_generator.api_call_counter == (50 // 5 + 24 // 5 + 1 + 26 // 5 + 1) + # The default batch_size is 5. So generate_batch_completion + # will be called 2 times for 50 examples in the train split, + # 1 time for 24 examples in the validation split, + # and 2 times for 26 examples in the test split. + assert mocked_generate_example.call_count == 2 + 1 + 2 + + # Each API call returns 5 responses, and each response is a valid JSON. + # So the dataset_dict will contain (50, 25, 30) examples. + assert len(dataset_dict["train"]) == 50 + assert len(dataset_dict["val"]) == 24 + assert len(dataset_dict["test"]) == 26 + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MOCK_CLASSIFICATION_EXAMPLE, +) +def test_generator_max_api_calls(mocked_generate_example): + """Test generation when num_examples >= max_api_calls.""" + dataset_generator = PromptBasedDatasetGenerator( + max_api_calls=3, filter_duplicated_examples=False + ) + dataset = dataset_generator.generate_dataset_split( + MockPromptSpec(TaskType.TEXT_GENERATION), 29, DatasetSplit.TRAIN + ) + # The max_api_calls is 3. So the limited_dataset_generator calls the + # API 3 times. Each API call returns 5 responses. So the + # limited_dataset_generator will have 3 * 5 = 15 examples. + assert len(dataset) == 15 + + # The default batch_size is 5. So generate_batch_completion + # will be called only once. + assert mocked_generate_example.call_count == 1 + + # Each API call returns 5 responses, so the limited_dataset_generator + # will use up all the available API calls. + assert dataset_generator.api_call_counter == 3 + + # Each API call returns 5 responses, and each response is a valid JSON. + # So the dataset will contain 15 examples. + assert len(dataset) == 15 + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MockBatchDifferentCompletions().mock_completions, +) +def test_generator_with_filter_first_batch(mocked_generate_example): + """Test PromptBasedDatasetGenerator with filter methods in the first batch.""" + dataset_generator = PromptBasedDatasetGenerator( + max_api_calls=2, + filter_duplicated_examples=True, + max_batch_size=2, + responses_per_request=3, + ) + + # Generate the dataset split using the initialized generator. + generated_dataset = dataset_generator.generate_dataset_split( + prompt_spec=MockPromptSpec(TaskType.TEXT_GENERATION), + num_examples=5, + split=DatasetSplit.TRAIN, + ) + + # Assertions for API call count and dataset matching the expected result. + assert mocked_generate_example.call_count == 1 + assert dataset_generator.api_call_counter == 2 + + # Define the expected dataset based on the given mock responses. + expected_dataset = Dataset.from_dict( + { + "input_col": ["1", "2"], + "output_col": ["a", "a"], + } + ) + + # Verify the generated dataset matches the expected dataset. + assert list(generated_dataset) == list(expected_dataset) + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MockBatchDifferentCompletions().mock_completions, +) +def test_generator_with_filter_second_batch(mocked_generate_example): + """Test PromptBasedDatasetGenerator with filter methods in the second batch. + + This test verifies the behavior of the PromptBasedDatasetGenerator with filter + methods in the second batch of API calls. It initializes an + PromptBasedDatasetGenerator with specific settings, limiting the number of + API calls to 3. After running the generation process, the test checks + whether the generated dataset matches the expected result after the + second API call. The test also ensures that the number of calls to the + API mock matches the expected number. + + Note: The first API call's max_batch_size is 2, generating 6 responses. + The second API call's max_batch_size is 1, generating 3 responses. + + Args: + mocked_generate_example (MagicMock): The patched function representing the + @patch decorator for generating example responses. + """ + # Initialize the PromptBasedDatasetGenerator with specific settings. + dataset_generator = PromptBasedDatasetGenerator( + max_api_calls=3, + filter_duplicated_examples=True, + max_batch_size=2, + responses_per_request=3, + ) + + # Generate the dataset split using the initialized generator. + generated_dataset = dataset_generator.generate_dataset_split( + prompt_spec=MockPromptSpec(TaskType.TEXT_GENERATION), + num_examples=5, + split=DatasetSplit.TRAIN, + ) + + # Assertions for API call count and dataset matching the expected result. + assert mocked_generate_example.call_count == 2 + assert dataset_generator.api_call_counter == 3 + + # Define the expected dataset based on the given mock responses. + expected_dataset = Dataset.from_dict( + { + "input_col": ["1", "2", "3"], + "output_col": ["a", "a", "a"], + } + ) + + # Verify the generated dataset matches the expected dataset. + assert list(generated_dataset) == list(expected_dataset) + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MockBatchDifferentCompletions().mock_completions, +) +def test_generator_with_filter_third_batch(mocked_generate_example): + """Test PromptBasedDatasetGenerator with filter methods in the third batch. + + This test verifies the behavior of the PromptBasedDatasetGenerator with + filter methods in the third batch of API calls. It initializes an + PromptBasedDatasetGenerator with specific settings, limiting the number + of API calls to 4. After running the generation process, the test + checks whether the generated dataset matches the expected + result after the third API call. The test also ensures that the + number of calls to the API mock matches the expected number. + + Note: The first API call's max_batch_size is 2, generating 6 responses. + The second API call's max_batch_size is 1, generating 3 responses. + The third API call's max_batch_size is 1, generating 3 responses. + + Args: + mocked_generate_example (MagicMock): The patched function representing the + @patch decorator for generating example responses. + """ + # Initialize the PromptBasedDatasetGenerator with specific settings. + dataset_generator = PromptBasedDatasetGenerator( + max_api_calls=4, + filter_duplicated_examples=True, + max_batch_size=2, + responses_per_request=3, + ) + + # Generate the dataset split using the initialized generator. + generated_dataset = dataset_generator.generate_dataset_split( + prompt_spec=MockPromptSpec(TaskType.TEXT_GENERATION), + num_examples=5, + split=DatasetSplit.TRAIN, + ) + + # Assertions for API call count and dataset matching the expected result. + assert mocked_generate_example.call_count == 3 + assert dataset_generator.api_call_counter == 4 + + # Define the expected dataset based on the given mock responses. + expected_dataset = Dataset.from_dict( + { + "input_col": ["1", "2", "3"], + "output_col": ["b", "a", "a"], + } + ) + + # Verify the generated dataset matches the expected dataset. + assert list(generated_dataset) == list(expected_dataset) + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MockBatchDifferentCompletions().mock_completions, +) +def test_generator_with_filter_forth_batch(mocked_generate_example): + """Test PromptBasedDatasetGenerator with filter methods in the forth batch.""" + # Initialize the PromptBasedDatasetGenerator with specific settings. + dataset_generator = PromptBasedDatasetGenerator( + max_api_calls=5, + filter_duplicated_examples=True, + max_batch_size=2, + responses_per_request=3, + ) + + # Generate the dataset split using the initialized generator. + generated_dataset = dataset_generator.generate_dataset_split( + prompt_spec=MockPromptSpec(TaskType.TEXT_GENERATION), + num_examples=5, + split=DatasetSplit.TRAIN, + ) + + # Assertions for API call count and dataset matching the expected result. + assert mocked_generate_example.call_count == 4 + assert dataset_generator.api_call_counter == 5 + + # Define the expected dataset based on the given mock responses. + expected_dataset = Dataset.from_dict( + { + "input_col": ["1", "2", "3", "4", "5"], + "output_col": ["b", "a", "a", "c", "a"], + } + ) + + # Verify the generated dataset matches the expected dataset. + assert list(generated_dataset) == list(expected_dataset) + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MockBatchDifferentCompletions().mock_completions, +) +def test_generator_with_filter_unlimited_api_calls(mocked_generate_example): + """Test PromptBasedDatasetGenerator with filter methods and unlimited API calls.""" + # Initialize the PromptBasedDatasetGenerator with + # specific settings and unlimited API calls. + dataset_generator = PromptBasedDatasetGenerator( + filter_duplicated_examples=True, + max_batch_size=2, + responses_per_request=3, + ) + + # Generate the dataset split using the initialized generator. + generated_dataset = dataset_generator.generate_dataset_split( + prompt_spec=MockPromptSpec(TaskType.TEXT_GENERATION), + num_examples=5, + split=DatasetSplit.TRAIN, + ) + + # Assertions for API call count and dataset matching the expected result. + assert mocked_generate_example.call_count == 4 + assert dataset_generator.api_call_counter == 5 + + # Define the expected dataset based on the given mock responses. + expected_dataset = Dataset.from_dict( + { + "input_col": ["1", "2", "3", "4", "5"], + "output_col": ["b", "a", "a", "c", "a"], + } + ) + + # Verify the generated dataset matches the expected dataset. + assert list(generated_dataset) == list(expected_dataset) + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MockBatchDifferentCompletions(length=5).mock_completions, +) +def test_generator_with_filter_to_generate_datasetdict(mocked_generate_example): + """Test with filter methods to generate a DatasetDict.""" + # Initialize the PromptBasedDatasetGenerator with + # specific settings and limited API calls. + dataset_generator = PromptBasedDatasetGenerator( + filter_duplicated_examples=True, + max_batch_size=2, + responses_per_request=3, + max_api_calls=7, + ) + + # Generate the DatasetDict using the initialized generator. + generated_dataset_dict = dataset_generator.generate_dataset_dict( + prompt_spec=MockPromptSpec(TaskType.TEXT_GENERATION), + num_examples={ + DatasetSplit.TRAIN: 4, + DatasetSplit.VAL: 4, + DatasetSplit.TEST: 2, + }, + ) + + # Assertions for API call count and dataset + # dictionaries matching the expected results. + assert mocked_generate_example.call_count == 5 + assert dataset_generator.api_call_counter == 7 + + # Define the expected dataset dictionaries + # based on the given mock responses. + expected_dataset_dict = datasets.DatasetDict( + { + "train": Dataset.from_dict( + { + "input_col": ["1", "2", "3", "4"], + "output_col": ["b", "a", "a", "c"], + } + ), + "val": Dataset.from_dict( + { + "input_col": ["1", "2"], + "output_col": ["a", "a"], + } + ), + "test": Dataset.from_dict( + { + "input_col": [], + "output_col": [], + } + ), + } + ) + + # Verify the generated DatasetDict matches the expected DatasetDict. + assert list(generated_dataset_dict["train"]) == list(expected_dataset_dict["train"]) + assert list(generated_dataset_dict["val"]) == list(expected_dataset_dict["val"]) + assert list(generated_dataset_dict["test"]) == list(expected_dataset_dict["test"]) + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MOCK_CLASSIFICATION_EXAMPLE, +) +def test_generator_max_api_calls_dict(mocked_generate_example): + """Test generation of a dataset dict where we hit max api calls.""" + # Refresh the call_count and create a new limited_dataset_generator. + dataset_generator = PromptBasedDatasetGenerator( + filter_duplicated_examples=False, + max_api_calls=13, + ) + + prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) + num_examples = { + DatasetSplit.TRAIN: 50, + DatasetSplit.VAL: 24, + DatasetSplit.TEST: 26, + } + + dataset_dict = dataset_generator.generate_dataset_dict( + prompt_spec=prompt_spec, + num_examples=num_examples, + ) + + # Since the max_api_calls is 13, the limited_dataset_generator cannot + # generate the whole dataset_dict and will call the API 13 times. + assert dataset_generator.api_call_counter == 13 + + # The train split has 50 examples, so it will call the API 10 times and call + # generate_batch_completion 2 times. + # The validation split has 24 examples, but there are only 3 API calls + # left, so it will call the API 3 times and call + # generate_batch_completion 1 time. + # The test split has 26 examples, but there are no more API calls left, + # so it will not call generate_batch_completion. + assert mocked_generate_example.call_count == 2 + 1 + 0 + + # Each API call returns 5 responses, and each response is a valid JSON. + # So the generated_dataset_dict will contain (50, 15, 0) examples. + assert len(dataset_dict["train"]) == 50 + assert len(dataset_dict["val"]) == 15 + assert len(dataset_dict["test"]) == 0 + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MOCK_WRONG_KEY_EXAMPLE, +) +def test_wrong_key_example(mocked_generate_example): + """Test PromptBasedDatasetGenerator when the agent returns wrong keys.""" + dataset_generator = PromptBasedDatasetGenerator( + max_api_calls=3, filter_duplicated_examples=False + ) + prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) + num_examples = 1 + split = DatasetSplit.TRAIN + generated_dataset = dataset_generator.generate_dataset_split( + prompt_spec, num_examples, split + ) + assert mocked_generate_example.call_count == 3 + expected_dataset = Dataset.from_dict({"input_col": [], "output_col": []}) + assert list(expected_dataset) == list(generated_dataset) + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MOCK_INVALID_JSON, +) +def test_invalid_json_response(mocked_generate_example): + """Test when the agent returns invalid JSON responses.""" + # Init the PromptBasedDatasetGenerator with `max_api_calls = 3`. + dataset_generator = PromptBasedDatasetGenerator(3, filter_duplicated_examples=False) + prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) + num_examples = 1 + split = DatasetSplit.VAL + dataset = dataset_generator.generate_dataset_split(prompt_spec, num_examples, split) + assert mocked_generate_example.call_count == 3 + expected_dataset = Dataset.from_dict({"input_col": [], "output_col": []}) + assert list(dataset) == list(expected_dataset) + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=UnknownGpt3Exception(), +) +def test_unexpected_examples_of_gpt(mocked_generate_example): + """Test PromptBasedDatasetGenerator when the agent returns unexpected examples.""" + os.environ["OPENAI_API_KEY"] = "fake_api_key" + # Init the PromptBasedDatasetGenerator with `max_api_calls = 3`. + with pytest.raises(UnknownGpt3Exception): + dataset_generator = PromptBasedDatasetGenerator( + max_api_calls=3, filter_duplicated_examples=False + ) + prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) + num_examples = 1 + split = DatasetSplit.TEST + _ = dataset_generator.generate_dataset_split(prompt_spec, num_examples, split) + assert mocked_generate_example.call_count == 1 + + +def test_filter_with_duplicate_inputs_unique_outputs(): + """Test filtering with duplicate inputs, unique outputs.""" + os.environ["OPENAI_API_KEY"] = "fake_api_key" + data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=False) + generated_examples = [ + Example(input_col="apple", output_col="A"), + Example(input_col="banana", output_col="B"), + Example(input_col="apple", output_col="E"), + Example(input_col="orange", output_col="O"), + Example(input_col="apple", output_col="D"), + ] + filtered_examples = data_generator.apply_multi_vote_filtering(generated_examples) + expected_examples = [ + Example(input_col="apple", output_col="A"), + Example(input_col="banana", output_col="B"), + Example(input_col="orange", output_col="O"), + ] + assert sorted(expected_examples) == sorted(filtered_examples) + + +def test_filter_duplicate_inputs_duplicate_outputs(): + """Test constructing a map with duplicate inputs and duplicate outputs.""" + os.environ["OPENAI_API_KEY"] = "fake_api_key" + data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=False) + generated_examples = [ + Example(input_col="apple", output_col="A"), + Example(input_col="banana", output_col="C"), + Example(input_col="apple", output_col="A"), + Example(input_col="banana", output_col="B"), + Example(input_col="apple", output_col="G"), + Example(input_col="apple", output_col="A"), + Example(input_col="orange", output_col="O"), + Example(input_col="apple", output_col="D"), + Example(input_col="banana", output_col="B"), + Example(input_col="orange", output_col="F"), + ] + filtered_examples = data_generator.apply_multi_vote_filtering(generated_examples) + expected_examples = [ + Example(input_col="apple", output_col="A"), + Example(input_col="banana", output_col="B"), + Example(input_col="orange", output_col="O"), + ] + assert expected_examples == filtered_examples + + +def test_create_all_examples_dataset_and_generated_dataset_with_unique_inputs_outputs(): + """Test constructing a map with unique inputs and outputs.""" + os.environ["OPENAI_API_KEY"] = "fake_api_key" + data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=False) + generated_examples = [ + Example(input_col="apple", output_col="A"), + Example(input_col="banana", output_col="B"), + Example(input_col="orange", output_col="O"), + ] + filtered_examples = data_generator.apply_multi_vote_filtering(generated_examples) + assert generated_examples == filtered_examples + + +def test_create_all_examples_dataset_and_generated_dataset_with_empty_examples_list(): + """Test constructing a map with empty inputs and outputs.""" + os.environ["OPENAI_API_KEY"] = "fake_api_key" + data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=False) + generated_examples = [] + filtered_examples = data_generator.apply_multi_vote_filtering(generated_examples) + assert generated_examples == filtered_examples + + +def test_compute_batch_size_with_limited_max_api_calls(): + """Test the batch size computation with limited max API calls.""" + os.environ["OPENAI_API_KEY"] = "fake_api_key" + data_generator = PromptBasedDatasetGenerator(max_api_calls=28) + data_generator.api_call_counter = 26 + # Default batch size and responses_per_request are both 5. + # So each batch should contain 25 examples. + + # At least (125 - 110) / 5 = 3 API calls needed to get + # more than 125 examples. + + batch_size = data_generator.compute_batch_size( + num_examples=125, generated_dataset_size=110 + ) + assert ( + batch_size + == data_generator.max_api_calls - data_generator.api_call_counter + == 28 - 26 + ) + + data_generator.api_call_counter = 20 + batch_size = data_generator.compute_batch_size(125, generated_dataset_size=110) + assert ( + batch_size + == (125 - 110) / data_generator.responses_per_request + == (125 - 110) / 5 + ) + + data_generator.api_call_counter = 0 + batch_size = data_generator.compute_batch_size(125, generated_dataset_size=50) + assert batch_size == data_generator.max_batch_size + + +def test_compute_batch_size_with_unlimited_max_api_calls(): + """Test the batch size computation with unlimited max API calls.""" + os.environ["OPENAI_API_KEY"] = "fake_api_key" + data_generator = PromptBasedDatasetGenerator() + # Default batch size and responses_per_request are both 5. + # So each batch should contain 25 examples. + + # At least (125 - 110) / 5 = 3 API calls needed to get + # more than 125 examples. + + batch_size = data_generator.compute_batch_size(125, generated_dataset_size=110) + assert ( + batch_size + == (125 - 110) / data_generator.responses_per_request + == (125 - 110) / 5 + ) + + batch_size = data_generator.compute_batch_size(125, generated_dataset_size=50) + assert batch_size == data_generator.max_batch_size == 5 + + +def test_extract_responses(): + """Test the extract_responses function of DatasetGenerator.""" + mock_completion_1 = MockCompletion() + mock_completion_1.choices = [ + {"message": {"content": '{"input": "1", "output": "a"}'}}, + {"message": {"content": '{"input": "1", "output": "b"}'}}, + {"message": {"content": '{"input": "1", "output": "a"}'}}, + ] + mock_completion_2 = MockCompletion() + mock_completion_2.choices = [ + {"message": {"content": '{"input": "3", "output": "a"}'}}, + # Note that the following choice miss the right quote of JSON. + # So it should be discarded. And will log a warning. + {"message": {"content": '{"input": "3", "output": "a}'}}, + {"message": {"content": '{"input": "3", "output": "b"}'}}, + ] + mock_completion_3 = MockCompletion() + mock_completion_3.choices = [ + {"message": {"content": '{"input": "4", "output": "c"}'}}, + {"message": {"content": '{"input": "4", "output": "c"}'}}, + {"message": {"content": '{"input": "5", "output": "a"}'}}, + ] + # choices should be list of dicts. So mock_completion_4 + # is invalid. Which will be discarded and log a warning. + mock_completion_4 = MockCompletion() + mock_completion_4.choices = None + + os.environ["OPENAI_API_KEY"] = "fake_api_key" + data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=True) + generated_examples = [] + with patch.object(logger, "info") as mock_info, patch.object( + logger, "warning" + ) as mock_warning: + data_generator.extract_and_append_responses( + [mock_completion_1, mock_completion_2], generated_examples + ) + mock_warning.assert_called_once_with( + 'Error happened parsing API choice: {\'message\': {\'content\': \'{"input": "3", "output": "a}\'}}' # noqa E501 + ) + # There are 5 valid examples. Each input + # and output will be logged once as info. + assert mock_info.call_count == 5 * 2 + + # The second choice in mock_completion_2 + # is invalid. So it should be discarded. + assert generated_examples == [ + Example(input_col="1", output_col="a"), + Example(input_col="1", output_col="b"), + Example(input_col="1", output_col="a"), + Example(input_col="3", output_col="a"), + Example(input_col="3", output_col="b"), + ] + data_generator.extract_and_append_responses([mock_completion_3], generated_examples) + assert generated_examples == [ + Example(input_col="1", output_col="a"), + Example(input_col="1", output_col="b"), + Example(input_col="1", output_col="a"), + Example(input_col="3", output_col="a"), + Example(input_col="3", output_col="b"), + Example(input_col="4", output_col="c"), + Example(input_col="4", output_col="c"), + Example(input_col="5", output_col="a"), + ] + with patch.object(logger, "info") as mock_info, patch.object( + logger, "warning" + ) as mock_warning: + data_generator.extract_and_append_responses( + [mock_completion_4], generated_examples + ) + mock_warning.assert_called_once_with( + "Error happened when parsing API completion: " + ) + mock_info.assert_not_called() + # The generated_examples should be the same. + assert generated_examples == [ + Example(input_col="1", output_col="a"), + Example(input_col="1", output_col="b"), + Example(input_col="1", output_col="a"), + Example(input_col="3", output_col="a"), + Example(input_col="3", output_col="b"), + Example(input_col="4", output_col="c"), + Example(input_col="4", output_col="c"), + Example(input_col="5", output_col="a"), + ] + + +def test_extract_some_empty_responses(): + """Test the extract_responses function correctly handle empty responses.""" + mock_completion_1 = MockCompletion() + mock_completion_1.choices = [ + # Note that this choice's input is empty. So it should be discarded. + {"message": {"content": '{"input": "", "output": "a"}'}}, + {"message": {"content": '{"input": "5", "output": "b"}'}}, + # Note that this choice's output is empty. So it should be discarded. + {"message": {"content": '{"input": "1", "output": ""}'}}, + ] + mock_completion_2 = MockCompletion() + mock_completion_2.choices = [ + {"message": {"content": '{"input": "3", "output": "a"}'}}, + # Note that the following choice misses the right quote of JSON. + # So it should be discarded. And will log a warning. + {"message": {"content": '{"input": "3", "output": "a}'}}, + {"message": {"content": '{"input": "3", "output": "b"}'}}, + ] + mock_completion_3 = MockCompletion() + mock_completion_3.choices = [ + {"message": {"content": '{"input": "4", "output": "c"}'}}, + {"message": {"content": '{"input": "4", "output": "c"}'}}, + {"message": {"content": '{"input": "5", "output": "a"}'}}, + ] + # choices should be list of dicts. So mock_completion_4 + # is invalid. Which will be discarded and log a warning. + mock_completion_4 = MockCompletion() + mock_completion_4.choices = None + + with tempfile.TemporaryDirectory() as cache_dir: + os.environ["OPENAI_API_KEY"] = "fake_api_key" + data_generator = PromptBasedDatasetGenerator( + cache_root=cache_dir, filter_duplicated_examples=True + ) + generated_examples = [] + with patch.object(logger, "info") as mock_info, patch.object( + logger, "warning" + ) as mock_warning: + data_generator.extract_and_append_responses( + [mock_completion_1, mock_completion_2], generated_examples + ) + mock_warning.assert_called_once_with( + 'Error happened parsing API choice: {\'message\': {\'content\': \'{"input": "3", "output": "a}\'}}' # noqa E501 + ) + # There are 3 valid examples in [mock_completion_1, + # mock_completion_2] Each input + # and output will be logged once as info. + # And there are 2 examples with empty + # input or output, which should be discarded + # and be logged as info. + assert mock_info.call_count == 3 * 2 + 2 + + # The second choice in mock_completion_2 + # is invalid. So it should be discarded. + assert generated_examples == [ + Example(input_col="5", output_col="b"), + Example(input_col="3", output_col="a"), + Example(input_col="3", output_col="b"), + ] + data_generator.extract_and_append_responses( + [mock_completion_3], generated_examples + ) + assert generated_examples == [ + Example(input_col="5", output_col="b"), + Example(input_col="3", output_col="a"), + Example(input_col="3", output_col="b"), + Example(input_col="4", output_col="c"), + Example(input_col="4", output_col="c"), + Example(input_col="5", output_col="a"), + ] + with patch.object(logger, "info") as mock_info, patch.object( + logger, "warning" + ) as mock_warning: + data_generator.extract_and_append_responses( + [mock_completion_4], generated_examples + ) + mock_warning.assert_called_once_with( + "Error happened when parsing API completion: " + ) + mock_info.assert_not_called() + # The generated_examples should be the same. + assert generated_examples == [ + Example(input_col="5", output_col="b"), + Example(input_col="3", output_col="a"), + Example(input_col="3", output_col="b"), + Example(input_col="4", output_col="c"), + Example(input_col="4", output_col="c"), + Example(input_col="5", output_col="a"), + ] + + +def test_initialize_dataset_generator_with_dynamic_temperature(): + """Test the correct initialization of the dynamic temperature strategy.""" + with tempfile.TemporaryDirectory() as cache_dir: + os.environ["OPENAI_API_KEY"] = "fake_api_key" + with pytest.raises(ValueError) as exc_info: + _ = PromptBasedDatasetGenerator( + cache_root=cache_dir, initial_temperature=-0.2 + ) + error_info = exc_info.value.args[0] + assert ( + error_info + == "initial_temperature must be >= 0, but self.initial_temperature=-0.2" + ) + with pytest.raises(ValueError) as exc_info: + _ = PromptBasedDatasetGenerator(cache_root=cache_dir, max_temperature=2.3) + error_info = exc_info.value.args[0] + assert ( + error_info + == "max_temperature must be <= 2,0, but self.max_temperature=2.3" + ) + + with pytest.raises(ValueError) as exc_info: + _ = PromptBasedDatasetGenerator( + cache_root=cache_dir, max_temperature=1.2, initial_temperature=1.5 + ) + error_info = exc_info.value.args[0] + assert ( + error_info + == "self.initial_temperature=1.5 must be <= self.max_temperature=1.2" + ) + + +@patch( + "prompt2model.utils.APIAgent.generate_batch_completion", + side_effect=MOCK_CLASSIFICATION_EXAMPLE, +) +def test_dataset_generator_terminates(mocked_generate_example): + """Check to make sure that the dataset generator terminates.""" + prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) + dataset_generator = PromptBasedDatasetGenerator( + initial_temperature=0.3, + max_temperature=1.4, + responses_per_request=3, + max_api_calls=10000, + requests_per_minute=80, + filter_duplicated_examples=False, + ) + generated_dataset = dataset_generator.generate_dataset_split( + prompt_spec, 100, split=DatasetSplit.TRAIN + ) + generated_df = generated_dataset.to_pandas() + assert len(generated_dataset) == 100 + assert list(generated_df.columns) == ["input_col", "output_col"] diff --git a/tests/dataset_generator_with_filter_test.py b/tests/dataset_generator_with_filter_test.py deleted file mode 100644 index 0a6a55076..000000000 --- a/tests/dataset_generator_with_filter_test.py +++ /dev/null @@ -1,1376 +0,0 @@ -"""Testing DatasetGenerator through PromptBasedDatasetGenerator.""" - -import gc -import logging -import os -import tempfile -from collections import Counter -from functools import partial -from pathlib import Path -from unittest.mock import patch - -import datasets -import pytest -from datasets import Dataset - -from prompt2model.dataset_generator.base import DatasetSplit -from prompt2model.dataset_generator.prompt_based import ( - Example, - PromptBasedDatasetGenerator, -) -from prompt2model.prompt_parser import MockPromptSpec, TaskType -from test_helpers import ( - MockBatchDifferentCompletions, - UnknownGpt3Exception, - are_dataset_dicts_identical, - are_datasets_identical, - mock_batch_api_response_identical_completions, -) - -logger = logging.getLogger("DatasetGenerator") - -# Create partial functions to simulate different API responses. -# MOCK_EXAMPLE: Represents a mock example with identical completions. -# The content contains an input ("6") and the corresponding output ("f"). -MOCK_EXAMPLE = partial( - mock_batch_api_response_identical_completions, - content='{"input": "6", "output": "f"}', -) - -# MOCK_WRONG_KEY_EXAMPLE: Represents a mock example with identical completions, -# but the content contains an incorrect key "label" instead of "output". -MOCK_WRONG_KEY_EXAMPLE = partial( - mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "label": "1"}', -) - -# MOCK_INVALID_JSON: Represents a mock example with an invalid JSON content. -# The content is missing a closing double-quote for the "output" value. -MOCK_INVALID_JSON = partial( - mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "output": "1}', -) - - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MOCK_WRONG_KEY_EXAMPLE, -) -def test_wrong_key_example(mocked_generate_example): - """Test PromptBasedDatasetGenerator when the agent returns a wrong key dictionary. - - This test case is designed to verify the behavior of PromptBasedDatasetGenerator - when the APIAgent returns a dictionary with a wrong key, i.e., "label" instead - of "output". - - Args: - mocked_generate_example: The function represents the @patch function and - provides the mocked behavior for API calls. - - Note: The test function assumes the existence of 'MOCK_WRONG_KEY_EXAMPLE', - which represents a mock example with identical completions but an incorrect key - in the content. - - """ - # Initialize the PromptBasedDatasetGenerator with `max_api_calls = 3`. - with tempfile.TemporaryDirectory() as cache_dir: - dataset_generator = PromptBasedDatasetGenerator( - 3, filter_duplicated_examples=True, cache_root=cache_dir - ) - - # Create a mock prompt specification. - prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) - - # Set the expected number of examples and dataset split for testing. - expected_num_examples = 1 - split = DatasetSplit.TRAIN - - # Generate the dataset split using PromptBasedDatasetGenerator. - dataset = dataset_generator.generate_dataset_split( - prompt_spec, expected_num_examples, split - ) - - # Assertions to verify the test results. - assert mocked_generate_example.call_count == 3 - assert ( - dataset["input_col"] == dataset["output_col"] and dataset["input_col"] == [] - ) - - # Collect garbage to release memory resources after the test. - gc.collect() - - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MOCK_INVALID_JSON, -) -def test_invalid_json_response(mocked_generate_example): - """Test PromptBasedDatasetGenerator when the agent returns an invalid JSON response. - - This test case is designed to verify the behavior of PromptBasedDatasetGenerator - when the APIAgent returns a response with invalid JSON content. The @patch - decorator replaces the 'generate_batch_completion' function with - the 'MOCK_INVALID_JSON' side effect. - - Args: - mocked_generate_example: The function represents the @patch function and - provides the mocked behavior for API calls. - - Note: The test function assumes the existence of 'MOCK_INVALID_JSON', - which represents a mock example with an invalid JSON content. - - """ - # Initialize the PromptBasedDatasetGenerator with `max_api_calls = 3`. - with tempfile.TemporaryDirectory() as cache_dir: - dataset_generator = PromptBasedDatasetGenerator( - 3, filter_duplicated_examples=True, cache_root=cache_dir - ) - - # Create a mock prompt specification. - prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) - - # Set the expected number of examples and dataset split for testing. - expected_num_examples = 1 - split = DatasetSplit.VAL - - # Generate the dataset split using PromptBasedDatasetGenerator. - dataset = dataset_generator.generate_dataset_split( - prompt_spec, expected_num_examples, split - ) - - # Assertions to verify the test results. - assert mocked_generate_example.call_count == 3 - assert ( - dataset["input_col"] == dataset["output_col"] and dataset["input_col"] == [] - ) - - # Collect garbage to release memory resources after the test. - gc.collect() - - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=UnknownGpt3Exception(), -) -def test_unexpected_examples_of_gpt(mocked_generate_example): - """Test PromptBasedDatasetGenerator when the agent returns a GPT-3 exception. - - This test case is designed to verify the behavior of PromptBasedDatasetGenerator - when the APIAgent returns an unknown GPT-3 exception. The @patch decorator - replaces the 'generate_batch_completion' function with the - 'UnknownGpt3Exception' side effect, simulating an unexpected exception. - - Args: - mocked_generate_example: The function represents the @patch function and - provides the mocked behavior for API calls. - - Note: The test function assumes the existence of 'UnknownGpt3Exception', - which represents an unknown GPT-3 exception raised during API calls. - - """ - api_key = "fake_api_key" - - # Set the fake API key in the environment variable for testing purposes. - os.environ["OPENAI_API_KEY"] = api_key - - # Initialize the PromptBasedDatasetGenerator with `max_api_calls = 3`. - # Use pytest.raises() to assert that an UnknownGpt3Exception is raised. - with pytest.raises( - UnknownGpt3Exception - ), tempfile.TemporaryDirectory() as cache_dir: - dataset_generator = PromptBasedDatasetGenerator( - max_api_calls=3, filter_duplicated_examples=True, cache_root=cache_dir - ) - - # Create a mock prompt specification. - prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) - - # Set the expected number of examples and dataset split for testing. - expected_num_examples = 1 - split = DatasetSplit.TEST - - # Generate the dataset split using PromptBasedDatasetGenerator and expect the - # unknown GPT-3 exception to be raised. - _ = dataset_generator.generate_dataset_split( - prompt_spec, expected_num_examples, split - ) - - # Assertions to verify the test results. - assert mocked_generate_example.call_count == 1 - - # Collect garbage to release memory resources after the test. - gc.collect() - - -def test_construct_map_with_duplicate_inputs_unique_outputs(): - """Test constructing a map with duplicate inputs but unique outputs. - - This test case verifies the behavior of the construct_input_output_map() - method in PromptBasedDatasetGenerator when there are duplicate inputs but - unique outputs in the generated examples. - - Attributes: - api_key (str): The fake API key used for testing. - expected_output (dict): The expected input-output map to be constructed. - """ - # Set a fake API key in the environment variable for testing purposes. - os.environ["OPENAI_API_KEY"] = "fake_api_key" - - # Initialize the PromptBasedDatasetGenerator with filter_duplicated_examples=True. - with tempfile.TemporaryDirectory() as cache_dir: - data_generator = PromptBasedDatasetGenerator( - filter_duplicated_examples=True, cache_root=cache_dir - ) - - # Create a list of generated examples with duplicate inputs and unique outputs. - generated_examples = [ - Example(input_col="apple", output_col="A"), - Example(input_col="banana", output_col="B"), - Example(input_col="apple", output_col="E"), - Example(input_col="orange", output_col="O"), - Example(input_col="apple", output_col="D"), - ] - - # Call the construct_input_output_map() - # method to create the input-output map. - input_output_map = data_generator.construct_input_output_map(generated_examples) - - # The expected input-output map afte - # r constructing it from the generated examples. - expected_output = { - "apple": Counter({"A": 1, "E": 1, "D": 1}), - "banana": Counter({"B": 1}), - "orange": Counter({"O": 1}), - } - - # Assertions to verify that the input-output - # map matches the expected output. - assert input_output_map == expected_output - - # Collect garbage to release memory - # resources after the test. - gc.collect() - - -def test_construct_map_with_duplicate_inputs_duplicate_outputs(): - """Test constructing a map with duplicate inputs and duplicate outputs. - - This test case verifies the behavior of the construct_input_output_map() - method in PromptBasedDatasetGenerator when there are duplicate inputs and - duplicate outputs in the generated examples. - - Attributes: - api_key (str): The fake API key used for testing. - expected_output (dict): The expected input-output map to be constructed. - """ - # Set a fake API key in the environment variable for testing purposes. - os.environ["OPENAI_API_KEY"] = "fake_api_key" - - # Initialize the PromptBasedDatasetGenerator with filter_duplicated_examples=True. - with tempfile.TemporaryDirectory() as cache_dir: - data_generator = PromptBasedDatasetGenerator( - filter_duplicated_examples=True, cache_root=cache_dir - ) - - # Create a list of generated examples with - # duplicate inputs and duplicate outputs. - generated_examples = [ - Example(input_col="apple", output_col="A"), - Example(input_col="banana", output_col="C"), - Example(input_col="apple", output_col="A"), - Example(input_col="banana", output_col="B"), - Example(input_col="apple", output_col="G"), - Example(input_col="apple", output_col="A"), - Example(input_col="orange", output_col="O"), - Example(input_col="apple", output_col="D"), - Example(input_col="banana", output_col="B"), - Example(input_col="orange", output_col="F"), - ] - - # Call the construct_input_output_map() - # method to create the input-output map. - input_output_map = data_generator.construct_input_output_map(generated_examples) - - # The expected input-output map after - # constructing it from the generated examples. - expected_output = { - "apple": Counter({"A": 3, "D": 1, "G": 1}), - "banana": Counter({"B": 2, "C": 1}), - "orange": Counter({"O": 1, "F": 1}), - } - - # Assertions to verify that the input-output - # map matches the expected output. - assert input_output_map == expected_output - - # Collect garbage to release memory - # resources after the test. - gc.collect() - - -def test_construct_map_with_unique_inputs_outputs(): - """Test constructing a map with unique inputs and outputs. - - This test case verifies the behavior of the construct_input_output_map() - method in PromptBasedDatasetGenerator when all generated examples have unique - inputs and outputs. - - Attributes: - api_key (str): The fake API key used for testing. - expected_output (dict): The expected input-output map to be constructed. - """ - # Set a fake API key in the environment variable for testing purposes. - os.environ["OPENAI_API_KEY"] = "fake_api_key" - - # Initialize the PromptBasedDatasetGenerator with filter_duplicated_examples=True. - with tempfile.TemporaryDirectory() as cache_dir: - data_generator = PromptBasedDatasetGenerator( - filter_duplicated_examples=True, cache_root=cache_dir - ) - - # Create a list of generated examples with unique inputs and outputs. - generated_examples = [ - Example(input_col="apple", output_col="A"), - Example(input_col="banana", output_col="B"), - Example(input_col="orange", output_col="O"), - ] - - # Call the construct_input_output_map() - # method to create the input-output map. - input_output_map = data_generator.construct_input_output_map(generated_examples) - - # The expected input-output map after - # constructing it from the generated examples. - expected_output = { - "apple": Counter({"A": 1}), - "banana": Counter({"B": 1}), - "orange": Counter({"O": 1}), - } - - # Assertions to verify that the input-output - # map matches the expected output. - assert input_output_map == expected_output - - # Collect garbage to release memory - # resources after the test. - gc.collect() - - -def test_construct_map_with_empty_examples_list(): - """Test constructing a map with an empty list of inputs and outputs. - - This test case verifies the behavior of the construct_input_output_map() - method in PromptBasedDatasetGenerator when no generated examples are available. - - Attributes: - api_key (str): The fake API key used for testing. - """ - # Set a fake API key in the environment variable for testing purposes. - os.environ["OPENAI_API_KEY"] = "fake_api_key" - - # Initialize the PromptBasedDatasetGenerator with filter_duplicated_examples=True. - with tempfile.TemporaryDirectory() as cache_dir: - data_generator = PromptBasedDatasetGenerator( - filter_duplicated_examples=True, cache_root=cache_dir - ) - - # Create an empty list of generated examples. - generated_examples = [] - - # Call the construct_input_output_map() - # method to create the input-output map. - input_output_map = data_generator.construct_input_output_map(generated_examples) - - # The input-output map should be empty - # when there are no generated examples. - assert input_output_map == {} - - # Collect garbage to release memory - # resources after the test. - gc.collect() - - -def test_multi_vote_with_duplicate_inputs_unique_outputs(): - """Test multi-voting with duplicate inputs but unique outputs. - - This test case verifies the application of multi-voting mechanism in the - apply_multi_vote_to_construct_generated_dataset() method of - PromptBasedDatasetGenerator. It specifically tests the scenario when - the input-output map contains duplicate inputs but unique outputs. - - Attributes: - api_key (str): The fake API key used for testing. - """ - # Set a fake API key in the environment variable for testing purposes. - os.environ["OPENAI_API_KEY"] = "fake_api_key" - - # Initialize the PromptBasedDatasetGenerator with filter_duplicated_examples=True. - with tempfile.TemporaryDirectory() as cache_dir: - data_generator = PromptBasedDatasetGenerator( - filter_duplicated_examples=True, cache_root=cache_dir - ) - - # Provide an input-output map with duplicate inputs but unique outputs. - input_output_map = { - "apple": Counter({"A": 1, "E": 1, "D": 1}), - "banana": Counter({"B": 1}), - "orange": Counter({"O": 1}), - } - - # Apply multi-voting mechanism to construct the generated dataset. - generated_dataset = ( - data_generator.apply_multi_vote_to_construct_generated_dataset( - input_output_map - ) - ) - - # Define the expected dataset after multi-voting. - expected_dataset = Dataset.from_dict( - {"input_col": ["apple", "banana", "orange"], "output_col": ["A", "B", "O"]} - ) - - # Verify that the generated dataset matches the expected dataset. - assert are_datasets_identical(generated_dataset, expected_dataset) - - # Collect garbage to release memory resources after the test. - gc.collect() - - -def test_multi_vote_with_duplicate_inputs_duplicate_outputs(): - """Test multi-voting with duplicate inputs and duplicate outputs. - - This test case verifies the application of multi-voting mechanism in the - apply_multi_vote_to_construct_generated_dataset() method of - PromptBasedDatasetGenerator. It specifically tests the scenario when - the input-output map contains duplicate inputs and duplicate outputs. - - Attributes: - api_key (str): The fake API key used for testing. - """ - # Set a fake API key in the environment variable for testing purposes. - os.environ["OPENAI_API_KEY"] = "fake_api_key" - - # Initialize the PromptBasedDatasetGenerator with filter_duplicated_examples=True. - with tempfile.TemporaryDirectory() as cache_dir: - data_generator = PromptBasedDatasetGenerator( - filter_duplicated_examples=True, cache_root=cache_dir - ) - - # Provide an input-output map with duplicate inputs and duplicate outputs. - input_output_map = { - "apple": Counter({"A": 3, "D": 1, "G": 1}), - "banana": Counter({"B": 2, "C": 1}), - "orange": Counter({"O": 1, "F": 1}), - } - - # Apply multi-voting mechanism to construct the generated dataset. - generated_dataset = ( - data_generator.apply_multi_vote_to_construct_generated_dataset( - input_output_map - ) - ) - - # Define the expected dataset after multi-voting. - expected_dataset = Dataset.from_dict( - {"input_col": ["apple", "banana", "orange"], "output_col": ["A", "B", "O"]} - ) - - # Verify that the generated dataset matches the expected dataset. - assert are_datasets_identical(generated_dataset, expected_dataset) - - # Collect garbage to release memory resources after the test. - gc.collect() - - -def test_multi_vote_with_unique_inputs_outputs(): - """Test multi-voting with unique inputs and outputs. - - This test case verifies the application of the multi-voting mechanism in the - apply_multi_vote_to_construct_generated_dataset() method of - PromptBasedDatasetGenerator. - It specifically tests the scenario when the input-output map contains unique - inputs and outputs. - - Attributes: - api_key (str): The fake API key used for testing. - """ - # Set a fake API key in the environment variable for testing purposes. - os.environ["OPENAI_API_KEY"] = "fake_api_key" - - # Initialize the PromptBasedDatasetGenerator with an empty input-output map. - with tempfile.TemporaryDirectory() as cache_dir: - data_generator = PromptBasedDatasetGenerator(cache_root=cache_dir) - - # Provide an input-output map with unique inputs and outputs. - input_output_map = { - "apple": Counter({"A": 1}), - "banana": Counter({"B": 1}), - "orange": Counter({"O": 1}), - } - - # Apply multi-voting mechanism to construct the generated dataset. - generated_dataset = ( - data_generator.apply_multi_vote_to_construct_generated_dataset( - input_output_map - ) - ) - - # Define the expected dataset after multi-voting. - expected_dataset = Dataset.from_dict( - {"input_col": ["apple", "banana", "orange"], "output_col": ["A", "B", "O"]} - ) - - # Verify that the generated dataset matches the expected dataset. - assert are_datasets_identical(generated_dataset, expected_dataset) - - # Collect garbage to release memory resources after the test. - gc.collect() - - -def test_multi_vote_with_empty_examples_list(): - """Test multi-voting with empty inputs and outputs. - - This test case verifies the application of the multi-voting mechanism in the - apply_multi_vote_to_construct_generated_dataset() method of - PromptBasedDatasetGenerator. - It specifically tests the scenario when the input-output map is empty. - - Attributes: - api_key (str): The fake API key used for testing. - """ - # Initialize the PromptBasedDatasetGenerator with an empty input-output map. - with tempfile.TemporaryDirectory() as cache_dir: - os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator( - cache_root=cache_dir, filter_duplicated_examples=True - ) - - # Set the input-output map to be empty. - input_output_map = {} - - # Apply multi-voting mechanism to construct the generated dataset. - generated_dataset = ( - data_generator.apply_multi_vote_to_construct_generated_dataset( - input_output_map - ) - ) - - # Define the expected dataset after multi-voting (empty dataset). - expected_dataset = Dataset.from_dict({}) - - # Verify that the generated dataset matches - # the expected dataset (empty dataset). - assert are_datasets_identical(generated_dataset, expected_dataset) - - # Collect garbage to release memory resources after the test. - gc.collect() - - -def test_create_all_examples_dataset_and_generated_dataset_with_duplicate_inputs_unique_outputs(): # noqa E501 - """Test constructing generated dataset with duplicate inputs but unique outputs. - - This test case verifies the construction of the generated dataset with duplicate - inputs but unique outputs. The PromptBasedDatasetGenerator object is initialized - with `filter_duplicated_examples=True` to ensure that duplicates are filtered. - - Attributes: - api_key (str): The fake API key used for testing. - """ - # Initialize the PromptBasedDatasetGenerator with `filter_duplicated_examples=True`. - with tempfile.TemporaryDirectory() as cache_dir: - os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator( - filter_duplicated_examples=True, cache_root=cache_dir - ) - - # Provide generated examples with duplicate inputs but unique outputs. - generated_examples = [ - Example(input_col="apple", output_col="A"), - Example(input_col="banana", output_col="B"), - Example(input_col="apple", output_col="E"), - Example(input_col="orange", output_col="O"), - Example(input_col="apple", output_col="D"), - ] - - # Convert the generated examples to the generated dataset. - ( - all_generated_examples_dataset, - generated_dataset, - ) = data_generator.create_all_examples_dataset_and_generated_dataset( - generated_examples - ) - - # Define the expected dataset after conversion (duplicates are filtered). - expected_dataset = Dataset.from_dict( - {"input_col": ["apple", "banana", "orange"], "output_col": ["A", "B", "O"]} - ) - - expected_all_generated_examples_dataset = Dataset.from_dict( - { - "input_col": [example.input_col for example in generated_examples], - "output_col": [example.output_col for example in generated_examples], - } - ) - - # Verify that the generated dataset matches the expected dataset. - assert are_datasets_identical(generated_dataset, expected_dataset) - assert are_datasets_identical( - all_generated_examples_dataset, expected_all_generated_examples_dataset - ) - - # Collect garbage to release memory resources after the test. - gc.collect() - - -def test_create_all_examples_dataset_and_generated_dataset_with_duplicate_inputs_duplicate_outputs(): # noqa E501 - """Test constructing a map with duplicate inputs and duplicate outputs. - - This test case verifies the construction of the generated dataset with duplicate - inputs and duplicate outputs. The PromptBasedDatasetGenerator object is initialized - with `filter_duplicated_examples=True` to ensure that duplicates are filtered. - - Attributes: - api_key (str): The fake API key used for testing. - """ - # Initialize the PromptBasedDatasetGenerator with `filter_duplicated_examples=True`. - with tempfile.TemporaryDirectory() as cache_dir: - os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator( - filter_duplicated_examples=True, cache_root=cache_dir - ) - - # Provide generated examples with duplicate inputs and duplicate outputs. - generated_examples = [ - Example(input_col="apple", output_col="A"), - Example(input_col="banana", output_col="C"), - Example(input_col="apple", output_col="A"), - Example(input_col="banana", output_col="B"), - Example(input_col="apple", output_col="G"), - Example(input_col="apple", output_col="A"), - Example(input_col="orange", output_col="O"), - Example(input_col="apple", output_col="D"), - Example(input_col="banana", output_col="B"), - Example(input_col="orange", output_col="F"), - ] - - # Convert the generated examples to the generated dataset. - ( - all_generated_examples_dataset, - generated_dataset, - ) = data_generator.create_all_examples_dataset_and_generated_dataset( - generated_examples - ) - - # Define the expected dataset after conversion (duplicates are filtered). - expected_dataset = Dataset.from_dict( - {"input_col": ["apple", "banana", "orange"], "output_col": ["A", "B", "O"]} - ) - - expected_all_generated_examples_dataset = Dataset.from_dict( - { - "input_col": [example.input_col for example in generated_examples], - "output_col": [example.output_col for example in generated_examples], - } - ) - - # Verify that the generated dataset matches the expected dataset. - assert are_datasets_identical(generated_dataset, expected_dataset) - assert are_datasets_identical( - all_generated_examples_dataset, expected_all_generated_examples_dataset - ) - - # Collect garbage to release memory resources after the test. - gc.collect() - - -def test_create_all_examples_dataset_and_generated_dataset_with_unique_inputs_outputs(): - """Test constructing a map with unique inputs and outputs. - - This test case verifies the construction of the generated dataset with unique - inputs and outputs. The PromptBasedDatasetGenerator object is initialized with - `filter_duplicated_examples=True` to ensure that duplicates are filtered. - - Attributes: - api_key (str): The fake API key used for testing. - """ - # Initialize the PromptBasedDatasetGenerator with `filter_duplicated_examples=True`. - with tempfile.TemporaryDirectory() as cache_dir: - os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator( - filter_duplicated_examples=True, cache_root=cache_dir - ) - - # Provide generated examples with unique inputs and outputs. - generated_examples = [ - Example(input_col="apple", output_col="A"), - Example(input_col="banana", output_col="B"), - Example(input_col="orange", output_col="O"), - ] - - # Convert the generated examples to the generated dataset. - ( - all_generated_examples_dataset, - generated_dataset, - ) = data_generator.create_all_examples_dataset_and_generated_dataset( - generated_examples - ) - - # Define the expected dataset after conversion (no duplicates to filter). - expected_dataset = Dataset.from_dict( - {"input_col": ["apple", "banana", "orange"], "output_col": ["A", "B", "O"]} - ) - - expected_all_generated_examples_dataset = Dataset.from_dict( - { - "input_col": [example.input_col for example in generated_examples], - "output_col": [example.output_col for example in generated_examples], - } - ) - - # Verify that the generated dataset matches the expected dataset. - assert are_datasets_identical(generated_dataset, expected_dataset) - assert are_datasets_identical( - all_generated_examples_dataset, expected_all_generated_examples_dataset - ) - - # Collect garbage to release memory resources after the test. - gc.collect() - - -def test_create_all_examples_dataset_and_generated_dataset_with_empty_examples_list(): - """Test constructing a map with empty inputs and outputs. - - This test case verifies the construction of the generated dataset when the - generated_examples list is empty. The PromptBasedDatasetGenerator object is - initialized with `filter_duplicated_examples=True` to ensure that duplicates - are filtered. - - Attributes: - api_key (str): The fake API key used for testing. - """ - # Initialize the PromptBasedDatasetGenerator with `filter_duplicated_examples=True`. - with tempfile.TemporaryDirectory() as cache_dir: - os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator( - filter_duplicated_examples=True, cache_root=cache_dir - ) - - # Provide an empty list of generated examples. - generated_examples = [] - - # Convert the empty generated examples to the generated dataset. - ( - all_generated_examples_dataset, - generated_dataset, - ) = data_generator.create_all_examples_dataset_and_generated_dataset( - generated_examples - ) - - # Define the expected dataset (empty dataset when there are no examples). - expected_dataset = Dataset.from_dict({}) - - expected_all_generated_examples_dataset = Dataset.from_dict( - { - "input_col": [example.input_col for example in generated_examples], - "output_col": [example.output_col for example in generated_examples], - } - ) - - # Verify that the generated dataset matches the expected dataset. - assert are_datasets_identical(generated_dataset, expected_dataset) - assert are_datasets_identical( - all_generated_examples_dataset, expected_all_generated_examples_dataset - ) - - # Collect garbage to release memory resources after the test. - gc.collect() - - -def test_load_cache_dataset_with_filter_duplicated_examples(): - """Test the cached dataset loading with filtering duplicated examples. - - This test case verifies the loading of the cached dataset and its filtering - to eliminate duplicated examples. The PromptBasedDatasetGenerator object is - initialized with `filter_duplicated_examples=True`. - - Attributes: - api_key (str): The fake API key used for testing. - """ - # Set up a temporary directory for cache. - with tempfile.TemporaryDirectory() as cache_dir: - os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator( - cache_root=cache_dir, filter_duplicated_examples=True - ) - - # Create a cached dataset and save it to the disk. - examples_cache_path = Path( - data_generator.cache_root / f"generated_examples_{DatasetSplit.TEST.value}" - ) - cached_dataset = Dataset.from_dict( - { - "input_col": ["1", "1", "1", "1", "2", "3"], - "output_col": ["a", "a", "b", "c", "a", "d"], - } - ) - cached_dataset.save_to_disk(examples_cache_path) - - # The generate_dataset_split would first load the cached dataset into - # generated_examples. Then, in the while loop, - # create_all_examples_dataset_and_generated_dataset would be called to - # construct the generated_dataset. Note that filter_duplicated_examples - # is True, so the generated_examples will be filtered to 3 examples - # in generated_dataset. Since expected_num_examples is 3, the while loop - # would exit immediately. So the generated_dataset would be the filtered - # cached dataset. - with patch.object(logger, "info") as mock_info, patch.object( - logger, "warning" - ) as mock_warning: - generated_dataset = data_generator.generate_dataset_split( - expected_num_examples=3, - prompt_spec=MockPromptSpec, - split=DatasetSplit.TEST, - ) - - # Verify that logger.info was called with the correct message. - mock_info.assert_called_once_with( - f"Loading cache from {str(examples_cache_path)}." - ) - mock_warning.assert_not_called() - - # Define the expected filtered dataset after loading the cache. - excepted_generated_dataset = Dataset.from_dict( - { - "input_col": ["1", "2", "3"], - "output_col": ["a", "a", "d"], - } - ) - - # Verify that the generated dataset matches the expected filtered dataset. - assert are_datasets_identical(generated_dataset, excepted_generated_dataset) - - # Collect garbage to release memory resources after the test. - gc.collect() - - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MOCK_EXAMPLE, -) -def test_load_cache_dataset_with_filter_duplicated_examples_and_continue_generation( - mocked_generate_example, -): - """Test PromptBasedDatasetGenerator can load cache and continue generation. - - This test case verifies the ability of PromptBasedDatasetGenerator to - load a cached dataset and continue generation when - `filter_duplicated_examples` is True. The PromptBasedDatasetGenerator - object is initialized with `filter_duplicated_examples=True`. - - Attributes: - api_key (str): The fake API key used for testing. - """ - # Set up a temporary directory for cache. - with tempfile.TemporaryDirectory() as cache_dir: - os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator( - cache_root=cache_dir, filter_duplicated_examples=True - ) - - # Create cached examples and save them to the disk. - examples_cache_path = ( - Path(cache_dir) / f"generated_examples_{DatasetSplit.TEST.value}" - ) - cached_examples = Dataset.from_dict( - { - "input_col": ["1", "1", "1", "1", "2", "3"], - "output_col": ["a", "a", "b", "c", "a", "d"], - } - ) - cached_examples.save_to_disk(examples_cache_path) - - # The generate_dataset_split would first load the cached dataset into - # generated_examples. Then, in the while loop, - # create_all_examples_dataset_and_generated_dataset would be called to - # construct the generated_dataset. Note that filter_duplicated_examples - # is True, so the generated_examples will be filtered to 3 examples - # in generated_dataset. Since expected_num_examples is 4, the generation - # would continue, and the max_batch_size = 1. After one batch of API calls, - # generated_dataset meets the requirement and stop generation. - with patch.object(logger, "info") as mock_info, patch.object( - logger, "warning" - ) as mock_warning: - generated_dataset = data_generator.generate_dataset_split( - expected_num_examples=4, - prompt_spec=MockPromptSpec, - split=DatasetSplit.TEST, - ) - - # Verify that logger.info was called with - # the correct message for loading cache. - info_list = [each.args[0] for each in mock_info.call_args_list] - assert info_list[0] == f"Loading cache from {str(examples_cache_path)}." - # The first logger.info is for loading cache, and there are - # 5 * 2 additional logger.info messages in extract_responses. - assert len(info_list) == 1 + 5 * 2 - mock_warning.assert_not_called() - - # Define the expected generated dataset after continuing generation. - excepted_generated_dataset = Dataset.from_dict( - { - "input_col": ["1", "2", "3", "6"], - "output_col": ["a", "a", "d", "f"], - } - ) - - # Verify that the generated dataset matches the expected dataset. - assert are_datasets_identical(generated_dataset, excepted_generated_dataset) - - # Verify that the API was called once to generate responses. - assert mocked_generate_example.call_count == 1 - - # Collect garbage to release memory resources after the test. - gc.collect() - - -""" -These tests validate the generation process with `filter_duplicated_examples=True`. - -These tests collaborate with the `MockBatchDifferentCompletions().mock_completions` -function to imitate the generation process of the PromptBasedDatasetGenerator. - -The first five tests check the generation of a single dataset split using -a shared PromptBasedDatasetGenerator with the following settings: - - max_batch_size = 2 - - responses_per_request = 3 - - filter_duplicated_examples = True - - expected_num_examples = 5 - -In the first API call, the generator produces 2 * 3 = 6 responses. -After filtering duplicates, the generated_dataset will be: - Dataset.from_dict( - { - "input_col": ["1", "2"], - "output_col": ["a", "a"], - }) - -max_batch_size = (expected_num_examples - len(generated_dataset)) -/ responses_per_request = (5 - 2) / 3 = 1. - -The second API call reduces max_batch_size to 1 and generates 3 more responses. -After filtering duplicates, the generated_dataset will be: - Dataset.from_dict( - { - "input_col": ["1", "2", "3"], - "output_col": ["a", "a", "a"], - }) - -The third API call again uses max_batch_size = 1 and generates another 3 responses. -After filtering duplicates, the generated_dataset will be: - Dataset.from_dict( - { - "input_col": ["1", "2", "3"], - "output_col": ["b", "a", "a"], - }) - -The fourth API call also uses max_batch_size = 1 and generates 3 responses. -After filtering duplicates, the generated_dataset will be: - Dataset.from_dict( - { - "input_col": ["1", "2", "3", "4", "5"], - "output_col": ["b", "a", "a", "c", "a"], - }) - -The test suite contains five test cases, each using a different -PromptBasedDatasetGenerator. -These generators have the same settings (max_batch_size = 2, responses_per_request = 3, -expected_num_examples = 5, filter_duplicated_examples = True), but their max_api_calls -attribute is set to 2, 3, 4, 5, and unlimited, respectively. - -Each test runs the generation of its generator and verifies that the -generated dataset matches the expected result. -""" - - -api_key = "fake_api_key" -prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) -split = DatasetSplit.TRAIN -filter_duplicated_examples = True -expected_num_examples = 5 -max_batch_size = 2 -responses_per_request = 3 - - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MockBatchDifferentCompletions().mock_completions, -) -def test_generator_with_filter_first_batch(mocked_generate_example): - """Test PromptBasedDatasetGenerator with filter methods in the first batch. - - This test verifies the behavior of the PromptBasedDatasetGenerator with - filter methods in the first batch of API calls. It initializes an - PromptBasedDatasetGenerator with specific settings, limiting the number - of API calls to 2. After running the generation process, the test - checks whether the generated dataset matches the expected - result after the second API call. The test also ensures that the - number of calls to the API mock matches the expected number. - - Note: The first API call's max_batch_size is 2, generating 6 responses. - - Args: - mocked_generate_example (MagicMock): The patched function representing the - @patch decorator for generating example responses. - """ - with tempfile.TemporaryDirectory() as cache_dir: - # Initialize the PromptBasedDatasetGenerator with specific settings. - dataset_generator = PromptBasedDatasetGenerator( - max_api_calls=2, - filter_duplicated_examples=filter_duplicated_examples, - cache_root=cache_dir, - max_batch_size=max_batch_size, - responses_per_request=responses_per_request, - ) - - # Generate the dataset split using the initialized generator. - generated_dataset = dataset_generator.generate_dataset_split( - prompt_spec, expected_num_examples, split - ) - - # Assertions for API call count and dataset matching the expected result. - assert mocked_generate_example.call_count == 1 - assert dataset_generator.api_call_counter == 2 - - # Define the expected dataset based on the given mock responses. - expected_dataset = Dataset.from_dict( - { - "input_col": ["1", "2"], - "output_col": ["a", "a"], - } - ) - - # Verify the generated dataset matches the expected dataset. - assert are_datasets_identical(generated_dataset, expected_dataset) - - # Collect garbage to release memory resources after the test. - gc.collect() - - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MockBatchDifferentCompletions().mock_completions, -) -def test_generator_with_filter_second_batch(mocked_generate_example): - """Test PromptBasedDatasetGenerator with filter methods in the second batch. - - This test verifies the behavior of the PromptBasedDatasetGenerator with filter - methods in the second batch of API calls. It initializes an - PromptBasedDatasetGenerator with specific settings, limiting the number of - API calls to 3. After running the generation process, the test checks - whether the generated dataset matches the expected result after the - second API call. The test also ensures that the number of calls to the - API mock matches the expected number. - - Note: The first API call's max_batch_size is 2, generating 6 responses. - The second API call's max_batch_size is 1, generating 3 responses. - - Args: - mocked_generate_example (MagicMock): The patched function representing the - @patch decorator for generating example responses. - """ - with tempfile.TemporaryDirectory() as cache_dir: - # Initialize the PromptBasedDatasetGenerator with specific settings. - dataset_generator = PromptBasedDatasetGenerator( - max_api_calls=3, - filter_duplicated_examples=filter_duplicated_examples, - cache_root=cache_dir, - max_batch_size=max_batch_size, - responses_per_request=responses_per_request, - ) - - # Generate the dataset split using the initialized generator. - generated_dataset = dataset_generator.generate_dataset_split( - prompt_spec, expected_num_examples, split - ) - - # Assertions for API call count and dataset matching the expected result. - assert mocked_generate_example.call_count == 2 - assert dataset_generator.api_call_counter == 3 - - # Define the expected dataset based on the given mock responses. - expected_dataset = Dataset.from_dict( - { - "input_col": ["1", "2", "3"], - "output_col": ["a", "a", "a"], - } - ) - - # Verify the generated dataset matches the expected dataset. - assert are_datasets_identical(generated_dataset, expected_dataset) - - # Collect garbage to release memory resources after the test. - gc.collect() - - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MockBatchDifferentCompletions().mock_completions, -) -def test_generator_with_filter_third_batch(mocked_generate_example): - """Test PromptBasedDatasetGenerator with filter methods in the third batch. - - This test verifies the behavior of the PromptBasedDatasetGenerator with - filter methods in the third batch of API calls. It initializes an - PromptBasedDatasetGenerator with specific settings, limiting the number - of API calls to 4. After running the generation process, the test - checks whether the generated dataset matches the expected - result after the third API call. The test also ensures that the - number of calls to the API mock matches the expected number. - - Note: The first API call's max_batch_size is 2, generating 6 responses. - The second API call's max_batch_size is 1, generating 3 responses. - The third API call's max_batch_size is 1, generating 3 responses. - - Args: - mocked_generate_example (MagicMock): The patched function representing the - @patch decorator for generating example responses. - """ - with tempfile.TemporaryDirectory() as cache_dir: - # Reset the mock responses to ensure predictable behavior. - - # Initialize the PromptBasedDatasetGenerator with specific settings. - dataset_generator = PromptBasedDatasetGenerator( - max_api_calls=4, - filter_duplicated_examples=filter_duplicated_examples, - cache_root=cache_dir, - max_batch_size=max_batch_size, - responses_per_request=responses_per_request, - ) - - # Generate the dataset split using the initialized generator. - generated_dataset = dataset_generator.generate_dataset_split( - prompt_spec, expected_num_examples, split - ) - - # Assertions for API call count and dataset matching the expected result. - assert mocked_generate_example.call_count == 3 - assert dataset_generator.api_call_counter == 4 - - # Define the expected dataset based on the given mock responses. - expected_dataset = Dataset.from_dict( - { - "input_col": ["1", "2", "3"], - "output_col": ["b", "a", "a"], - } - ) - - # Verify the generated dataset matches the expected dataset. - assert are_datasets_identical(generated_dataset, expected_dataset) - - # Collect garbage to release memory resources after the test. - gc.collect() - - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MockBatchDifferentCompletions().mock_completions, -) -def test_generator_with_filter_forth_batch(mocked_generate_example): - """Test PromptBasedDatasetGenerator with filter methods in the forth batch. - - This test verifies the behavior of the PromptBasedDatasetGenerator with - filter methods in the forth batch of API calls. It initializes an - PromptBasedDatasetGenerator with specific settings, limiting the number - of API calls to 5. After running the generation process, the test checks - whether the generated dataset matches the expected result after the - forth API call. The test also ensures that the number of calls to the - API mock matches the expected number. - - Note: The first API call's max_batch_size is 2, generating 6 responses. - The second API call's max_batch_size is 1, generating 3 responses. - The third API call's max_batch_size is 1, generating 3 responses. - The forth and last API call's max_batch_size is 1, generating 3 responses. - - Args: - mocked_generate_example (MagicMock): The patched function representing the - @patch decorator for generating example responses. - """ - with tempfile.TemporaryDirectory() as cache_dir: - # Reset the mock responses to ensure predictable behavior. - - # Initialize the PromptBasedDatasetGenerator with specific settings. - dataset_generator = PromptBasedDatasetGenerator( - max_api_calls=5, - filter_duplicated_examples=filter_duplicated_examples, - cache_root=cache_dir, - max_batch_size=max_batch_size, - responses_per_request=responses_per_request, - ) - - # Generate the dataset split using the initialized generator. - generated_dataset = dataset_generator.generate_dataset_split( - prompt_spec, expected_num_examples, split - ) - - # Assertions for API call count and dataset matching the expected result. - assert mocked_generate_example.call_count == 4 - assert dataset_generator.api_call_counter == 5 - - # Define the expected dataset based on the given mock responses. - expected_dataset = Dataset.from_dict( - { - "input_col": ["1", "2", "3", "4", "5"], - "output_col": ["b", "a", "a", "c", "a"], - } - ) - - # Verify the generated dataset matches the expected dataset. - assert are_datasets_identical(generated_dataset, expected_dataset) - - # Collect garbage to release memory resources after the test. - gc.collect() - - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MockBatchDifferentCompletions().mock_completions, -) -def test_generator_with_filter_unlimited_api_calls(mocked_generate_example): - """Test PromptBasedDatasetGenerator with filter methods and unlimited API calls. - - This test verifies the behavior of the PromptBasedDatasetGenerator with - filter methods and unlimited API calls. It initializes the generator - with specific settings but does not limit the number of API calls. - After running the generation process, the test checks whether - the generated dataset matches the expected result after the - final API call. The test also ensures that the number of calls - to the API mock matches the expected number. - - Note: The first API call's max_batch_size is 2, generating 6 responses. - The second API call's max_batch_size is 1, generating 3 responses. - The third API call's max_batch_size is 1, generating 3 responses. - The forth and last API call's max_batch_size is 1, generating 3 responses. - After the forth batch, the generation ends. No further API calls are required. - - Args: - mocked_generate_example (MagicMock): The patched function representing the - @patch decorator for generating example responses. - """ - with tempfile.TemporaryDirectory() as cache_dir: - # Reset the mock responses to ensure predictable behavior. - - # Initialize the PromptBasedDatasetGenerator with - # specific settings and unlimited API calls. - dataset_generator = PromptBasedDatasetGenerator( - filter_duplicated_examples=filter_duplicated_examples, - cache_root=cache_dir, - max_batch_size=max_batch_size, - responses_per_request=responses_per_request, - ) - - # Generate the dataset split using the initialized generator. - generated_dataset = dataset_generator.generate_dataset_split( - prompt_spec, expected_num_examples, split - ) - - # Assertions for API call count and dataset matching the expected result. - assert mocked_generate_example.call_count == 4 - assert dataset_generator.api_call_counter == 5 - - # Define the expected dataset based on the given mock responses. - expected_dataset = Dataset.from_dict( - { - "input_col": ["1", "2", "3", "4", "5"], - "output_col": ["b", "a", "a", "c", "a"], - } - ) - - # Verify the generated dataset matches the expected dataset. - assert are_datasets_identical(generated_dataset, expected_dataset) - - # Collect garbage to release memory resources after the test. - gc.collect() - - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MockBatchDifferentCompletions(length=5).mock_completions, -) -def test_generator_with_filter_to_generate_datasetdict(mocked_generate_example): - """Test PromptBasedDatasetGenerator with filter methods to generate a DatasetDict. - - This test checks the generation of a DatasetDict using the - PromptBasedDatasetGenerator - with filter methods. It initializes the generator with specific settings and - expected_num_examples for each split. The test verifies the generated dataset - dictionaries for the train, val, and test splits match the expected results. - - The generation involves multiple API calls, and filtering duplicates is applied to - the generated examples at each step. The API calls continue until the number of - generated examples meets the expected_num_examples for each split or reaches the - maximum allowed API calls. - - Args: - mocked_generate_example (MagicMock): The patched function representing the - @patch decorator for generating example responses. - """ - with tempfile.TemporaryDirectory() as cache_dir: - # Reset the mock responses to ensure predictable behavior. - - # Initialize the PromptBasedDatasetGenerator with - # specific settings and limited API calls. - dataset_generator = PromptBasedDatasetGenerator( - filter_duplicated_examples=filter_duplicated_examples, - cache_root=cache_dir, - max_batch_size=max_batch_size, - responses_per_request=responses_per_request, - max_api_calls=7, - ) - - # Generate the DatasetDict using the initialized generator. - generated_dataset_dict = dataset_generator.generate_dataset_dict( - prompt_spec, - expected_num_examples={ - DatasetSplit.TRAIN: 4, - DatasetSplit.VAL: 4, - DatasetSplit.TEST: 2, - }, - ) - - # Assertions for API call count and dataset - # dictionaries matching the expected results. - assert mocked_generate_example.call_count == 5 - assert dataset_generator.api_call_counter == 7 - - # Define the expected dataset dictionaries - # based on the given mock responses. - expected_dataset_dict = datasets.DatasetDict( - { - "train": Dataset.from_dict( - { - "input_col": ["1", "2", "3", "4", "5"], - "output_col": ["b", "a", "a", "c", "a"], - } - ), - "val": Dataset.from_dict( - { - "input_col": ["1", "2"], - "output_col": ["a", "a"], - } - ), - "test": Dataset.from_dict( - { - "input_col": [], - "output_col": [], - } - ), - } - ) - - # Verify the generated DatasetDict matches the expected DatasetDict. - assert are_dataset_dicts_identical( - generated_dataset_dict, expected_dataset_dict - ) - - # Collect garbage to release memory resources after the test. - gc.collect() diff --git a/tests/dataset_generator_without_filter_test.py b/tests/dataset_generator_without_filter_test.py deleted file mode 100644 index 3ed217c7a..000000000 --- a/tests/dataset_generator_without_filter_test.py +++ /dev/null @@ -1,1036 +0,0 @@ -"""Testing DatasetGenerator through PromptBasedDatasetGenerator.""" - -import gc -import logging -import os -import tempfile -from functools import partial -from pathlib import Path -from unittest.mock import patch - -import pytest -from datasets import Dataset - -from prompt2model.dataset_generator.base import DatasetSplit -from prompt2model.dataset_generator.prompt_based import ( - Example, - PromptBasedDatasetGenerator, -) -from prompt2model.prompt_parser import MockPromptSpec, TaskType -from test_helpers import ( - MockCompletion, - UnknownGpt3Exception, - are_datasets_identical, - mock_batch_api_response_identical_completions, -) - -logger = logging.getLogger("DatasetGenerator") - -MOCK_CLASSIFICATION_EXAMPLE = partial( - mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "output": "1"}', -) -MOCK_WRONG_KEY_EXAMPLE = partial( - mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "label": "1"}', -) -MOCK_INVALID_JSON = partial( - mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "output": "1}', -) - -MOCK_CLASSIFICATION_EXAMPLE = partial( - mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "output": "1"}', -) -MOCK_WRONG_KEY_EXAMPLE = partial( - mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "label": "1"}', -) -MOCK_INVALID_JSON = partial( - mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "output": "1}', -) - -MOCK_CLASSIFICATION_EXAMPLE = partial( - mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "output": "1"}', -) -MOCK_WRONG_KEY_EXAMPLE = partial( - mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "label": "1"}', -) -MOCK_INVALID_JSON = partial( - mock_batch_api_response_identical_completions, - content='{"input": "This is a great movie!", "output": "1}', -) - - -def check_generate_dataset(dataset_generator: PromptBasedDatasetGenerator): - """Test the `generate_dataset_split()` function of `PromptBasedDatasetGenerator`. - - This function generates a Dataset for a specified split of the data - (train, validation, or test) using a simple prompt specification - and saves them to a temporary directory. Then, it checks that the - generated dataset has the expected number of examples, the expected - columns, and each example is not empty. - - Args: - dataset_generator: The dataset_generator will be tested - with limited max_api_calls or unlimited max_api_calls. - """ - prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) - split = DatasetSplit.TRAIN - expected_num_examples = 29 - # If expected_num_examples >= max_api_calls, the returned dataset's - # length will be less than or equal to max_api_calls. - dataset = dataset_generator.generate_dataset_split( - prompt_spec, expected_num_examples, split - ) - # Since each API call would return one completion object with 5 responses - # and some of the responses are invalid JSON objects, the upper bound of - # the length of the dataset is expected_num_examples + 5, where 5 is the - # default number of responses per API call. - assert len(dataset) < expected_num_examples + 5 - expected_columns = {"input_col", "output_col"} - assert set(dataset.column_names) == expected_columns - return dataset - - -def check_generate_dataset_dict(dataset_generator: PromptBasedDatasetGenerator): - """Test the `generate_dataset_dict()` function of `PromptBasedDatasetGenerator`. - - This function generates movie comments datasets by creating a specified - number of examples for each split of the data, which includes train, - validation, and test. It uses a simple prompt specification and saves the - generated datasets to a temporary directory. Afterward, the function - checks whether the dataset dictionary contains all the expected keys, - each split has the anticipated number of examples, every dataset has - the anticipated columns, each example is not empty, and whether - the dataset dictionary is saved to the output directory. - - Args: - dataset_generator: The dataset_generator will be tested - with limited max_api_calls or unlimited max_api_calls. - """ - prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) - expected_num_examples = { - DatasetSplit.TRAIN: 50, - DatasetSplit.VAL: 24, - DatasetSplit.TEST: 26, - } - with tempfile.TemporaryDirectory() as tmpdirname: - output_dir = os.path.join(tmpdirname, "output") - dataset_dict = dataset_generator.generate_dataset_dict( - prompt_spec=prompt_spec, - expected_num_examples=expected_num_examples, - output_dir=output_dir, - ) - - assert set(dataset_dict.keys()) == {"train", "val", "test"} - for split, num in expected_num_examples.items(): - # As explained previously, the upper bound of the length of - # generated dataset is expected_num_examples + 5, where - # 5 is the default number of responses per API call. - assert len(dataset_dict[split.value]) < num + 5 - expected_columns = {"input_col", "output_col"} - for dataset in dataset_dict.values(): - assert set(dataset.column_names) == expected_columns - assert os.path.isdir(output_dir) - assert set(os.listdir(output_dir)) == { - "dataset_dict.json", - "test", - "train", - "val", - } - return dataset_dict - - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MOCK_CLASSIFICATION_EXAMPLE, -) -def test_generator_without_filter(mocked_generate_example): - """Test classification dataset generation using the PromptBasedDatasetGenerator. - - This function first tests unlimited generation. Then, it tests generation - when expected_num_examples >= max_api_calls. In the second test, the API agent - will only be called max_api_calls times. - - Args: - mocked_generate_example: The function representing the @patch function. - """ - with tempfile.TemporaryDirectory() as cache_dir: - os.environ["OPENAI_API_KEY"] = "fake_api_key" - unlimited_dataset_generator = PromptBasedDatasetGenerator( - filter_duplicated_examples=False, cache_root=cache_dir - ) - unlimited_generated_dataset = check_generate_dataset( - unlimited_dataset_generator - ) - # The default responses_per_request is 5. So each API call will return - # 5 responses, i.e. 5 choices in openai.Completion.choices. - # Each API call will return 5 responses, and each response is a valid JSON. - # So the unlimited_dataset_generator will call the API (29 // 5 + 1) times. - assert unlimited_dataset_generator.api_call_counter == (29 // 5 + 1) - # The default batch_size is 5. So generate_batch_completion - # will be called 2 times with first batch_size = 5 and second batch_size = 1. - assert mocked_generate_example.call_count == 2 - # Since all the responses are valid JSON and the api_call_counter is 6, - # the unlimited_generated_dataset will contain 30 examples. - assert len(unlimited_generated_dataset) == 30 - - # Refresh the call_count and dataset_generator. - with tempfile.TemporaryDirectory() as cache_dir: - mocked_generate_example.call_count = 0 - unlimited_dataset_generator = PromptBasedDatasetGenerator( - filter_duplicated_examples=False, cache_root=cache_dir - ) - unlimited_generated_dataset_dict = check_generate_dataset_dict( - unlimited_dataset_generator - ) - - # Each API call returns five responses. So the unlimited_dataset_generator will - # call the API (50 // 5 + 24 // 5 + 1 + 26 // 5 + 1) = 21 times. - assert unlimited_dataset_generator.api_call_counter == ( - 50 // 5 + 24 // 5 + 1 + 26 // 5 + 1 - ) - # The default batch_size is 5. So generate_batch_completion - # will be called 2 times for 50 examples in the train split, - # 1 time for 24 examples in the validation split, - # and 2 times for 26 examples in the test split. - assert mocked_generate_example.call_count == 2 + 1 + 2 - - # Each API call returns 5 responses, and each response is a valid JSON. - # So the unlimited_generated_dataset_dict will contain (50, 25, 30) examples. - assert len(unlimited_generated_dataset_dict["train"]) == 50 - assert len(unlimited_generated_dataset_dict["val"]) == 25 - assert len(unlimited_generated_dataset_dict["test"]) == 30 - - with tempfile.TemporaryDirectory() as cache_dir: - # Refresh the call_count. - mocked_generate_example.call_count = 0 - - limited_dataset_generator = PromptBasedDatasetGenerator( - max_api_calls=3, filter_duplicated_examples=False, cache_root=cache_dir - ) - limited_generated_dataset = check_generate_dataset(limited_dataset_generator) - # The max_api_calls is 3. So the limited_dataset_generator calls the - # API 3 times. Each API call returns 5 responses. So the - # limited_dataset_generator will have 3 * 5 = 15 examples. - assert len(limited_generated_dataset) == 15 - - # The default batch_size is 5. So generate_batch_completion - # will be called only once. - assert mocked_generate_example.call_count == 1 - - # Each API call returns 5 responses, so the limited_dataset_generator - # will use up all the available API calls. - assert limited_dataset_generator.api_call_counter == 3 - - # Each API call returns 5 responses, and each response is a valid JSON. - # So the limited_generated_dataset will contain 15 examples. - assert len(limited_generated_dataset) == 15 - - with tempfile.TemporaryDirectory() as cache_dir: - # Refresh the call_count and create a new limited_dataset_generator. - mocked_generate_example.call_count = 0 - limited_dataset_generator = PromptBasedDatasetGenerator( - max_api_calls=13, filter_duplicated_examples=False, cache_root=cache_dir - ) - - limited_generated_dataset_dict = check_generate_dataset_dict( - limited_dataset_generator - ) - # Since the max_api_calls is 13, the limited_dataset_generator cannot - # generate the whole dataset_dict and will call the API 13 times. - assert limited_dataset_generator.api_call_counter == 13 - - # The train split has 50 examples, so it will call the API 10 times and call - # generate_batch_completion 2 times. - # The validation split has 24 examples, but there are only 3 API calls - # left, so it will call the API 3 times and call - # generate_batch_completion 1 time. - # The test split has 26 examples, but there are no more API calls left, - # so it will not call generate_batch_completion. - assert mocked_generate_example.call_count == 2 + 1 + 0 - - # Each API call returns 5 responses, and each response is a valid JSON. - # So the generated_dataset_dict will contain (50, 15, 0) examples. - assert len(limited_generated_dataset_dict["train"]) == 50 - assert len(limited_generated_dataset_dict["val"]) == 15 - assert len(limited_generated_dataset_dict["test"]) == 0 - gc.collect() - - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MOCK_WRONG_KEY_EXAMPLE, -) -def test_wrong_key_example(mocked_generate_example): - """Test PromptBasedDatasetGenerator when the agent returns wrong keys. - - Args: - mocked_generate_example: The function representing the @patch function. - """ - # Init the PromptBasedDatasetGenerator with `max_api_calls = 3`. - with tempfile.TemporaryDirectory() as cache_dir: - dataset_generator = PromptBasedDatasetGenerator( - 3, filter_duplicated_examples=False, cache_root=cache_dir - ) - prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) - expected_num_examples = 1 - split = DatasetSplit.TRAIN - generated_dataset = dataset_generator.generate_dataset_split( - prompt_spec, expected_num_examples, split - ) - assert mocked_generate_example.call_count == 3 - expected_dataset = Dataset.from_dict({"input_col": [], "output_col": []}) - assert are_datasets_identical(expected_dataset, generated_dataset) - gc.collect() - - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MOCK_INVALID_JSON, -) -def test_invalid_json_response(mocked_generate_example): - """Test PromptBasedDatasetGenerator when the agent returns invalid JSON responses. - - Args: - mocked_generate_example: The function representing the @patch function. - """ - # Init the PromptBasedDatasetGenerator with `max_api_calls = 3`. - with tempfile.TemporaryDirectory() as cache_dir: - dataset_generator = PromptBasedDatasetGenerator( - 3, filter_duplicated_examples=False, cache_root=cache_dir - ) - prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) - expected_num_examples = 1 - split = DatasetSplit.VAL - dataset = dataset_generator.generate_dataset_split( - prompt_spec, expected_num_examples, split - ) - assert mocked_generate_example.call_count == 3 - expected_dataset = Dataset.from_dict({"input_col": [], "output_col": []}) - assert are_datasets_identical(dataset, expected_dataset) - gc.collect() - - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=UnknownGpt3Exception(), -) -def test_unexpected_examples_of_gpt(mocked_generate_example): - """Test PromptBasedDatasetGenerator when the agent returns unexpected examples. - - This function tests the scenario when the agent raises an UnknownGpt3Exception - during dataset generation. The test ensures that the exception is correctly raised. - - Args: - mocked_generate_example: The function representing the @patch function. - """ - os.environ["OPENAI_API_KEY"] = "fake_api_key" - # Init the PromptBasedDatasetGenerator with `max_api_calls = 3`. - with pytest.raises( - UnknownGpt3Exception - ), tempfile.TemporaryDirectory() as cache_dir: - dataset_generator = PromptBasedDatasetGenerator( - max_api_calls=3, filter_duplicated_examples=False, cache_root=cache_dir - ) - prompt_spec = MockPromptSpec(TaskType.TEXT_GENERATION) - expected_num_examples = 1 - split = DatasetSplit.TEST - _ = dataset_generator.generate_dataset_split( - prompt_spec, expected_num_examples, split - ) - assert mocked_generate_example.call_count == 1 - gc.collect() - - -def test_create_all_examples_dataset_and_generated_dataset_with_duplicate_inputs_unique_outputs(): # noqa: 501 - """Test constructing the generated dataset with duplicate inputs but unique outputs. - - This function tests the scenario when the generator has generated examples with - duplicate inputs but unique outputs. It ensures that the generator successfully - converts the generated examples into a generated dataset while preserving the - correct mappings between input and output. - - The test uses the `PromptBasedDatasetGenerator` with - `filter_duplicated_examples=False`. - The `generating_split` attribute of the generator is set to `DatasetSplit.TEST`, - and the `generated_examples` list contains examples with some duplicate inputs but - unique outputs. - - The function then calls the `create_all_examples_dataset_and_generated_dataset()` - method to create the generated dataset. - - Finally, the function checks whether the generated dataset matches the expected - dataset constructed from the input examples. - - Note: The test uses a temporary directory as the cache root to ensure that the cache - directory is cleaned up after the test finishes. - - Raises: - AssertionError: If the generated dataset does not match the expected dataset. - """ - with tempfile.TemporaryDirectory() as cache_dir: - os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator( - filter_duplicated_examples=False, cache_root=cache_dir - ) - generated_examples = [ - Example(input_col="apple", output_col="A"), - Example(input_col="banana", output_col="B"), - Example(input_col="apple", output_col="E"), - Example(input_col="orange", output_col="O"), - Example(input_col="apple", output_col="D"), - ] - ( - all_generated_examples_dataset, - generated_dataset, - ) = data_generator.create_all_examples_dataset_and_generated_dataset( - generated_examples - ) - expected_dataset = Dataset.from_dict( - { - "input_col": [example.input_col for example in generated_examples], - "output_col": [example.output_col for example in generated_examples], - } - ) - assert are_datasets_identical(all_generated_examples_dataset, expected_dataset) - assert are_datasets_identical(generated_dataset, expected_dataset) - gc.collect() - - -def test_create_all_examples_dataset_and_generated_dataset_with_duplicate_inputs_duplicate_outputs(): # noqa: 501 - """Test constructing a map with duplicate inputs and duplicate outputs. - - This function tests the scenario when the generator has generated examples with - duplicate inputs and duplicate outputs. It ensures that the generator successfully - converts the generated examples into a generated dataset while preserving the - correct mappings between input and output. - - The test uses the `PromptBasedDatasetGenerator` with - `filter_duplicated_examples=False`. - The `generating_split` attribute of the generator is set to `DatasetSplit.TEST`, - and the `generated_examples` list contains examples with both duplicate inputs and - duplicate outputs. The function then calls the - `create_all_examples_dataset_and_generated_dataset()` method to create the generated - dataset. - - Finally, the function checks whether the generated dataset matches the expected - dataset constructed from the input examples. - - Note: The test uses a temporary directory as the cache root to ensure that the cache - directory is cleaned up after the test finishes. - - Raises: - AssertionError: If the generated dataset does not match the expected dataset. - """ - with tempfile.TemporaryDirectory() as cache_dir: - os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator( - filter_duplicated_examples=False, cache_root=cache_dir - ) - generated_examples = [ - Example(input_col="apple", output_col="A"), - Example(input_col="banana", output_col="C"), - Example(input_col="apple", output_col="A"), - Example(input_col="banana", output_col="B"), - Example(input_col="apple", output_col="G"), - Example(input_col="apple", output_col="A"), - Example(input_col="orange", output_col="O"), - Example(input_col="apple", output_col="D"), - Example(input_col="banana", output_col="B"), - Example(input_col="orange", output_col="F"), - ] - ( - all_generated_examples_dataset, - generated_dataset, - ) = data_generator.create_all_examples_dataset_and_generated_dataset( - generated_examples - ) - expected_dataset = Dataset.from_dict( - { - "input_col": [example.input_col for example in generated_examples], - "output_col": [example.output_col for example in generated_examples], - } - ) - assert are_datasets_identical(all_generated_examples_dataset, expected_dataset) - assert are_datasets_identical(generated_dataset, expected_dataset) - gc.collect() - - -def test_create_all_examples_dataset_and_generated_dataset_with_unique_inputs_outputs(): - """Test constructing a map with unique inputs and outputs. - - This function tests the scenario when the generator has generated examples with - unique inputs and unique outputs. It ensures that the generator successfully - converts the generated examples into a generated dataset while preserving the - correct mappings between input and output. - - The test uses the `PromptBasedDatasetGenerator` with - `filter_duplicated_examples=False`. - The `generating_split` attribute of the generator is set to `DatasetSplit.TEST`, - and the `generated_examples` list contains examples with unique inputs and - unique outputs. The function then calls the - `create_all_examples_dataset_and_generated_dataset()` method to create the generated - dataset. - - Finally, the function checks whether the generated dataset matches the expected - dataset constructed from the input examples. - - Note: The test uses a temporary directory as the cache root to ensure that the cache - directory is cleaned up after the test finishes. - - Raises: - AssertionError: If the generated dataset does not match the expected dataset. - """ - with tempfile.TemporaryDirectory() as cache_dir: - os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator( - filter_duplicated_examples=False, cache_root=cache_dir - ) - generated_examples = [ - Example(input_col="apple", output_col="A"), - Example(input_col="banana", output_col="B"), - Example(input_col="orange", output_col="O"), - ] - ( - all_generated_examples_dataset, - generated_dataset, - ) = data_generator.create_all_examples_dataset_and_generated_dataset( - generated_examples - ) - expected_dataset = Dataset.from_dict( - { - "input_col": [example.input_col for example in generated_examples], - "output_col": [example.output_col for example in generated_examples], - } - ) - assert are_datasets_identical(all_generated_examples_dataset, expected_dataset) - assert are_datasets_identical(generated_dataset, expected_dataset) - gc.collect() - - -def test_create_all_examples_dataset_and_generated_dataset_with_empty_examples_list(): - """Test constructing a map with empty inputs and outputs. - - This function tests the scenario when the generator has an empty list of generated - examples. It ensures that the generator successfully converts the empty examples - list into an empty generated dataset. - - The test uses the `PromptBasedDatasetGenerator` with - `filter_duplicated_examples=False`. - The `generating_split` attribute of the generator is set to `DatasetSplit.TEST`, - and the `generated_examples` list is empty. The function then calls the - `create_all_examples_dataset_and_generated_dataset()` method to create the generated - dataset. - - Finally, the function checks whether the generated dataset is empty. - - Note: The test uses a temporary directory as the cache root to ensure that the cache - directory is cleaned up after the test finishes. - - Raises: - AssertionError: If the generated dataset is not empty. - """ - with tempfile.TemporaryDirectory() as cache_dir: - os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator( - filter_duplicated_examples=False, cache_root=cache_dir - ) - generated_examples = [] - ( - all_generated_examples_dataset, - generated_dataset, - ) = data_generator.create_all_examples_dataset_and_generated_dataset( - generated_examples - ) - expected_dataset = Dataset.from_dict( - { - "input_col": [example.input_col for example in generated_examples], - "output_col": [example.output_col for example in generated_examples], - } - ) - assert are_datasets_identical(all_generated_examples_dataset, expected_dataset) - assert are_datasets_identical(generated_dataset, expected_dataset) - gc.collect() - - -def test_compute_batch_size_with_limited_max_api_calls(): - """Test the batch size computation with limited max API calls. - - This function tests the computation of batch size when the generator has limited - max API calls. It covers scenarios where the API calls are close to reaching the - maximum limit and when the API calls are far from the maximum limit. - - The test uses the `PromptBasedDatasetGenerator` with `max_api_calls=28`. The - `api_call_counter` attribute of the generator is set to `26`, and the - `generated_dataset` contains 110 examples. The function then calls the - `compute_batch_size()` method with an `expected_num_examples` of `125`. - - Finally, the function checks whether the computed batch size matches the expected - batch size based on the remaining API calls and the number of examples needed to - reach the expected number of examples. - - Note: The test uses a temporary directory as the cache root to ensure that the cache - directory is cleaned up after the test finishes. - - Raises: - AssertionError: If the computed batch size does - not match the expected batch size. - """ - os.environ["OPENAI_API_KEY"] = "fake_api_key" - with tempfile.TemporaryDirectory() as cache_dir: - data_generator = PromptBasedDatasetGenerator( - max_api_calls=28, cache_root=cache_dir - ) - data_generator.api_call_counter = 26 - generated_dataset = Dataset.from_dict( - { - "input_col": ["1"] * 110, - "output_col": ["2"] * 110, - } - ) - # Default batch size and responses_per_request are both 5. - # So each batch should contain 25 examples. - - # At least (125 - 110) / 5 = 3 API calls needed to get - # more than 125 examples. - - batch_size = data_generator.compute_batch_size( - expected_num_examples=125, generated_dataset=generated_dataset - ) - assert ( - batch_size - == data_generator.max_api_calls - data_generator.api_call_counter - == 28 - 26 - ) - - data_generator.api_call_counter = 20 - batch_size = data_generator.compute_batch_size(125, generated_dataset) - assert ( - batch_size - == ((125 - len(generated_dataset))) / data_generator.responses_per_request - == (125 - 110) / 5 - ) - - data_generator.api_call_counter = 0 - generated_dataset = Dataset.from_dict( - { - "input_col": [1] * 50, - "output_col": [2] * 50, - } - ) - batch_size = data_generator.compute_batch_size(125, generated_dataset) - assert batch_size == data_generator.max_batch_size - gc.collect() - - -def test_compute_batch_size_with_unlimited_max_api_calls(): - """Test the batch size computation with unlimited max API calls. - - This function tests the computation of batch size when the generator has unlimited - max API calls. It covers scenarios where the number of examples needed to reach the - expected number of examples is greater than the default batch size. - - The test uses the `PromptBasedDatasetGenerator` with default `max_api_calls`. The - `generated_dataset` contains 110 examples. The function then calls the - `compute_batch_size()` method with an `expected_num_examples` of `125`. - - Finally, the function checks whether the computed batch size matches the expected - batch size based on the number of examples needed to reach the expected number of - examples. - - Note: The test uses a temporary directory as the cache root to ensure that the cache - directory is cleaned up after the test finishes. - - Raises: - AssertionError: If the computed batch size - ddoes not match the expected batch size. - """ - with tempfile.TemporaryDirectory() as cache_dir: - os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator(cache_root=cache_dir) - generated_dataset = Dataset.from_dict( - { - "input_col": ["1"] * 110, - "output_col": ["2"] * 110, - } - ) - # Default batch size and responses_per_request are both 5. - # So each batch should contain 25 examples. - - # At least (125 - 110) / 5 = 3 API calls needed to get - # more than 125 examples. - - batch_size = data_generator.compute_batch_size(125, generated_dataset) - assert ( - batch_size - == (125 - len(generated_dataset)) / data_generator.responses_per_request - == (125 - 110) / 5 - ) - - generated_dataset = Dataset.from_dict( - { - "input_col": [1] * 50, - "output_col": [2] * 50, - } - ) - batch_size = data_generator.compute_batch_size(125, generated_dataset) - assert batch_size == data_generator.max_batch_size == 5 - gc.collect() - - -def test_load_cache_dataset_without_filter_duplicated_examples(): - """Test the cached dataset loading without filtering duplicated examples. - - This function tests the cached dataset loading without filtering duplicated - examples. It first saves a dataset to the cache directory and then initializes - the `PromptBasedDatasetGenerator` with `filter_duplicated_examples=False`. - - The `generate_dataset_split()` method is then called with an - `expected_num_examples` of `110`, which is equal to the size of the cached - dataset. The function checks that the cached dataset is successfully loaded, - and the generation stops because the expected number of examples is - already met. - - Note: The test uses a temporary directory as the cache root to ensure that the cache - directory is cleaned up after the test finishes. - - Raises: - AssertionError: If the cached dataset is not loaded correctly, or if the - generation does not stop when the expected number of examples is met. - """ - with tempfile.TemporaryDirectory() as cache_dir: - os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator( - cache_root=cache_dir, filter_duplicated_examples=False - ) - examples_cache_path = Path( - data_generator.cache_root / f"generated_examples_{DatasetSplit.TEST.value}" - ) - cached_examples = Dataset.from_dict( - { - "input_col": ["1"] * 110, - "output_col": ["2"] * 110, - } - ) - cached_examples.save_to_disk(examples_cache_path) - # The generate_dataset_split would first load the cached - # dataset into generated_examples. Then in the while - # loop, create_all_examples_dataset_and_generated_dataset - # would be called to construct the generated_dataset. - # Note that filter_duplicated_examples is False, so the - # generated_examples won't be filtered. And since the - # expected_num_examples is 110, the while loop would exit - # immediately. So the generated_dataset would be the - # same as the cached dataset. - with patch.object(logger, "info") as mock_info, patch.object( - logger, "warning" - ) as mock_warning: - generated_dataset = data_generator.generate_dataset_split( - expected_num_examples=110, - prompt_spec=MockPromptSpec, - split=DatasetSplit.TEST, - ) - mock_info.assert_called_once_with( - f"Loading cache from {str(examples_cache_path)}." - ) - mock_warning.assert_not_called() - assert are_datasets_identical(generated_dataset, cached_examples) - gc.collect() - - -@patch( - "prompt2model.utils.APIAgent.generate_batch_completion", - side_effect=MOCK_CLASSIFICATION_EXAMPLE, -) -def test_load_cache_dataset_without_filter_duplicated_examples_and_continue_generation( - mocked_generate_example, -): - """Test PromptBasedDatasetGenerator can load cache and continue generation. - - This function tests that the `PromptBasedDatasetGenerator` can load the cached - dataset and continue generation if the expected number of examples is - greater than the size of the cached dataset. The test first saves a dataset - to the cache directory and then initializes the `PromptBasedDatasetGenerator` - with `filter_duplicated_examples=False`. The `generate_dataset_split()` - method is then called with an `expected_num_examples` of `117`, which - is greater than the size of the cached dataset. The function checks that - the cached dataset is successfully loaded, and the generation continues - to meet the expected number of examples. - - Note: The test uses a temporary directory as the cache root to ensure - that the cache directory is cleaned up after the test finishes. - - Raises: - AssertionError: If the cached dataset is not loaded correctly, or if the - generation does not continue to meet the expected number of examples. - """ - with tempfile.TemporaryDirectory() as cache_dir: - os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator( - cache_root=cache_dir, filter_duplicated_examples=False - ) - examples_cache_path = Path( - data_generator.cache_root / f"generated_examples_{DatasetSplit.TEST.value}" - ) - cached_dataset = Dataset.from_dict( - { - "input_col": ["1"] * 110, - "output_col": ["2"] * 110, - } - ) - cached_dataset.save_to_disk(examples_cache_path) - # The generate_dataset_split would first load the cached - # dataset into generated_examples. Then in the while - # loop, create_all_examples_dataset_and_generated_dataset - # would be called to construct the generated_dataset. - # Note that filter_duplicated_examples is False, so the - # generated_examples won't be filtered. And since the - # expected_num_examples is 117, the generation would - # continue and the batch_size = 2. After one batch of API - # calls, generated_dataset meets the requirement and - # stop generation. - with patch.object(logger, "info") as mock_info, patch.object( - logger, "warning" - ) as mock_warning: - generated_dataset = data_generator.generate_dataset_split( - expected_num_examples=117, - prompt_spec=MockPromptSpec, - split=DatasetSplit.TEST, - ) - info_list = [each.args[0] for each in mock_info.call_args_list] - assert info_list[0] == f"Loading cache from {str(examples_cache_path)}." - # The first logger.info is loaded cache, and there is - # another 2 * 5 * 2 logger.info in extract_responses. - assert len(info_list) == 1 + 2 * 5 * 2 - mock_warning.assert_not_called() - excepted_generated_dataset = Dataset.from_dict( - { - "input_col": ["1"] * 110 + ["This is a great movie!"] * 10, - "output_col": ["2"] * 110 + ["1"] * 10, - } - ) - assert are_datasets_identical(generated_dataset, excepted_generated_dataset) - assert mocked_generate_example.call_count == 1 - gc.collect() - - -def test_extract_responses(): - """Test the extract_responses function of DatasetGenerator.""" - mock_completion_1 = MockCompletion() - mock_completion_1.choices = [ - {"message": {"content": '{"input": "1", "output": "a"}'}}, - {"message": {"content": '{"input": "1", "output": "b"}'}}, - {"message": {"content": '{"input": "1", "output": "a"}'}}, - ] - mock_completion_2 = MockCompletion() - mock_completion_2.choices = [ - {"message": {"content": '{"input": "3", "output": "a"}'}}, - # Note that the following choice miss the right quote of JSON. - # So it should be discarded. And will log a warning. - {"message": {"content": '{"input": "3", "output": "a}'}}, - {"message": {"content": '{"input": "3", "output": "b"}'}}, - ] - mock_completion_3 = MockCompletion() - mock_completion_3.choices = [ - {"message": {"content": '{"input": "4", "output": "c"}'}}, - {"message": {"content": '{"input": "4", "output": "c"}'}}, - {"message": {"content": '{"input": "5", "output": "a"}'}}, - ] - # choices should be list of dicts. So mock_completion_4 - # is invalid. Which will be discarded and log a warning. - mock_completion_4 = MockCompletion() - mock_completion_4.choices = None - - with tempfile.TemporaryDirectory() as cache_dir: - os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator( - cache_root=cache_dir, filter_duplicated_examples=True - ) - generated_examples = [] - with patch.object(logger, "info") as mock_info, patch.object( - logger, "warning" - ) as mock_warning: - generated_examples = data_generator.extract_responses( - [mock_completion_1, mock_completion_2], generated_examples - ) - mock_warning.assert_called_once_with( - 'Error happened parsing API choice: {\'message\': {\'content\': \'{"input": "3", "output": "a}\'}}' # noqa E501 - ) - # There are 5 valid examples. Each input - # and output will be logged once as info. - assert mock_info.call_count == 5 * 2 - - # The second choice in mock_completion_2 - # is invalid. So it should be discarded. - assert generated_examples == [ - Example(input_col="1", output_col="a"), - Example(input_col="1", output_col="b"), - Example(input_col="1", output_col="a"), - Example(input_col="3", output_col="a"), - Example(input_col="3", output_col="b"), - ] - generated_examples = data_generator.extract_responses( - [mock_completion_3], generated_examples - ) - assert generated_examples == [ - Example(input_col="1", output_col="a"), - Example(input_col="1", output_col="b"), - Example(input_col="1", output_col="a"), - Example(input_col="3", output_col="a"), - Example(input_col="3", output_col="b"), - Example(input_col="4", output_col="c"), - Example(input_col="4", output_col="c"), - Example(input_col="5", output_col="a"), - ] - with patch.object(logger, "info") as mock_info, patch.object( - logger, "warning" - ) as mock_warning: - generated_examples = data_generator.extract_responses( - [mock_completion_4], generated_examples - ) - mock_warning.assert_called_once_with( - "Error happened when parsing API completion: " - ) - mock_info.assert_not_called() - # The generated_examples should be the same. - assert generated_examples == [ - Example(input_col="1", output_col="a"), - Example(input_col="1", output_col="b"), - Example(input_col="1", output_col="a"), - Example(input_col="3", output_col="a"), - Example(input_col="3", output_col="b"), - Example(input_col="4", output_col="c"), - Example(input_col="4", output_col="c"), - Example(input_col="5", output_col="a"), - ] - gc.collect() - - -def test_extract_some_empty_responses(): - """Test the extract_responses function correctly handle empty responses.""" - mock_completion_1 = MockCompletion() - mock_completion_1.choices = [ - # Note that this choice's input is empty. So it should be discarded. - {"message": {"content": '{"input": "", "output": "a"}'}}, - {"message": {"content": '{"input": "5", "output": "b"}'}}, - # Note that this choice's output is empty. So it should be discarded. - {"message": {"content": '{"input": "1", "output": ""}'}}, - ] - mock_completion_2 = MockCompletion() - mock_completion_2.choices = [ - {"message": {"content": '{"input": "3", "output": "a"}'}}, - # Note that the following choice misses the right quote of JSON. - # So it should be discarded. And will log a warning. - {"message": {"content": '{"input": "3", "output": "a}'}}, - {"message": {"content": '{"input": "3", "output": "b"}'}}, - ] - mock_completion_3 = MockCompletion() - mock_completion_3.choices = [ - {"message": {"content": '{"input": "4", "output": "c"}'}}, - {"message": {"content": '{"input": "4", "output": "c"}'}}, - {"message": {"content": '{"input": "5", "output": "a"}'}}, - ] - # choices should be list of dicts. So mock_completion_4 - # is invalid. Which will be discarded and log a warning. - mock_completion_4 = MockCompletion() - mock_completion_4.choices = None - - with tempfile.TemporaryDirectory() as cache_dir: - os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator( - cache_root=cache_dir, filter_duplicated_examples=True - ) - generated_examples = [] - with patch.object(logger, "info") as mock_info, patch.object( - logger, "warning" - ) as mock_warning: - generated_examples = data_generator.extract_responses( - [mock_completion_1, mock_completion_2], generated_examples - ) - mock_warning.assert_called_once_with( - 'Error happened parsing API choice: {\'message\': {\'content\': \'{"input": "3", "output": "a}\'}}' # noqa E501 - ) - # There are 3 valid examples in [mock_completion_1, - # mock_completion_2] Each input - # and output will be logged once as info. - # And there are 2 examples with empty - # input or output, which should be discarded - # and be logged as info. - assert mock_info.call_count == 3 * 2 + 2 - - # The second choice in mock_completion_2 - # is invalid. So it should be discarded. - assert generated_examples == [ - Example(input_col="5", output_col="b"), - Example(input_col="3", output_col="a"), - Example(input_col="3", output_col="b"), - ] - generated_examples = data_generator.extract_responses( - [mock_completion_3], generated_examples - ) - assert generated_examples == [ - Example(input_col="5", output_col="b"), - Example(input_col="3", output_col="a"), - Example(input_col="3", output_col="b"), - Example(input_col="4", output_col="c"), - Example(input_col="4", output_col="c"), - Example(input_col="5", output_col="a"), - ] - with patch.object(logger, "info") as mock_info, patch.object( - logger, "warning" - ) as mock_warning: - generated_examples = data_generator.extract_responses( - [mock_completion_4], generated_examples - ) - mock_warning.assert_called_once_with( - "Error happened when parsing API completion: " - ) - mock_info.assert_not_called() - # The generated_examples should be the same. - assert generated_examples == [ - Example(input_col="5", output_col="b"), - Example(input_col="3", output_col="a"), - Example(input_col="3", output_col="b"), - Example(input_col="4", output_col="c"), - Example(input_col="4", output_col="c"), - Example(input_col="5", output_col="a"), - ] - gc.collect() - - -def test_initialize_dataset_generator_with_dynamic_temperature(): - """Test the correct initialization of the dynamic temperature strategy.""" - with tempfile.TemporaryDirectory() as cache_dir: - os.environ["OPENAI_API_KEY"] = "fake_api_key" - with pytest.raises(ValueError) as exc_info: - _ = PromptBasedDatasetGenerator( - cache_root=cache_dir, initial_temperature=-0.2 - ) - error_info = exc_info.value.args[0] - assert ( - error_info - == "initial_temperature must be >= 0, but self.initial_temperature=-0.2" - ) - with pytest.raises(ValueError) as exc_info: - _ = PromptBasedDatasetGenerator(cache_root=cache_dir, max_temperature=2.3) - error_info = exc_info.value.args[0] - assert ( - error_info - == "max_temperature must be <= 2,0, but self.max_temperature=2.3" - ) - - with pytest.raises(ValueError) as exc_info: - _ = PromptBasedDatasetGenerator( - cache_root=cache_dir, max_temperature=1.2, initial_temperature=1.5 - ) - error_info = exc_info.value.args[0] - assert ( - error_info - == "self.initial_temperature=1.5 must be <= self.max_temperature=1.2" - ) diff --git a/tests/dataset_processor_test.py b/tests/dataset_processor_test.py index b00bfb519..ece1b78ec 100644 --- a/tests/dataset_processor_test.py +++ b/tests/dataset_processor_test.py @@ -9,11 +9,7 @@ import pytest from prompt2model.dataset_processor.textualize import TextualizeProcessor -from test_helpers import ( - are_dataset_dicts_identical, - create_gpt2_model_and_tokenizer, - create_t5_model_and_tokenizer, -) +from test_helpers import create_gpt2_model_and_tokenizer, create_t5_model_and_tokenizer logger = logging.getLogger("DatasetProcessor") @@ -128,10 +124,12 @@ def test_dataset_processor_t5_style(): INSTRUCTION, DATASET_DICTS ) # Ensure the dataset_dicts themselves are the same after processing. - assert all( - are_dataset_dicts_identical(raw, origin) - for (raw, origin) in zip(raw_dataset_dicts, DATASET_DICTS) - ) + for raw, origin in zip(raw_dataset_dicts, DATASET_DICTS): + assert list(raw["train"]) == list(origin["train"]) + if "val" in raw: + assert list(raw["val"]) == list(origin["val"]) + if "test" in raw: + assert list(raw["test"]) == list(origin["test"]) t5_expected_dataset_dicts = [ datasets.DatasetDict( { @@ -179,7 +177,11 @@ def test_dataset_processor_t5_style(): ), ] for exp, act in zip(t5_expected_dataset_dicts, t5_modified_dataset_dicts): - assert are_dataset_dicts_identical(exp, act) + assert list(exp["train"]) == list(act["train"]) + if "val" in exp: + assert list(exp["val"]) == list(act["val"]) + if "test" in exp: + assert list(exp["test"]) == list(act["test"]) gc.collect() @@ -233,7 +235,7 @@ def test_dataset_processor_with_numerical_column(): "convert to text2text\nExample:\nfoo\nLabel:\n", "convert to text2text\nExample:\nbar\nLabel:\n", ], - "model_output": ["foo", "bar", "0", "1"], + "model_output": ["baz", "qux", "0", "1"], } ), "test": datasets.Dataset.from_dict( @@ -260,9 +262,8 @@ def test_dataset_processor_with_numerical_column(): actual_dataset_dict = datasets.DatasetDict( {"train": concatenated_training_dataset, "test": concatenated_test_dataset} ) - are_dataset_dicts_identical(expected_dataset_dict, actual_dataset_dict) - - gc.collect() + assert list(expected_dataset_dict["train"]) == list(actual_dataset_dict["train"]) + assert list(expected_dataset_dict["test"]) == list(actual_dataset_dict["test"]) def test_dataset_processor_decoder_only_style(): @@ -276,10 +277,12 @@ def test_dataset_processor_decoder_only_style(): INSTRUCTION, DATASET_DICTS ) # Ensure the dataset_dicts themselves are the same after processing. - assert all( - are_dataset_dicts_identical(raw, origin) - for raw, origin in zip(raw_dataset_dicts, DATASET_DICTS) - ) + for raw, origin in zip(raw_dataset_dicts, DATASET_DICTS): + assert list(raw["train"]) == list(origin["train"]) + if "val" in raw: + assert list(raw["val"]) == list(origin["val"]) + if "test" in raw: + assert list(raw["test"]) == list(origin["test"]) # Check that the modified dataset dicts have the expected content gpt_expected_dataset_dicts = [ datasets.DatasetDict( @@ -327,13 +330,12 @@ def test_dataset_processor_decoder_only_style(): } ), ] - assert all( - are_dataset_dicts_identical(exp, modified) - for (exp, modified) in zip( - gpt_expected_dataset_dicts, gpt_modified_dataset_dicts - ) - ) - gc.collect() + for exp, modified in zip(gpt_expected_dataset_dicts, gpt_modified_dataset_dicts): + assert list(exp["train"]) == list(modified["train"]) + if "val" in exp: + assert list(exp["val"]) == list(modified["val"]) + if "test" in exp: + assert list(exp["test"]) == list(modified["test"]) def test_unexpected_dataset_split(): @@ -436,11 +438,12 @@ def test_empty_filter_t5_type(): } ), ] - assert all( - are_dataset_dicts_identical(exp, act) - for exp, act in zip(t5_expected_dataset_dicts, t5_modified_dataset_dicts) - ) - gc.collect() + for exp, modified in zip(t5_expected_dataset_dicts, t5_modified_dataset_dicts): + assert list(exp["train"]) == list(modified["train"]) + if "val" in exp: + assert list(exp["val"]) == list(modified["val"]) + if "test" in exp: + assert list(exp["test"]) == list(modified["test"]) def test_empty_filter_decoder_only_style(): @@ -486,12 +489,12 @@ def test_empty_filter_decoder_only_style(): } ), ] - assert all( - are_dataset_dicts_identical(expected, modified) - for expected, modified in zip( - gpt_expected_dataset_dicts, gpt_modified_dataset_dicts - ) - ) + for exp, modified in zip(gpt_expected_dataset_dicts, gpt_modified_dataset_dicts): + assert list(exp["train"]) == list(modified["train"]) + if "val" in exp: + assert list(exp["val"]) == list(modified["val"]) + if "test" in exp: + assert list(exp["test"]) == list(modified["test"]) gc.collect() @@ -539,7 +542,7 @@ def test_raise_value_error_of_process_dataset_lists(): def test_process_dataset_lists(): """Test the `process_dataset_lists` function.""" processor = TextualizeProcessor(has_encoder=True) - modified_dataset_dicsts = processor.process_dataset_lists( + modified_dataset_dicts = processor.process_dataset_lists( INSTRUCTION, DATASET_LIST, 0.6, 0.2 ) expected_modified_generated_dataset_dict = datasets.DatasetDict( @@ -604,22 +607,24 @@ def test_process_dataset_lists(): ), } ) - assert all( - are_dataset_dicts_identical(raw, origin) - for (raw, origin) in zip( - [ - expected_modified_generated_dataset_dict, - expected_modified_retrieved_dataset_dict, - ], - modified_dataset_dicsts, - ) - ) + for exp, modified in zip( + [ + expected_modified_generated_dataset_dict, + expected_modified_retrieved_dataset_dict, + ], + modified_dataset_dicts, + ): + assert list(exp["train"]) == list(modified["train"]) + if "val" in exp: + assert list(exp["val"]) == list(modified["val"]) + if "test" in exp: + assert list(exp["test"]) == list(modified["test"]) def test_process_dataset_lists_with_maximum_example_num(): """Test the maximum_example_num parameter.""" processor = TextualizeProcessor(has_encoder=True) - modified_dataset_dicsts = processor.process_dataset_lists( + modified_dataset_dicts = processor.process_dataset_lists( INSTRUCTION, DATASET_LIST, 0.6, 0.2, 3000 ) # Before applying the maximum_example_num, train_num = 6000, @@ -688,13 +693,15 @@ def test_process_dataset_lists_with_maximum_example_num(): ), } ) - assert all( - are_dataset_dicts_identical(raw, origin) - for (raw, origin) in zip( - [ - expected_modified_generated_dataset_dict, - expected_modified_retrieved_dataset_dict, - ], - modified_dataset_dicsts, - ) - ) + for exp, modified in zip( + [ + expected_modified_generated_dataset_dict, + expected_modified_retrieved_dataset_dict, + ], + modified_dataset_dicts, + ): + assert list(exp["train"]) == list(modified["train"]) + if "val" in exp: + assert list(exp["val"]) == list(modified["val"]) + if "test" in exp: + assert list(exp["test"]) == list(modified["test"]) From 2f8065ebb41cc67ede3454cd7972258595402ad4 Mon Sep 17 00:00:00 2001 From: Graham Neubig Date: Thu, 7 Sep 2023 18:17:58 -0400 Subject: [PATCH 2/7] Fixes error handling --- prompt2model/prompt_parser/instr_parser.py | 59 +++++++++------------- tests/prompt_parser_test.py | 8 +-- 2 files changed, 29 insertions(+), 38 deletions(-) diff --git a/prompt2model/prompt_parser/instr_parser.py b/prompt2model/prompt_parser/instr_parser.py index 2afbed4f7..a9c69823d 100644 --- a/prompt2model/prompt_parser/instr_parser.py +++ b/prompt2model/prompt_parser/instr_parser.py @@ -14,6 +14,7 @@ ) from prompt2model.utils import APIAgent, get_formatted_logger +from prompt2model.utils.api_tools import API_ERRORS, handle_api_error logger = get_formatted_logger("PromptParser") @@ -85,39 +86,29 @@ def parse_from_prompt(self, prompt: str) -> None: parsing_prompt_for_chatgpt = construct_prompt_for_instruction_parsing(prompt) chat_api = APIAgent() + last_error = None while True: self.api_call_counter += 1 - response = chat_api.generate_one_completion( - parsing_prompt_for_chatgpt, - temperature=0, - presence_penalty=0, - frequency_penalty=0, - ) - - if isinstance(response, Exception): - # Generation failed due to an API related error and requires retry. - - if self.max_api_calls and self.api_call_counter >= self.max_api_calls: - # In case we reach maximum number of API calls, we raise an error. - logger.error("Maximum number of API calls reached.") - raise ValueError( - "Maximum number of API calls reached." - ) from response - - continue # no need to proceed with extracting - # response if API call failed. - - extraction = self.extract_response(response) - - if extraction is not None: - # extraction is successful - - self._instruction, self._examples = extraction - return None - - if self.max_api_calls and self.api_call_counter == self.max_api_calls: - # In case we reach maximum number of API calls without a - # successful extraction, we return None. - - logger.warning("Maximum number of API calls reached for PromptParser.") - return None + try: + response: openai.ChatCompletion | Exception = ( + chat_api.generate_one_completion( + parsing_prompt_for_chatgpt, + temperature=0, + presence_penalty=0, + frequency_penalty=0, + ) + ) + extraction = self.extract_response(response) + if extraction is not None: + self._instruction, self._examples = extraction + return + except API_ERRORS as e: + last_error = e + handle_api_error(e) + + if self.max_api_calls and self.api_call_counter >= self.max_api_calls: + # In case we reach maximum number of API calls, we raise an error. + logger.error("Maximum number of API calls reached.") + raise RuntimeError( + "Maximum number of API calls reached." + ) from last_error diff --git a/tests/prompt_parser_test.py b/tests/prompt_parser_test.py index 6b4a4f363..9faa4fb33 100644 --- a/tests/prompt_parser_test.py +++ b/tests/prompt_parser_test.py @@ -126,7 +126,7 @@ def test_instruction_parser_with_invalid_json(mocked_parsing_method): @patch("time.sleep") @patch( - "openai.ChatCompletion.create", + "prompt2model.utils.APIAgent.generate_one_completion", side_effect=openai.error.Timeout("timeout"), ) def test_instruction_parser_with_timeout(mocked_parsing_method, mocked_sleep_method): @@ -141,7 +141,7 @@ def test_instruction_parser_with_timeout(mocked_parsing_method, mocked_sleep_met some time after each API timeout. """ prompt = """This prompt will be ignored by the parser in this test.""" - with pytest.raises(ValueError) as exc_info: + with pytest.raises(RuntimeError) as exc_info: prompt_spec = PromptBasedInstructionParser( task_type=TaskType.TEXT_GENERATION, max_api_calls=3 ) @@ -152,8 +152,8 @@ def test_instruction_parser_with_timeout(mocked_parsing_method, mocked_sleep_met assert mocked_sleep_method.call_count == 3 assert mocked_parsing_method.call_count == 3 - # Check if the ValueError was raised - assert isinstance(exc_info.value, ValueError) + # Check if the RuntimeError was raised + assert isinstance(exc_info.value, RuntimeError) # Check if the original exception (e) is present as the cause original_exception = exc_info.value.__cause__ assert isinstance(original_exception, openai.error.Timeout) From c7c29f7500115c6fde32da6b85feaca121f9d7ce Mon Sep 17 00:00:00 2001 From: Graham Neubig Date: Thu, 7 Sep 2023 18:22:52 -0400 Subject: [PATCH 3/7] Fix typechecking error --- prompt2model/model_retriever/generate_hypothetical_document.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prompt2model/model_retriever/generate_hypothetical_document.py b/prompt2model/model_retriever/generate_hypothetical_document.py index f49af74e1..67bdc54ad 100644 --- a/prompt2model/model_retriever/generate_hypothetical_document.py +++ b/prompt2model/model_retriever/generate_hypothetical_document.py @@ -443,7 +443,7 @@ def generate_hypothetical_model_description( ) return chatgpt_completion.choices[0]["message"]["content"] except API_ERRORS as e: - api_call_counter = handle_api_error(e, api_call_counter) + handle_api_error(e) if max_api_calls and api_call_counter >= max_api_calls: logging.error("Maximum number of API calls reached.") raise ValueError("Maximum number of API calls reached.") from e From e880426b010a0aa00a344019a9e4ef0903ff57a6 Mon Sep 17 00:00:00 2001 From: Graham Neubig Date: Thu, 7 Sep 2023 19:52:08 -0400 Subject: [PATCH 4/7] Fix prompt parser test --- tests/prompt_parser_test.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/prompt_parser_test.py b/tests/prompt_parser_test.py index 9faa4fb33..895d19003 100644 --- a/tests/prompt_parser_test.py +++ b/tests/prompt_parser_test.py @@ -112,12 +112,11 @@ def test_instruction_parser_with_invalid_json(mocked_parsing_method): with patch.object(logger, "info") as mock_info, patch.object( logger, "warning" ) as mock_warning: - prompt_spec.parse_from_prompt(prompt) + with pytest.raises(RuntimeError): + prompt_spec.parse_from_prompt(prompt) mock_info.assert_not_called() warning_list = [each.args[0] for each in mock_warning.call_args_list] - assert warning_list == ["API response was not a valid JSON"] * 3 + [ - "Maximum number of API calls reached for PromptParser." - ] + assert warning_list == ["API response was not a valid JSON"] * 3 assert mocked_parsing_method.call_count == 3 assert prompt_spec._instruction is None assert prompt_spec._examples is None From 659bd3d7eb9ab12bde3963c3604fff44cce2307e Mon Sep 17 00:00:00 2001 From: Graham Neubig Date: Fri, 8 Sep 2023 08:22:36 -0400 Subject: [PATCH 5/7] Update prompt2model/model_retriever/generate_hypothetical_document.py Co-authored-by: Vijay Viswanathan --- prompt2model/model_retriever/generate_hypothetical_document.py | 1 + 1 file changed, 1 insertion(+) diff --git a/prompt2model/model_retriever/generate_hypothetical_document.py b/prompt2model/model_retriever/generate_hypothetical_document.py index 67bdc54ad..5f17aed6c 100644 --- a/prompt2model/model_retriever/generate_hypothetical_document.py +++ b/prompt2model/model_retriever/generate_hypothetical_document.py @@ -444,6 +444,7 @@ def generate_hypothetical_model_description( return chatgpt_completion.choices[0]["message"]["content"] except API_ERRORS as e: handle_api_error(e) + api_call_counter += 1 if max_api_calls and api_call_counter >= max_api_calls: logging.error("Maximum number of API calls reached.") raise ValueError("Maximum number of API calls reached.") from e From 9a23cf555e389f56f449c51d52336d93bbbda0a1 Mon Sep 17 00:00:00 2001 From: Graham Neubig Date: Fri, 8 Sep 2023 08:27:47 -0400 Subject: [PATCH 6/7] Reflect review comments --- prompt2model/dataset_generator/prompt_based.py | 10 +++++----- tests/dataset_generator_test.py | 8 ++++---- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/prompt2model/dataset_generator/prompt_based.py b/prompt2model/dataset_generator/prompt_based.py index 793a791a9..0f0be8d47 100644 --- a/prompt2model/dataset_generator/prompt_based.py +++ b/prompt2model/dataset_generator/prompt_based.py @@ -147,8 +147,8 @@ def construct_prompt( parsed from the user's prompt, which quality is higher than the generated examples. generated_examples: A list of currently generated examples. - context_cutoff: If the total length of the prompt exceeds this value, - repeat the prompt generation process to generate a shorter one. + context_cutoff: If the total length of the prompt in tokens exceeds this + value, repeat prompt generation process to generate a shorter one. Returns: The generated prompt string. @@ -224,6 +224,9 @@ def apply_multi_vote_filtering( Returns: Currently generated dataset with multi-vote filtering applied. """ + # Ensure that multi-vote filtering is enabled. + if not self.filter_duplicated_examples: + raise ValueError("Multi-vote filtering is not enabled.") filtered_examples = [] input_output_map: dict[str, Counter] = defaultdict(Counter) @@ -231,9 +234,6 @@ def apply_multi_vote_filtering( for ex in generated_examples: input_output_map[ex.input_col][ex.output_col] += 1 - if len(generated_examples) != 0 and input_output_map is None: - raise ValueError("input_output_map is not correctly constructed.") - for input_str, output_counter in input_output_map.items(): most_common_count = output_counter.most_common(1)[0][1] diff --git a/tests/dataset_generator_test.py b/tests/dataset_generator_test.py index b15990778..837e18574 100644 --- a/tests/dataset_generator_test.py +++ b/tests/dataset_generator_test.py @@ -598,7 +598,7 @@ def test_unexpected_examples_of_gpt(mocked_generate_example): def test_filter_with_duplicate_inputs_unique_outputs(): """Test filtering with duplicate inputs, unique outputs.""" os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=False) + data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=True) generated_examples = [ Example(input_col="apple", output_col="A"), Example(input_col="banana", output_col="B"), @@ -618,7 +618,7 @@ def test_filter_with_duplicate_inputs_unique_outputs(): def test_filter_duplicate_inputs_duplicate_outputs(): """Test constructing a map with duplicate inputs and duplicate outputs.""" os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=False) + data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=True) generated_examples = [ Example(input_col="apple", output_col="A"), Example(input_col="banana", output_col="C"), @@ -643,7 +643,7 @@ def test_filter_duplicate_inputs_duplicate_outputs(): def test_create_all_examples_dataset_and_generated_dataset_with_unique_inputs_outputs(): """Test constructing a map with unique inputs and outputs.""" os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=False) + data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=True) generated_examples = [ Example(input_col="apple", output_col="A"), Example(input_col="banana", output_col="B"), @@ -656,7 +656,7 @@ def test_create_all_examples_dataset_and_generated_dataset_with_unique_inputs_ou def test_create_all_examples_dataset_and_generated_dataset_with_empty_examples_list(): """Test constructing a map with empty inputs and outputs.""" os.environ["OPENAI_API_KEY"] = "fake_api_key" - data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=False) + data_generator = PromptBasedDatasetGenerator(filter_duplicated_examples=True) generated_examples = [] filtered_examples = data_generator.apply_multi_vote_filtering(generated_examples) assert generated_examples == filtered_examples From d35bbceb07f22daba6e7bde73b075aff07a33da1 Mon Sep 17 00:00:00 2001 From: Graham Neubig Date: Fri, 8 Sep 2023 08:30:52 -0400 Subject: [PATCH 7/7] Update prompt2model/dataset_generator/prompt_based.py Co-authored-by: Vijay Viswanathan --- prompt2model/dataset_generator/prompt_based.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prompt2model/dataset_generator/prompt_based.py b/prompt2model/dataset_generator/prompt_based.py index 0f0be8d47..806e05735 100644 --- a/prompt2model/dataset_generator/prompt_based.py +++ b/prompt2model/dataset_generator/prompt_based.py @@ -154,7 +154,7 @@ def construct_prompt( The generated prompt string. """ while True: - # Choose a few examples to add to the prompt if examples exist + # Choose a few examples to add to the prompt if examples exist. if len(generated_examples) == 0: low_quality_example_string = "N/A\n" random_selected_generated_example_num = 0