From 2d864c78496cbfab9c50c110c7892c007b490c5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Knes?= Date: Mon, 17 Feb 2025 12:17:24 +0100 Subject: [PATCH] support for remote inference and wx.ai integration --- bigcode_eval/arguments.py | 17 ++ bigcode_eval/generation.py | 8 +- bigcode_eval/remote_inference/__init__.py | 0 bigcode_eval/remote_inference/base.py | 174 +++++++++++++++++ bigcode_eval/remote_inference/utils.py | 55 ++++++ bigcode_eval/remote_inference/wx_ai.py | 109 +++++++++++ main.py | 216 +++++++++++----------- tests/test_generation_evaluation.py | 1 + 8 files changed, 476 insertions(+), 104 deletions(-) create mode 100644 bigcode_eval/remote_inference/__init__.py create mode 100644 bigcode_eval/remote_inference/base.py create mode 100644 bigcode_eval/remote_inference/utils.py create mode 100644 bigcode_eval/remote_inference/wx_ai.py diff --git a/bigcode_eval/arguments.py b/bigcode_eval/arguments.py index 4c0c39264..ec364082d 100644 --- a/bigcode_eval/arguments.py +++ b/bigcode_eval/arguments.py @@ -36,3 +36,20 @@ class EvalArguments: seed: Optional[int] = field( default=0, metadata={"help": "Random seed used for evaluation."} ) + length_penalty: Optional[dict[str, int | float]] = field( + default=None, + metadata={"help": "A dictionary with length penalty options (for watsonx.ai)."} + ) + max_new_tokens: Optional[int] = field( + default=None, metadata={"help": "Maximum number of generated tokens (for watsonx.ai)."} + ) + min_new_tokens: Optional[int] = field( + default=None, metadata={"help": "Minimum number of generated tokens (for watsonx.ai)."} + ) + stop_sequences: Optional[list[str]] = field( + default=None, metadata={"help": "List of stop sequences (for watsonx.ai)."} + ) + repetition_penalty: Optional[float] = field( + default=None, + metadata={"help": "A float value of repetition penalty (for watsonx.ai)."} + ) diff --git a/bigcode_eval/generation.py b/bigcode_eval/generation.py index 98e15a7be..2c1b75477 100644 --- a/bigcode_eval/generation.py +++ b/bigcode_eval/generation.py @@ -7,6 +7,7 @@ from torch.utils.data.dataloader import DataLoader from transformers import StoppingCriteria, StoppingCriteriaList +from bigcode_eval.remote_inference.utils import remote_inference from bigcode_eval.utils import TokenizedDataset, complete_code @@ -62,6 +63,11 @@ def parallel_generations( ) return generations[:n_tasks] + if args.inference_platform != "hf": + return remote_inference( + args.inference_platform, dataset, task, args + ) + set_seed(args.seed, device_specific=True) # Setup generation settings @@ -89,7 +95,7 @@ def parallel_generations( stopping_criteria.append( TooLongFunctionCriteria(0, task.max_length_multiplier) ) - + if stopping_criteria: gen_kwargs["stopping_criteria"] = StoppingCriteriaList(stopping_criteria) diff --git a/bigcode_eval/remote_inference/__init__.py b/bigcode_eval/remote_inference/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bigcode_eval/remote_inference/base.py b/bigcode_eval/remote_inference/base.py new file mode 100644 index 000000000..00c9a3bdc --- /dev/null +++ b/bigcode_eval/remote_inference/base.py @@ -0,0 +1,174 @@ +from abc import abstractmethod +from argparse import Namespace +from typing import Any, Optional + +from datasets import Dataset as HfDataset + +from bigcode_eval.base import Task +from bigcode_eval.utils import _parse_instruction + + +Dataset = HfDataset | list[dict[str, Any]] + + +class RemoteInferenceInterface: + @abstractmethod + def __init__(self): + raise NotImplementedError + + @abstractmethod + def _prepare_generation_params(self, args: Namespace) -> dict[str, Any]: + """Method maps HF generation parameters to platform-specific ones.""" + + raise NotImplementedError + + @staticmethod + def _limit_inputs( + dataset: Dataset, limit: Optional[int], offset: Optional[int] + ) -> Dataset: + """Method limits input dataset based on provided `limit` and `limit_start` args.""" + + is_hf = isinstance(dataset, HfDataset) + + if offset: + dataset = ( + dataset.select(range(offset, len(dataset))) + if is_hf + else dataset[offset:] + ) + + if limit: + dataset = ( + dataset.take(limit) + if is_hf + else dataset[:limit] + ) + + return dataset + + @staticmethod + def _make_instruction_prompt( + instruction: str, + context: str, + prefix: str, + instruction_tokens: Optional[str], + ) -> str: + """Method creates a prompt for instruction-tuning based on a given prefix and instruction tokens.""" + + user_token, end_token, assistant_token = "", "", "\n" + if instruction_tokens: + user_token, end_token, assistant_token = instruction_tokens.split(",") + + return "".join( + ( + prefix, + user_token, + instruction, + end_token, + assistant_token, + context, + ) + ) + + @staticmethod + def _make_infill_prompt(prefix: str, content_prefix: str, content_suffix: str) -> str: + """Method creates a prompt for infilling. + As it depends on particular models, it may be necessary to implement the method separately for each platform. + """ + + return f"{prefix}{content_prefix}{content_suffix}" + + def _create_prompt_from_dict( + self, content: dict[str, str], prefix: str, instruction_tokens: Optional[str] + ) -> str: + """Method prepares a prompt in similar way to the `TokenizedDataset` class for either instruction or infilling mode.""" + + if all(key in ("instruction", "context") for key in content): + return self._make_instruction_prompt( + content["instruction"], content["context"], prefix, instruction_tokens + ) + + elif all(key in ("prefix", "suffix") for key in content): + return self._make_infill_prompt(prefix, content["prefix"], content["suffix"]) + + else: + raise ValueError(f"Unsupported prompt format:\n{content}.") + + def _prepare_prompts( + self, + dataset: Dataset, + task: Task, + prefix: str, + instruction_tokens: Optional[str], + ) -> list[str]: + """Method creates prompts for inputs based on the task prompt, prefix and instruction tokens (if applicable).""" + + is_string = isinstance(task.get_prompt(dataset[0]), str) + + return [ + prefix + task.get_prompt(instance) + if is_string + else self._create_prompt_from_dict( + task.get_prompt(instance), prefix, instruction_tokens + ) + for instance in dataset + ] + + @abstractmethod + def _infer( + self, inputs: list[str], params: dict[str, Any], args: Namespace + ) -> list[list[str]]: + """Method responsible for inference on a given platform.""" + + raise NotImplementedError + + @staticmethod + def _postprocess_predictions( + predictions: list[list[str]], + prompts: list[str], + task: Task, + instruction_tokens: Optional[str], + ) -> list[list[str]]: + """Method postprocess model's predictions based on a given task and instruction tokens (if applicable).""" + + if instruction_tokens: + predictions = [ + [_parse_instruction(prediction[0], instruction_tokens.split(","))] + for prediction in predictions + ] + + return [ + [ + task.postprocess_generation( + prompts[i] + predictions[i][0], i + ) + ] + for i in range(len(predictions)) + ] + + def prepare_generations( + self, + dataset: Dataset, + task: Task, + args: Namespace, + prefix: str = "", + postprocess: bool = True, + ) -> list[list[str]]: + """Method generates (and postprocess) code using given platform. It follows the same process as HF inference.""" + + gen_params = self._prepare_generation_params(args) + + dataset = self._limit_inputs(dataset, args.limit, args.limit_start) + + prompts = self._prepare_prompts( + dataset, task, prefix, args.instruction_tokens + ) + + predictions = self._infer(prompts, gen_params, args) + + if postprocess: + return self._postprocess_predictions( + predictions, prompts, task, args.instruction_tokens + ) + + return predictions diff --git a/bigcode_eval/remote_inference/utils.py b/bigcode_eval/remote_inference/utils.py new file mode 100644 index 000000000..4b8a797c0 --- /dev/null +++ b/bigcode_eval/remote_inference/utils.py @@ -0,0 +1,55 @@ +from argparse import Namespace +from importlib import import_module + +from bigcode_eval.base import Task + +from bigcode_eval.remote_inference.base import Dataset, RemoteInferenceInterface +from bigcode_eval.remote_inference.wx_ai import WxInference + + +required_packages = { + "wx": ["ibm_watsonx_ai"], +} + + +def check_packages_installed(names: list[str]) -> bool: + for name in names: + try: + import_module(name) + except (ImportError, ModuleNotFoundError, NameError): + return False + return True + + +def remote_inference( + inference_platform: str, + dataset: Dataset, + task: Task, + args: Namespace, +) -> list[list[str]]: + packages = required_packages.get(inference_platform) + if packages and not check_packages_installed(packages): + raise RuntimeError( + f"In order to run inference with '{inference_platform}', the " + f"following packages are required: '{packages}'. However, they " + f"could not be properly imported. Check if the packages are " + f"installed correctly." + ) + + inference_cls: RemoteInferenceInterface + + if inference_platform == "wx": + inference_cls = WxInference() + + else: + raise ValueError( + f"Unsupported remote inference platform: '{inference_platform}'." + ) + + return inference_cls.prepare_generations( + dataset=dataset, + task=task, + args=args, + prefix=args.prefix, + postprocess=args.postprocess, + ) diff --git a/bigcode_eval/remote_inference/wx_ai.py b/bigcode_eval/remote_inference/wx_ai.py new file mode 100644 index 000000000..977a6aa07 --- /dev/null +++ b/bigcode_eval/remote_inference/wx_ai.py @@ -0,0 +1,109 @@ +import logging +import os +from argparse import Namespace +from typing import Any + +from ibm_watsonx_ai import APIClient +from ibm_watsonx_ai.foundation_models import ModelInference + +from bigcode_eval.remote_inference.base import RemoteInferenceInterface + + +class WxInference(RemoteInferenceInterface): + def __init__(self): + creds = self._read_wx_credentials() + + self.client = APIClient(credentials=creds) + + if "project_id" in creds: + self.client.set.default_project(creds["project_id"]) + if "space_id" in creds: + self.client.set.default_space(creds["space_id"]) + + @staticmethod + def _read_wx_credentials() -> dict[str, str]: + credentials = {} + + url = os.environ.get("WX_URL") + if not url: + raise EnvironmentError( + "You need to specify the URL address by setting the env " + "variable 'WX_URL', if you want to run watsonx.ai inference." + ) + credentials["url"] = url + + project_id = os.environ.get("WX_PROJECT_ID") + space_id = os.environ.get("WX_SPACE_ID") + if project_id and space_id: + logging.warning( + "Both the project ID and the space ID were specified. " + "The class 'WxInference' will access the project by default." + ) + credentials["project_id"] = project_id + elif project_id: + credentials["project_id"] = project_id + elif space_id: + credentials["space_id"] = space_id + else: + raise EnvironmentError( + "You need to specify the project ID or the space id by setting the " + "appropriate env variable (either 'WX_PROJECT_ID' or 'WX_SPACE_ID'), " + "if you want to run watsonx.ai inference." + ) + + apikey = os.environ.get("WX_APIKEY") + username = os.environ.get("WX_USERNAME") + password = os.environ.get("WX_PASSWORD") + if apikey and username and password: + logging.warning( + "All of API key, username and password were specified. " + "The class 'WxInference' will use the API key for authorization " + "by default." + ) + credentials["apikey"] = apikey + elif apikey: + credentials["apikey"] = apikey + elif username and password: + credentials["username"] = username + credentials["password"] = password + else: + raise EnvironmentError( + "You need to specify either the API key, or both the username and " + "password by setting appropriate env variable ('WX_APIKEY', 'WX_USERNAME', " + "'WX_PASSWORD'), if you want to run watsonx.ai inference." + ) + + return credentials + + def _prepare_generation_params(self, args: Namespace) -> dict[str, Any]: + """Method maps generation parameters from args to be compatible with watsonx.ai.""" + + return { + "decoding_method": "sample" if args.do_sample else "greedy", + "random_seed": None if args.seed == 0 else args.seed, # seed must be greater than 0 + "temperature": args.temperature, + "top_p": args.top_p, + "top_k": None if args.top_k == 0 else args.top_k, # top_k cannot be 0 + "max_new_tokens": args.max_new_tokens, + "min_new_tokens": args.min_new_tokens, + "length_penalty": args.length_penalty, + "stop_sequences": args.stop_sequences, + "repetition_penalty": args.repetition_penalty, + } + + def _infer( + self, inputs: list[str], params: dict[str, Any], args: Namespace + ) -> list[list[str]]: + model = ModelInference( + model_id=args.model, + api_client=self.client, + ) + + return [ + [result["results"][0]["generated_text"]] + for result in + model.generate( + prompt=inputs, + params=params, + ) + ] diff --git a/main.py b/main.py index 5d030909c..b5c1d5138 100644 --- a/main.py +++ b/main.py @@ -210,6 +210,13 @@ def parse_args(): action="store_true", help="Don't run generation but benchmark groundtruth (useful for debugging)", ) + parser.add_argument( + "--inference_platform", + type=str, + default="hf", + choices=["hf", "wx"], + help="Inference platform for code generation. Default is 'hf', which infers locally using Huggingface.", + ) return parser.parse_args() @@ -253,112 +260,115 @@ def main(): results[task] = evaluator.evaluate(task) else: # here we generate code and save it (evaluation is optional but True by default) - dict_precisions = { - "fp32": torch.float32, - "fp16": torch.float16, - "bf16": torch.bfloat16, - } - if args.precision not in dict_precisions: - raise ValueError( - f"Non valid precision {args.precision}, choose from: fp16, fp32, bf16" - ) - - model_kwargs = { - "revision": args.revision, - "trust_remote_code": args.trust_remote_code, - "token": args.use_auth_token, - } - if args.load_in_8bit: - print("Loading model in 8bit") - model_kwargs["load_in_8bit"] = args.load_in_8bit - model_kwargs["device_map"] = {"": accelerator.process_index} - elif args.load_in_4bit: - print("Loading model in 4bit") - model_kwargs["load_in_4bit"] = args.load_in_4bit - model_kwargs["torch_dtype"] = torch.float16 - model_kwargs["bnb_4bit_compute_dtype"] = torch.float16 - model_kwargs["device_map"] = {"": accelerator.process_index} - else: - print(f"Loading model in {args.precision}") - model_kwargs["torch_dtype"] = dict_precisions[args.precision] - - if args.max_memory_per_gpu: - if args.max_memory_per_gpu != "auto": - model_kwargs["max_memory"] = get_gpus_max_memory( - args.max_memory_per_gpu, accelerator.num_processes - ) - model_kwargs["offload_folder"] = "offload" - else: - model_kwargs["device_map"] = "auto" - print("Loading model in auto mode") + model, tokenizer = None, None + + if args.inference_platform == "hf": + dict_precisions = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, + } + if args.precision not in dict_precisions: + raise ValueError( + f"Non valid precision {args.precision}, choose from: fp16, fp32, bf16" + ) - if args.modeltype == "causal": - model = AutoModelForCausalLM.from_pretrained( - args.model, - **model_kwargs, - ) - elif args.modeltype == "seq2seq": - warnings.warn( - "Seq2Seq models have only been tested for HumanEvalPack & CodeT5+ models." - ) - model = AutoModelForSeq2SeqLM.from_pretrained( - args.model, - **model_kwargs, - ) - else: - raise ValueError( - f"Non valid modeltype {args.modeltype}, choose from: causal, seq2seq" - ) + model_kwargs = { + "revision": args.revision, + "trust_remote_code": args.trust_remote_code, + "token": args.use_auth_token, + } + if args.load_in_8bit: + print("Loading model in 8bit") + model_kwargs["load_in_8bit"] = args.load_in_8bit + model_kwargs["device_map"] = {"": accelerator.process_index} + elif args.load_in_4bit: + print("Loading model in 4bit") + model_kwargs["load_in_4bit"] = args.load_in_4bit + model_kwargs["torch_dtype"] = torch.float16 + model_kwargs["bnb_4bit_compute_dtype"] = torch.float16 + model_kwargs["device_map"] = {"": accelerator.process_index} + else: + print(f"Loading model in {args.precision}") + model_kwargs["torch_dtype"] = dict_precisions[args.precision] + + if args.max_memory_per_gpu: + if args.max_memory_per_gpu != "auto": + model_kwargs["max_memory"] = get_gpus_max_memory( + args.max_memory_per_gpu, accelerator.num_processes + ) + model_kwargs["offload_folder"] = "offload" + else: + model_kwargs["device_map"] = "auto" + print("Loading model in auto mode") + + if args.modeltype == "causal": + model = AutoModelForCausalLM.from_pretrained( + args.model, + **model_kwargs, + ) + elif args.modeltype == "seq2seq": + warnings.warn( + "Seq2Seq models have only been tested for HumanEvalPack & CodeT5+ models." + ) + model = AutoModelForSeq2SeqLM.from_pretrained( + args.model, + **model_kwargs, + ) + else: + raise ValueError( + f"Non valid modeltype {args.modeltype}, choose from: causal, seq2seq" + ) - if args.peft_model: - from peft import PeftModel # dynamic import to avoid dependency on peft - - model = PeftModel.from_pretrained(model, args.peft_model) - print("Loaded PEFT model. Merging...") - model.merge_and_unload() - print("Merge complete.") - - if args.left_padding: - # left padding is required for some models like chatglm3-6b - tokenizer = AutoTokenizer.from_pretrained( - args.model, - revision=args.revision, - trust_remote_code=args.trust_remote_code, - token=args.use_auth_token, - padding_side="left", - ) - else: - # used by default for most models - tokenizer = AutoTokenizer.from_pretrained( - args.model, - revision=args.revision, - trust_remote_code=args.trust_remote_code, - token=args.use_auth_token, - truncation_side="left", - padding_side="right", - ) - if not tokenizer.eos_token: - if tokenizer.bos_token: - tokenizer.eos_token = tokenizer.bos_token - print("bos_token used as eos_token") + if args.peft_model: + from peft import PeftModel # dynamic import to avoid dependency on peft + + model = PeftModel.from_pretrained(model, args.peft_model) + print("Loaded PEFT model. Merging...") + model.merge_and_unload() + print("Merge complete.") + + if args.left_padding: + # left padding is required for some models like chatglm3-6b + tokenizer = AutoTokenizer.from_pretrained( + args.model, + revision=args.revision, + trust_remote_code=args.trust_remote_code, + token=args.use_auth_token, + padding_side="left", + ) else: - raise ValueError("No eos_token or bos_token found") - try: - tokenizer.pad_token = tokenizer.eos_token - - # Some models like CodeGeeX2 have pad_token as a read-only property - except AttributeError: - print("Not setting pad_token to eos_token") - pass - WIZARD_LLAMA_MODELS = [ - "WizardLM/WizardCoder-Python-34B-V1.0", - "WizardLM/WizardCoder-34B-V1.0", - "WizardLM/WizardCoder-Python-13B-V1.0" - ] - if args.model in WIZARD_LLAMA_MODELS: - tokenizer.bos_token = "" - tokenizer.bos_token_id = 1 - print("Changing bos_token to ") + # used by default for most models + tokenizer = AutoTokenizer.from_pretrained( + args.model, + revision=args.revision, + trust_remote_code=args.trust_remote_code, + token=args.use_auth_token, + truncation_side="left", + padding_side="right", + ) + if not tokenizer.eos_token: + if tokenizer.bos_token: + tokenizer.eos_token = tokenizer.bos_token + print("bos_token used as eos_token") + else: + raise ValueError("No eos_token or bos_token found") + try: + tokenizer.pad_token = tokenizer.eos_token + + # Some models like CodeGeeX2 have pad_token as a read-only property + except AttributeError: + print("Not setting pad_token to eos_token") + pass + WIZARD_LLAMA_MODELS = [ + "WizardLM/WizardCoder-Python-34B-V1.0", + "WizardLM/WizardCoder-34B-V1.0", + "WizardLM/WizardCoder-Python-13B-V1.0" + ] + if args.model in WIZARD_LLAMA_MODELS: + tokenizer.bos_token = "" + tokenizer.bos_token_id = 1 + print("Changing bos_token to ") evaluator = Evaluator(accelerator, model, tokenizer, args) diff --git a/tests/test_generation_evaluation.py b/tests/test_generation_evaluation.py index 2f5062bff..0aae6c964 100644 --- a/tests/test_generation_evaluation.py +++ b/tests/test_generation_evaluation.py @@ -54,6 +54,7 @@ def update_args(args): args.precision = None args.modeltype = None args.max_memory_per_gpu = None + args.inference_platform = "hf" return args