Skip to content

Support for remote inference #302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions bigcode_eval/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,20 @@ class EvalArguments:
seed: Optional[int] = field(
default=0, metadata={"help": "Random seed used for evaluation."}
)
length_penalty: Optional[dict[str, int | float]] = field(
default=None,
metadata={"help": "A dictionary with length penalty options (for watsonx.ai)."}
)
max_new_tokens: Optional[int] = field(
default=None, metadata={"help": "Maximum number of generated tokens (for watsonx.ai)."}
)
min_new_tokens: Optional[int] = field(
default=None, metadata={"help": "Minimum number of generated tokens (for watsonx.ai)."}
)
stop_sequences: Optional[list[str]] = field(
default=None, metadata={"help": "List of stop sequences (for watsonx.ai)."}
)
repetition_penalty: Optional[float] = field(
default=None,
metadata={"help": "A float value of repetition penalty (for watsonx.ai)."}
)
8 changes: 7 additions & 1 deletion bigcode_eval/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.utils.data.dataloader import DataLoader
from transformers import StoppingCriteria, StoppingCriteriaList

from bigcode_eval.remote_inference.utils import remote_inference
from bigcode_eval.utils import TokenizedDataset, complete_code


Expand Down Expand Up @@ -62,6 +63,11 @@ def parallel_generations(
)
return generations[:n_tasks]

if args.inference_platform != "hf":
return remote_inference(
args.inference_platform, dataset, task, args
)

set_seed(args.seed, device_specific=True)

