-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
15 changed files
with
1,055 additions
and
4 deletions.
There are no files selected for viewing
8 changes: 8 additions & 0 deletions
8
packages/ragbits-evaluate/examples/document-search/config/config.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
defaults: | ||
- data: hf-docs | ||
- setup: baseline | ||
- _self_ | ||
|
||
neptune: | ||
project: ragbits | ||
run: False |
2 changes: 2 additions & 0 deletions
2
packages/ragbits-evaluate/examples/document-search/config/data/hf-docs.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
path: "m-ric/huggingface_doc_qa_eval" | ||
split: "train" |
1 change: 1 addition & 0 deletions
1
packages/ragbits-evaluate/examples/document-search/config/setup/baseline.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
name: BASELINE |
75 changes: 75 additions & 0 deletions
75
packages/ragbits-evaluate/examples/document-search/evaluate.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import asyncio | ||
import logging | ||
from pathlib import Path | ||
|
||
import hydra | ||
import neptune | ||
from hydra.core.hydra_config import HydraConfig | ||
from neptune.utils import stringify_unsupported | ||
from omegaconf import DictConfig | ||
|
||
from ragbits.evaluate.evaluator import Evaluator | ||
from ragbits.evaluate.loaders import HuggingFaceDataLoader | ||
from ragbits.evaluate.metrics import MetricSet | ||
from ragbits.evaluate.pipelines import DocumentSearchEvaluationPipeline | ||
from ragbits.evaluate.utils import save | ||
|
||
logging.getLogger("LiteLLM").setLevel(logging.ERROR) | ||
logging.getLogger("httpx").setLevel(logging.ERROR) | ||
log = logging.getLogger(__name__) | ||
|
||
|
||
async def bench(config: DictConfig) -> None: | ||
""" | ||
Function running evaluation for all datasets and evaluation tasks defined in hydra config. | ||
Args: | ||
config: Hydra configuration. | ||
""" | ||
log.info("Starting evaluation: %s", config.setup.name) | ||
|
||
dataloader = HuggingFaceDataLoader(config) | ||
pipeline = DocumentSearchEvaluationPipeline(config) | ||
metrics = MetricSet()(config) | ||
|
||
evaluator = Evaluator(task="document_search") | ||
results = await evaluator.compute( | ||
pipeline=pipeline, | ||
dataloader=dataloader, | ||
metrics=metrics, | ||
) | ||
|
||
log.info("Evaluation finished. Saving results...") | ||
|
||
output_dir = Path(HydraConfig.get().runtime.output_dir) | ||
metrics_file = output_dir / "metrics.json" | ||
results_file = output_dir / "results.json" | ||
|
||
save(metrics_file, metrics=results["metrics"], time_perf=results["time_perf"]) | ||
save(results_file, results=results["results"]) | ||
|
||
log.info("Evaluation results saved under directory: %s", output_dir) | ||
|
||
if config.neptune.run: | ||
run = neptune.init_run(project=config.neptune.project) | ||
run["sys/tags"].add(config.setup.name) | ||
run["config"] = stringify_unsupported(config) | ||
run["evaluation/metrics"] = stringify_unsupported(results["metrics"]) | ||
run["evaluation/time_perf"] = stringify_unsupported(results["time_perf"]) | ||
run["evaluation/metrics.json"].upload(metrics_file.as_posix()) | ||
run["evaluation/results.json"].upload(results_file.as_posix()) | ||
|
||
|
||
@hydra.main(config_path="config", config_name="config", version_base="3.2") | ||
def main(config: DictConfig) -> None: | ||
""" | ||
Function running evaluation for all datasets and evaluation tasks defined in hydra config. | ||
Args: | ||
config: Hydra configuration. | ||
""" | ||
asyncio.run(bench(config)) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() # pylint: disable=no-value-for-parameter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from .evaluator import Evaluator | ||
from .loaders import DataLoader, HuggingFaceDataLoader | ||
from .metrics import Metric, MetricSet | ||
from .pipelines import DocumentSearchEvaluationPipeline | ||
from .utils import save | ||
|
||
__all__ = [ | ||
"Evaluator", | ||
"DataLoader", | ||
"HuggingFaceDataLoader", | ||
"MetricSet", | ||
"Metric", | ||
"DocumentSearchEvaluationPipeline", | ||
"save", | ||
] |
121 changes: 121 additions & 0 deletions
121
packages/ragbits-evaluate/src/ragbits/evaluate/evaluator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import time | ||
from dataclasses import asdict | ||
from typing import Any, Iterable | ||
|
||
from tqdm.asyncio import tqdm | ||
|
||
from .loaders import DataLoader | ||
from .metrics.base import MetricSet | ||
from .pipelines.base import EvaluationPipeline, EvaluationResult | ||
|
||
|
||
class Evaluator: | ||
""" | ||
Evaluator class. | ||
""" | ||
|
||
def __init__(self, task: str) -> None: | ||
""" | ||
Constructs the evaluator. | ||
Args: | ||
task: The task for the evaluator. | ||
""" | ||
self.task = task | ||
|
||
async def compute( | ||
self, | ||
pipeline: EvaluationPipeline, | ||
dataloader: DataLoader, | ||
metrics: MetricSet, | ||
) -> dict[str, Any]: | ||
""" | ||
Compute the evaluation results for the given pipeline and data. | ||
Args: | ||
pipeline: The pipeline to be evaluated. | ||
dataloader: The dataloader to load the data. | ||
metrics: The metrics to be computed. | ||
Returns: | ||
The evaluation results. | ||
""" | ||
dataset = await dataloader.load() | ||
results, perf_results = await self._call_pipeline(pipeline, dataset) | ||
computed_metrics = self._compute_metrics(metrics, results) | ||
processed_results = self._results_processor(results) | ||
|
||
return { | ||
**perf_results, | ||
**computed_metrics, | ||
**processed_results, | ||
} | ||
|
||
async def _call_pipeline( | ||
self, | ||
pipeline: EvaluationPipeline, | ||
dataset: Iterable, | ||
) -> tuple[list[EvaluationResult], dict[str, Any]]: | ||
""" | ||
Call the pipeline with the given data. | ||
Args: | ||
pipeline: The pipeline to be called. | ||
data: The evaluation data. | ||
Returns: | ||
The evaluation results and performance metrics. | ||
""" | ||
start_time = time.perf_counter() | ||
pipe_outputs = await tqdm.gather(*[pipeline(data) for data in dataset], desc="Evaluation") | ||
end_time = time.perf_counter() | ||
return pipe_outputs, self._compute_time_perf(start_time, end_time, len(pipe_outputs)) | ||
|
||
def _results_processor(self, results: list[EvaluationResult]) -> dict[str, Any]: | ||
""" | ||
Process the results. | ||
Args: | ||
results: The evaluation results. | ||
Returns: | ||
The processed results. | ||
""" | ||
return {"results": [asdict(result) for result in results]} | ||
|
||
def _compute_metrics(self, metrics: MetricSet, results: list[EvaluationResult]) -> dict[str, Any]: | ||
""" | ||
Compute a metric using the given inputs. | ||
Args: | ||
metrics: The metrics to be computed. | ||
results: The evaluation results. | ||
Returns: | ||
The computed metric. | ||
""" | ||
return {"metrics": metrics.compute(results)} | ||
|
||
def _compute_time_perf(self, start_time: float, end_time: float, num_samples: int) -> dict[str, Any]: | ||
""" | ||
Compute the performance metrics. | ||
Args: | ||
start_time: The start time. | ||
end_time: The end time. | ||
num_samples: The number of samples. | ||
Returns: | ||
The performance metrics. | ||
""" | ||
latency = end_time - start_time | ||
throughput = num_samples / latency | ||
latency_sample = 1.0 / throughput if throughput > 0 else 0.0 | ||
|
||
return { | ||
"time_perf": { | ||
"total_time_in_seconds": latency, | ||
"samples_per_second": throughput, | ||
"latency_in_seconds": latency_sample, | ||
}, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Generic, TypeVar, Union | ||
|
||
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict, load_dataset | ||
from omegaconf import DictConfig | ||
|
||
DataT = TypeVar("DataT") | ||
HFData = Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset] | ||
|
||
|
||
class DataLoader(Generic[DataT], ABC): | ||
""" | ||
Data loader. | ||
""" | ||
|
||
def __init__(self, config: DictConfig) -> None: | ||
self.config = config | ||
|
||
@abstractmethod | ||
async def load(self) -> DataT: | ||
""" | ||
Load the data. | ||
Returns: | ||
The loaded data. | ||
""" | ||
|
||
|
||
class HuggingFaceDataLoader(DataLoader[HFData]): | ||
""" | ||
Hugging Face data loader. | ||
""" | ||
|
||
async def load(self) -> HFData: | ||
""" | ||
Load the data from Hugging Face. | ||
Returns: | ||
The loaded data. | ||
""" | ||
return load_dataset( | ||
path=self.config.data.path, | ||
split=self.config.data.split, | ||
) |
3 changes: 3 additions & 0 deletions
3
packages/ragbits-evaluate/src/ragbits/evaluate/metrics/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .base import Metric, MetricSet | ||
|
||
__all__ = ["Metric", "MetricSet"] |
76 changes: 76 additions & 0 deletions
76
packages/ragbits-evaluate/src/ragbits/evaluate/metrics/base.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Any, Optional | ||
|
||
from omegaconf import DictConfig | ||
from typing_extensions import Self | ||
|
||
from ..pipelines import EvaluationResult | ||
|
||
|
||
class Metric(ABC): | ||
""" | ||
Base class for metrics. | ||
""" | ||
|
||
def __init__(self, config: Optional[DictConfig] = None) -> None: | ||
""" | ||
Initializes the metric. | ||
Args: | ||
config: The metric configuration. | ||
""" | ||
super().__init__() | ||
self.config = config or {} | ||
|
||
@abstractmethod | ||
def compute(self, results: list[EvaluationResult]) -> dict[str, Any]: | ||
""" | ||
Compute the metric. | ||
Args: | ||
results: The evaluation results. | ||
Returns: | ||
The computed metric. | ||
""" | ||
|
||
|
||
class MetricSet: | ||
""" | ||
Represents a set of metrics. | ||
""" | ||
|
||
def __init__(self, *metrics: type[Metric]) -> None: | ||
""" | ||
Initializes the metric set. | ||
Args: | ||
metrics: The metrics. | ||
""" | ||
self._metrics = metrics | ||
self.metrics: list[Metric] = [] | ||
|
||
def __call__(self, config: Optional[DictConfig] = None) -> Self: | ||
""" | ||
Initializes the metrics. | ||
Args: | ||
config: The configuration for the metrics. | ||
Returns: | ||
The initialized metric set. | ||
""" | ||
self.metrics = [metric(config) for metric in self._metrics] | ||
return self | ||
|
||
def compute(self, results: list[EvaluationResult]) -> dict[str, Any]: | ||
""" | ||
Compute the metrics. | ||
Args: | ||
results: The evaluation results. | ||
Returns: | ||
The computed metrics. | ||
""" | ||
return {name: value for metric in self.metrics for name, value in metric.compute(results).items()} |
8 changes: 8 additions & 0 deletions
8
packages/ragbits-evaluate/src/ragbits/evaluate/pipelines/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from .base import EvaluationPipeline, EvaluationResult | ||
from .document_search import DocumentSearchEvaluationPipeline | ||
|
||
__all__ = [ | ||
"DocumentSearchEvaluationPipeline", | ||
"EvaluationPipeline", | ||
"EvaluationResult", | ||
] |
Oops, something went wrong.