Skip to content

Commit

Permalink
Merge branch 'main' into add_litellm_inference
Browse files Browse the repository at this point in the history
  • Loading branch information
JoelNiklaus authored Nov 20, 2024
2 parents dabb4a7 + 85c0d9f commit 65f759c
Show file tree
Hide file tree
Showing 16 changed files with 459 additions and 27 deletions.
4 changes: 2 additions & 2 deletions .github/ISSUE_TEMPLATE/evaluation-task-request.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ assignees: ''

## Evaluation metadata
Provide all available
- Paper url:
- Github url:
- Paper url:
- Github url:
- Dataset url:
1 change: 0 additions & 1 deletion .github/ISSUE_TEMPLATE/feature-request.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,3 @@ A clear and concise description of what you want to happen.

## Posssible alternatives
A clear and concise description of any alternative solutions or features you've considered.

12 changes: 10 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,17 @@ Harness and HELM teams for their pioneering work on LLM evaluations.
Got ideas? Found a bug? Want to add a
[task](https://github.com/huggingface/lighteval/wiki/Adding-a-Custom-Task) or
[metric](https://github.com/huggingface/lighteval/wiki/Adding-a-New-Metric)?
Contributions are warmly
welcomed!
Contributions are warmly welcomed!

If you're adding a new feature, please open an issue first.

If you open a PR, don't forget to run the styling!

```bash
pip install -e .[dev]
pre-commit install
pre-commit run --all-files
```
## 📜 Citation

```bibtex
Expand Down
4 changes: 2 additions & 2 deletions examples/model_configs/peft_model.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
model:
type: "base"
type: "base"
base_params:
model_args: "pretrained=predibase/customer_support,revision=main" # pretrained=model_name,trust_remote_code=boolean,revision=revision_to_use,model_parallel=True ... For a PEFT model, the pretrained model should be the one trained with PEFT and the base model below will contain the original model on which the adapters will be applied.
dtype: "4bit" # Specifying the model to be loaded in 4 bit uses BitsAndBytesConfig. The other option is to use "8bit" quantization.
dtype: "4bit" # Specifying the model to be loaded in 4 bit uses BitsAndBytesConfig. The other option is to use "8bit" quantization.
compile: true
merged_weights: # Ignore this section if you are not using PEFT models
delta_weights: false # set to True of your model should be merged with a base model, also need to provide the base model name
Expand Down
4 changes: 2 additions & 2 deletions examples/model_configs/quantized_model.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
model:
type: "base"
type: "base"
base_params:
model_args: "pretrained=HuggingFaceH4/zephyr-7b-beta,revision=main" # pretrained=model_name,trust_remote_code=boolean,revision=revision_to_use,model_parallel=True ...
dtype: "4bit" # Specifying the model to be loaded in 4 bit uses BitsAndBytesConfig. The other option is to use "8bit" quantization.
dtype: "4bit" # Specifying the model to be loaded in 4 bit uses BitsAndBytesConfig. The other option is to use "8bit" quantization.
compile: true
merged_weights: # Ignore this section if you are not using PEFT models
delta_weights: false # set to True of your model should be merged with a base model, also need to provide the base model name
Expand Down
3 changes: 3 additions & 0 deletions src/lighteval/logging/hierarchical_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@

logger = get_logger(__name__, log_level="INFO")
elif is_accelerate_available():
from accelerate import Accelerator, InitProcessGroupKwargs
from accelerate.logging import get_logger

# We must init the accelerator before using the logger
accelerator = Accelerator(kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(seconds=3000))])
logger = get_logger(__name__, log_level="INFO")
else:
logger = Logger(__name__, level="INFO")
Expand Down
10 changes: 9 additions & 1 deletion src/lighteval/metrics/metrics_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import sacrebleu
import sklearn.metrics

from lighteval.logging.hierarchical_logger import hlog_warn
from lighteval.metrics.sample_preparator import (
GenerativeCorpusMetricInput,
LogprobCorpusMetricInput,
Expand Down Expand Up @@ -103,7 +104,14 @@ def __init__(self, metric_type: str):
def compute(self, items: list[GenerativeCorpusMetricInput]) -> float:
"""Computes the metric score over all the corpus generated items, by using the sacrebleu implementation."""
golds = [i.golds for i in items]
preds = [as_list(i.preds) for i in items]
preds = []
for i in items:
pred = as_list(i.preds)
if len(pred) > 1:
hlog_warn(
f"Multiple predictions present, keeping only the first prediction (when computing sacrebleu.{self.metric.__name__})."
)
preds.append(pred[0])
return float(self.metric(hypotheses=preds, references=golds).score)


Expand Down
1 change: 1 addition & 0 deletions src/lighteval/metrics/sample_preparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class PerplexityCorpusMetricInput(CorpusMetricInput):


class GenerativePreparator:
@staticmethod
def prepare(golds: list[str], predictions: list[str], **kwargs):
"""Prepares an individual generative example to the format expected by metrics computed at the corpus level (aggregated).
Expand Down
9 changes: 8 additions & 1 deletion src/lighteval/models/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,14 @@ class AdapterModel(BaseModel):
def _create_auto_tokenizer(self, config: AdapterModelConfig, env_config: EnvConfig) -> PreTrainedTokenizer:
# By default, we look at the model config for the model stored in `base_model`
# (= the parent model, not the model of interest)
return self._create_auto_tokenizer_with_name(config.base_model, config=config, env_config=env_config)
return self._create_auto_tokenizer_with_name(
model_name=config.base_model,
revision=config.revision,
env_config=env_config,
tokenizer_name=config.tokenizer,
subfolder=config.subfolder,
trust_remote_code=config.trust_remote_code,
)

def _create_auto_model(self, config: AdapterModelConfig, env_config: EnvConfig) -> AutoModelForCausalLM:
"""Returns a PeftModel from a base model and a version fined tuned using PEFT."""
Expand Down
109 changes: 100 additions & 9 deletions src/lighteval/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES

from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset, LoglikelihoodSingleTokenDataset
from lighteval.logging.hierarchical_logger import hlog, hlog_err, hlog_warn
Expand Down Expand Up @@ -57,6 +58,7 @@


if is_accelerate_available():
from accelerate import Accelerator
from accelerate.utils import calculate_maximum_sizes, convert_bytes, get_max_memory

os.environ["TOKENIZERS_PARALLELISM"] = "false"
Expand All @@ -67,8 +69,8 @@
class BaseModel(LightevalModel):
def __init__(
self,
config: BaseModelConfig,
env_config: EnvConfig,
config: BaseModelConfig,
):
"""Initializes a HuggingFace `AutoModel` and `AutoTokenizer` for evaluation."""
self._config = config.init_configs(env_config)
Expand Down Expand Up @@ -114,6 +116,72 @@ def __init__(

self.pairwise_tokenization = config.pairwise_tokenization

@classmethod
def from_model(
cls,
model: Union[AutoModelForCausalLM, LightevalModel],
env_config: EnvConfig,
accelerator: "Accelerator" = None,
tokenizer_name: str = None, # custom tokenizer
trust_remote_code: bool = False,
use_chat_template: bool = False,
add_special_tokens: bool = True,
pairwise_tokenization: bool = False,
multichoice_continuations_start_space: bool = None,
):
# Slightly hackish way to test if the model is a AutoModelForCausalLM, since the instances don't
# derive from this class explicitely
assert isinstance(model, LightevalModel) or type(model).__name__ in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values()

if isinstance(model, LightevalModel):
return model

# Instanciate the object without using __init__
self = cls.__new__(cls)
self._config = model.config
self._max_length = self._init_max_length(max_length=model.config.max_length)
self._tokenizer = self._create_auto_tokenizer_with_name(
model_name=model.name_or_path,
revision=model.config._commit_hash,
env_config=env_config,
trust_remote_code=trust_remote_code,
tokenizer_name=tokenizer_name,
)
self.model_name = _simplify_name(model.name_or_path)
self.model_sha = model.config._commit_hash

# If model_parallel is not set we compare the number of processes with the number of GPUs
self.model = model
self.model.eval()
torch.set_grad_enabled(False)

self.accelerator = accelerator
if accelerator is not None:
self._device = accelerator.device
self.model = self.accelerator.prepare(self.model.to(accelerator.device))
else:
self._device = "cpu"

self.use_chat_template = use_chat_template
self._add_special_tokens = add_special_tokens if add_special_tokens is not None else False
self.pairwise_tokenization = pairwise_tokenization
self.multichoice_continuations_start_space = multichoice_continuations_start_space

self.precision = _get_dtype(model.dtype, config=self._config)

if is_accelerate_available():
model_size, _ = calculate_maximum_sizes(self.model)
model_size = convert_bytes(model_size)
else:
model_size = -1
self.model_info = ModelInfo(
model_name=self.model_name,
model_sha=self.model_sha,
model_dtype=self.precision,
model_size=model_size,
)
return self

@property
def tokenizer(self):
return self._tokenizer
Expand Down Expand Up @@ -207,10 +275,23 @@ def _create_auto_model(self, config: BaseModelConfig, env_config: EnvConfig) ->
def _create_auto_tokenizer(
self, config: BaseModelConfig, env_config: EnvConfig
) -> transformers.PreTrainedTokenizer:
return self._create_auto_tokenizer_with_name(config.pretrained, config=config, env_config=env_config)
return self._create_auto_tokenizer_with_name(
model_name=config.pretrained,
revision=config.revision,
env_config=env_config,
tokenizer_name=config.tokenizer,
subfolder=config.subfolder,
trust_remote_code=config.trust_remote_code,
)

def _create_auto_tokenizer_with_name(
self, model_name: str, config: BaseModelConfig, env_config: EnvConfig
self,
model_name: str,
revision: str,
env_config: EnvConfig,
tokenizer_name: str = None,
subfolder: str = None,
trust_remote_code: bool = False,
) -> transformers.PreTrainedTokenizer:
"""
Create a Hugging Face AutoTokenizer for language model.
Expand All @@ -231,25 +312,35 @@ def _create_auto_tokenizer_with_name(
"""
try:
tokenizer = AutoTokenizer.from_pretrained(
model_name if config.tokenizer is None else config.tokenizer,
revision=config.revision + (f"/{config.subfolder}" if config.subfolder is not None else ""),
model_name if tokenizer_name is None else tokenizer_name,
revision=revision + (f"/{subfolder}" if subfolder is not None else ""),
cache_dir=env_config.cache_dir,
token=env_config.token,
trust_remote_code=config.trust_remote_code,
trust_remote_code=trust_remote_code,
padding_side="left",
truncation_side="left",
)
except RecursionError:
tokenizer = AutoTokenizer.from_pretrained(
model_name if config.tokenizer is None else config.tokenizer,
revision=config.revision + (f"/{config.subfolder}" if config.subfolder is not None else ""),
model_name if tokenizer_name is None else tokenizer_name,
revision=revision + (f"/{subfolder}" if subfolder is not None else ""),
cache_dir=env_config.cache_dir,
token=env_config.token,
trust_remote_code=config.trust_remote_code,
trust_remote_code=trust_remote_code,
unk_token="<unk>",
padding_side="left",
truncation_side="left",
)
except FileNotFoundError:
hlog_warn("Problem when loading the tokenizer in the cache - discarding the provided cache path value.")
tokenizer = AutoTokenizer.from_pretrained(
model_name if tokenizer_name is None else tokenizer_name,
revision=revision + (f"/{subfolder}" if subfolder is not None else ""),
token=env_config.token,
trust_remote_code=trust_remote_code,
padding_side="left",
truncation_side="left",
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.model_max_length = self.max_length
hlog("Tokenizer truncation and padding size set to the left side.")
Expand Down
3 changes: 2 additions & 1 deletion src/lighteval/models/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def tokenizer(self):

def cleanup(self):
destroy_model_parallel()
del self.model.llm_engine.model_executor.driver_worker
if self.model is not None:
del self.model.llm_engine.model_executor.driver_worker
self.model = None
gc.collect()
ray.shutdown()
Expand Down
12 changes: 10 additions & 2 deletions src/lighteval/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from lighteval.logging.evaluation_tracker import EvaluationTracker
from lighteval.logging.hierarchical_logger import hlog, htrack_block
from lighteval.metrics.utils.metric_utils import MetricCategory
from lighteval.models.model_loader import load_model
from lighteval.models.model_loader import BaseModel, load_model
from lighteval.models.model_output import ModelResponse
from lighteval.tasks.lighteval_task import LightevalTask, create_requests_from_tasks
from lighteval.tasks.registry import Registry, taskinfo_selector
Expand Down Expand Up @@ -164,7 +164,15 @@ def _init_model(self, model_config, model):
)
else:
return load_model(config=model_config, env_config=self.pipeline_parameters.env_config)
return model
if isinstance(model, BaseModel):
return model
else:
return BaseModel.from_model(
model=model,
use_chat_template=self.pipeline_parameters.use_chat_template,
env_config=self.pipeline_parameters.env_config,
accelerator=self.accelerator,
)

def _init_tasks_and_requests(self, tasks: str):
with htrack_block("Tasks loading"):
Expand Down
10 changes: 9 additions & 1 deletion src/lighteval/tasks/templates/continuation.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def get_continuation_prompt_function(
language: Language,
adapter: Callable[[dict], ContinuationInput | None] | ContinuationDictAdapter,
formulation: Formulation = MCFFormulation(),
fix_formatting: bool = True,
):
"""
Create a templated prompt function for a Continuation task.
Expand Down Expand Up @@ -118,6 +119,7 @@ def get_continuation_prompt_function(
adapter (Callable[[dict], ContinuationInput] | ContinuationDictAdapter): Either a function that takes a dataset row and returns a ContinuationInput, or a dictionary with keys corresponding to the field names in the dataset row.
Note: Both ContinuationDictAdapter and ContinuationInput are TypeDicts, this means that the caller provides dictionary and doesn't initialize any class!
formulation (Formulation, optional): The formulation (MCF/Hybrid/CF) to use for the task. Defaults to MCFFormulation().
fix_formatting (bool, optional): Whether to fix the formatting of the text by capitalizing and fixing punctuation based on language. If False, the text will be used as-is. Defaults to True.
Returns:
Callable: A function that generates Continuation prompt based on the given parameters.
"""
Expand All @@ -132,10 +134,16 @@ def prepare_prompt(line: dict):
instruction_val = cont_input.get("instruction")
instruction = f"{instruction_val}\n" if instruction_val else ""

context = f"{capitalize(fix_ending_punct(cont_input['context'], translation_literals))}"
context = (
f"{capitalize(fix_ending_punct(cont_input['context'], translation_literals))}"
if fix_formatting
else cont_input["context"]
)

continuations = [
fix_capitalization(context, fix_ending_punct(continuation, translation_literals), translation_literals)
if fix_formatting
else continuation
for continuation in cont_input["continuations"]
]

Expand Down
Loading

0 comments on commit 65f759c

Please sign in to comment.