# Setup generation settings
Expand Down Expand Up @@ -89,7 +95,7 @@ def parallel_generations(
stopping_criteria.append(
TooLongFunctionCriteria(0, task.max_length_multiplier)
)

if stopping_criteria:
gen_kwargs["stopping_criteria"] = StoppingCriteriaList(stopping_criteria)

Expand Down
Empty file.
174 changes: 174 additions & 0 deletions bigcode_eval/remote_inference/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
from abc import abstractmethod
from argparse import Namespace
from typing import Any, Optional

from datasets import Dataset as HfDataset

from bigcode_eval.base import Task
from bigcode_eval.utils import _parse_instruction


Dataset = HfDataset | list[dict[str, Any]]


class RemoteInferenceInterface:
@abstractmethod
def __init__(self):
raise NotImplementedError

@abstractmethod
def _prepare_generation_params(self, args: Namespace) -> dict[str, Any]:
"""Method maps HF generation parameters to platform-specific ones."""

raise NotImplementedError

@staticmethod
def _limit_inputs(
dataset: Dataset, limit: Optional[int], offset: Optional[int]
) -> Dataset:
"""Method limits input dataset based on provided `limit` and `limit_start` args."""

is_hf = isinstance(dataset, HfDataset)

if offset:
dataset = (
dataset.select(range(offset, len(dataset)))
if is_hf
else dataset[offset:]
)

if limit:
dataset = (
dataset.take(limit)
if is_hf
else dataset[:limit]
)

return dataset

@staticmethod
def _make_instruction_prompt(
instruction: str,
context: str,
prefix: str,
instruction_tokens: Optional[str],
) -> str:
"""Method creates a prompt for instruction-tuning based on a given prefix and instruction tokens."""

user_token, end_token, assistant_token = "", "", "\n"
if instruction_tokens:
user_token, end_token, assistant_token = instruction_tokens.split(",")

return "".join(
(
prefix,
user_token,
instruction,
end_token,
assistant_token,
context,
)
)

@staticmethod
def _make_infill_prompt(prefix: str, content_prefix: str, content_suffix: str) -> str:
"""Method creates a prompt for infilling.
As it depends on particular models, it may be necessary to implement the method separately for each platform.
"""

return f"{prefix}{content_prefix}{content_suffix}"

def _create_prompt_from_dict(
self, content: dict[str, str], prefix: str, instruction_tokens: Optional[str]
) -> str:
"""Method prepares a prompt in similar way to the `TokenizedDataset` class for either instruction or infilling mode."""

if all(key in ("instruction", "context") for key in content):
return self._make_instruction_prompt(
content["instruction"], content["context"], prefix, instruction_tokens
)

elif all(key in ("prefix", "suffix") for key in content):
return self._make_infill_prompt(prefix, content["prefix"], content["suffix"])

else:
raise ValueError(f"Unsupported prompt format:\n{content}.")

def _prepare_prompts(
self,
dataset: Dataset,
task: Task,
prefix: str,
instruction_tokens: Optional[str],
) -> list[str]:
"""Method creates prompts for inputs based on the task prompt, prefix and instruction tokens (if applicable)."""

is_string = isinstance(task.get_prompt(dataset[0]), str)

return [
prefix + task.get_prompt(instance)
if is_string
else self._create_prompt_from_dict(
task.get_prompt(instance), prefix, instruction_tokens
)
for instance in dataset
]

@abstractmethod
def _infer(
self, inputs: list[str], params: dict[str, Any], args: Namespace
) -> list[list[str]]:
"""Method responsible for inference on a given platform."""

raise NotImplementedError

@staticmethod
def _postprocess_predictions(
predictions: list[list[str]],
prompts: list[str],
task: Task,
instruction_tokens: Optional[str],
) -> list[list[str]]:
"""Method postprocess model's predictions based on a given task and instruction tokens (if applicable)."""

if instruction_tokens:
predictions = [
[_parse_instruction(prediction[0], instruction_tokens.split(","))]
for prediction in predictions
]

return [
[
task.postprocess_generation(
prompts[i] + predictions[i][0], i
)
]
for i in range(len(predictions))
]

def prepare_generations(
self,
dataset: Dataset,
task: Task,
args: Namespace,
prefix: str = "",
postprocess: bool = True,
) -> list[list[str]]:
"""Method generates (and postprocess) code using given platform. It follows the same process as HF inference."""

gen_params = self._prepare_generation_params(args)

dataset = self._limit_inputs(dataset, args.limit, args.limit_start)

prompts = self._prepare_prompts(
dataset, task, prefix, args.instruction_tokens
)

predictions = self._infer(prompts, gen_params, args)

if postprocess:
return self._postprocess_predictions(
predictions, prompts, task, args.instruction_tokens
)

return predictions
55 changes: 55 additions & 0 deletions bigcode_eval/remote_inference/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from argparse import Namespace
from importlib import import_module

from bigcode_eval.base import Task

from bigcode_eval.remote_inference.base import Dataset, RemoteInferenceInterface
from bigcode_eval.remote_inference.wx_ai import WxInference


required_packages = {
"wx": ["ibm_watsonx_ai"],
}


def check_packages_installed(names: list[str]) -> bool:
for name in names:
try:
import_module(name)
except (ImportError, ModuleNotFoundError, NameError):
return False
return True


def remote_inference(
inference_platform: str,
dataset: Dataset,
task: Task,
args: Namespace,
) -> list[list[str]]:
packages = required_packages.get(inference_platform)
if packages and not check_packages_installed(packages):
raise RuntimeError(
f"In order to run inference with '{inference_platform}', the "
f"following packages are required: '{packages}'. However, they "
f"could not be properly imported. Check if the packages are "
f"installed correctly."
)

inference_cls: RemoteInferenceInterface

if inference_platform == "wx":
inference_cls = WxInference()

else:
raise ValueError(
f"Unsupported remote inference platform: '{inference_platform}'."
)

return inference_cls.prepare_generations(
dataset=dataset,
task=task,
args=args,
prefix=args.prefix,
postprocess=args.postprocess,
)
109 changes: 109 additions & 0 deletions bigcode_eval/remote_inference/wx_ai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import logging
import os
from argparse import Namespace
from typing import Any

from ibm_watsonx_ai import APIClient
from ibm_watsonx_ai.foundation_models import ModelInference

from bigcode_eval.remote_inference.base import RemoteInferenceInterface


class WxInference(RemoteInferenceInterface):
def __init__(self):
creds = self._read_wx_credentials()

self.client = APIClient(credentials=creds)

if "project_id" in creds:
self.client.set.default_project(creds["project_id"])
if "space_id" in creds:
self.client.set.default_space(creds["space_id"])

@staticmethod
def _read_wx_credentials() -> dict[str, str]:
credentials = {}

url = os.environ.get("WX_URL")
if not url:
raise EnvironmentError(
"You need to specify the URL address by setting the env "
"variable 'WX_URL', if you want to run watsonx.ai inference."
)
credentials["url"] = url

project_id = os.environ.get("WX_PROJECT_ID")
space_id = os.environ.get("WX_SPACE_ID")
if project_id and space_id:
logging.warning(
"Both the project ID and the space ID were specified. "
"The class 'WxInference' will access the project by default."
)
credentials["project_id"] = project_id
elif project_id:
credentials["project_id"] = project_id
elif space_id:
credentials["space_id"] = space_id
else:
raise EnvironmentError(
"You need to specify the project ID or the space id by setting the "
"appropriate env variable (either 'WX_PROJECT_ID' or 'WX_SPACE_ID'), "
"if you want to run watsonx.ai inference."
)

apikey = os.environ.get("WX_APIKEY")
username = os.environ.get("WX_USERNAME")
password = os.environ.get("WX_PASSWORD")
if apikey and username and password:
logging.warning(
"All of API key, username and password were specified. "
"The class 'WxInference' will use the API key for authorization "
"by default."
)
credentials["apikey"] = apikey
elif apikey:
credentials["apikey"] = apikey
elif username and password:
credentials["username"] = username
credentials["password"] = password
else:
raise EnvironmentError(
"You need to specify either the API key, or both the username and "
"password by setting appropriate env variable ('WX_APIKEY', 'WX_USERNAME', "
"'WX_PASSWORD'), if you want to run watsonx.ai inference."
)

return credentials

def _prepare_generation_params(self, args: Namespace) -> dict[str, Any]:
"""Method maps generation parameters from args to be compatible with watsonx.ai."""

return {
"decoding_method": "sample" if args.do_sample else "greedy",
"random_seed": None if args.seed == 0 else args.seed, # seed must be greater than 0
"temperature": args.temperature,
"top_p": args.top_p,
"top_k": None if args.top_k == 0 else args.top_k, # top_k cannot be 0
"max_new_tokens": args.max_new_tokens,
"min_new_tokens": args.min_new_tokens,
"length_penalty": args.length_penalty,
"stop_sequences": args.stop_sequences,
"repetition_penalty": args.repetition_penalty,
}

def _infer(
self, inputs: list[str], params: dict[str, Any], args: Namespace
) -> list[list[str]]:
model = ModelInference(
model_id=args.model,
api_client=self.client,
)

return [
[result["results"][0]["generated_text"]]
for result in
model.generate(
prompt=inputs,
params=params,
)
]
Loading