diff --git a/langkit/metrics/topic.py b/langkit/metrics/topic.py index 8b79e56..ce27a33 100644 --- a/langkit/metrics/topic.py +++ b/langkit/metrics/topic.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from functools import partial +from functools import lru_cache, partial from typing import Any, Dict, List, Optional import pandas as pd @@ -7,7 +7,6 @@ from transformers import Pipeline, pipeline # type: ignore from langkit.core.metric import MetricCreator, MultiMetric, MultiMetricResult -from langkit.metrics.util import LazyInit __default_topics = [ "medicine", @@ -19,21 +18,35 @@ _hypothesis_template = "This example is about {}" -__classifier: LazyInit[Pipeline] = LazyInit( - lambda: pipeline( +@lru_cache(maxsize=None) +def _get_classifier(model_version: str) -> Pipeline: + return pipeline( "zero-shot-classification", - model="MoritzLaurer/xtremedistil-l6-h256-zeroshot-v1.1-all-33", + model=model_version, device="cuda" if torch.cuda.is_available() else "cpu", ) -) + + +MODEL_SMALL = "MoritzLaurer/xtremedistil-l6-h256-zeroshot-v1.1-all-33" +MODEL_BASE = "MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33" +MODEL_LARGE = "MoritzLaurer/deberta-v3-large-zeroshot-v1.1-all-33" def __get_scores_per_label( - text: str, topics: List[str], hypothesis_template: str = _hypothesis_template, multi_label: bool = True + text: str, + topics: List[str], + hypothesis_template: str = _hypothesis_template, + multi_label: bool = True, + model_version: str = MODEL_SMALL, ) -> Optional[Dict[str, float]]: if not text: return None - result: Dict[str, [str, float]] = __classifier.value(text, topics, hypothesis_template=hypothesis_template, multi_label=multi_label) # type: ignore + result: Dict[str, [str, float]] = _get_classifier(model_version)( # type: ignore + text, + topics, + hypothesis_template=hypothesis_template, + multi_label=multi_label, # type: ignore + ) scores_per_label: Dict[str, float] = {label: score for label, score in zip(result["labels"], result["scores"])} # type: ignore[reportUnknownVariableType] return scores_per_label @@ -45,7 +58,9 @@ def _sanitize_metric_name(topic: str) -> str: return topic.replace(" ", "_").lower() -def topic_metric(input_name: str, topics: List[str], hypothesis_template: Optional[str] = None) -> MultiMetric: +def topic_metric( + input_name: str, topics: List[str], hypothesis_template: Optional[str] = None, model_version: str = MODEL_SMALL +) -> MultiMetric: hypothesis_template = hypothesis_template or _hypothesis_template def udf(text: pd.DataFrame) -> MultiMetricResult: @@ -53,7 +68,7 @@ def udf(text: pd.DataFrame) -> MultiMetricResult: def process_row(row: pd.DataFrame) -> Dict[str, List[Optional[float]]]: value: Any = row[input_name] # type: ignore - scores = __get_scores_per_label(value, topics=topics, hypothesis_template=hypothesis_template) # pyright: ignore[reportUnknownArgumentType] + scores = __get_scores_per_label(value, topics=topics, hypothesis_template=hypothesis_template, model_version=model_version) # pyright: ignore[reportUnknownArgumentType] for topic in topics: metrics[topic].append(scores[topic] if scores else None) return metrics @@ -67,15 +82,15 @@ def process_row(row: pd.DataFrame) -> Dict[str, List[Optional[float]]]: return MultiMetricResult(metrics=all_metrics) def cache_assets(): - __classifier.value + _get_classifier(model_version) metric_names = [f"{input_name}.topics.{_sanitize_metric_name(topic)}" for topic in topics] return MultiMetric(names=metric_names, input_name=input_name, evaluate=udf, cache_assets=cache_assets) -prompt_topic_module = partial(topic_metric, "prompt", __default_topics, _hypothesis_template) -response_topic_module = partial(topic_metric, "response", __default_topics, _hypothesis_template) -prompt_response_topic_module = [prompt_topic_module, response_topic_module, _hypothesis_template] +prompt_topic_module = partial(topic_metric, "prompt", __default_topics, _hypothesis_template, MODEL_SMALL) +response_topic_module = partial(topic_metric, "response", __default_topics, _hypothesis_template, MODEL_SMALL) +prompt_response_topic_module = [prompt_topic_module, response_topic_module, _hypothesis_template, MODEL_SMALL] @dataclass @@ -85,9 +100,11 @@ class CustomTopicModules: prompt_response_topic_module: MetricCreator -def get_custom_topic_modules(topics: List[str], template: str = _hypothesis_template) -> CustomTopicModules: - prompt_topic_module = partial(topic_metric, "prompt", topics, template) - response_topic_module = partial(topic_metric, "response", topics, template) +def get_custom_topic_modules( + topics: List[str], template: str = _hypothesis_template, model_version: str = MODEL_SMALL +) -> CustomTopicModules: + prompt_topic_module = partial(topic_metric, "prompt", topics, template, model_version) + response_topic_module = partial(topic_metric, "response", topics, template, model_version) return CustomTopicModules( prompt_topic_module=prompt_topic_module, response_topic_module=response_topic_module, diff --git a/tests/langkit/metrics/test_topic.py b/tests/langkit/metrics/test_topic.py index 1dfe4b3..54767be 100644 --- a/tests/langkit/metrics/test_topic.py +++ b/tests/langkit/metrics/test_topic.py @@ -1,4 +1,5 @@ # pyright: reportUnknownMemberType=none +from functools import partial from typing import Any import pandas as pd @@ -7,7 +8,7 @@ from langkit.core.metric import WorkflowMetricConfig, WorkflowMetricConfigBuilder from langkit.core.workflow import Workflow from langkit.metrics.library import lib -from langkit.metrics.topic import get_custom_topic_modules, prompt_topic_module +from langkit.metrics.topic import MODEL_BASE, get_custom_topic_modules, prompt_topic_module, topic_metric from langkit.metrics.whylogs_compat import create_whylogs_udf_schema expected_metrics = [ @@ -81,6 +82,21 @@ def test_topic(): assert actual.index.tolist() == expected_columns +def test_topic_base_model(): + df = pd.DataFrame( + { + "prompt": [ + "http://get-free-money-now.xyz/bank/details", + ], + } + ) + + custom_topic_module = partial(topic_metric, "prompt", ["phishing"], model_version=MODEL_BASE) + schema = WorkflowMetricConfigBuilder().add(custom_topic_module).build() + actual = _log(df, schema) + assert actual.loc["prompt.topics.phishing"]["distribution/mean"] > 0.80 + + def test_topic_empty_input(): df = pd.DataFrame( { @@ -243,6 +259,25 @@ def test_custom_topic(): assert actual.loc[column]["distribution/max"] >= 0.50 +def test_custom_topics_base_model(): + df = pd.DataFrame( + { + "prompt": [ + "http://get-free-money-now.xyz/bank/details", + ], + "response": [ + "http://win-a-free-iphone-today.net", + ], + } + ) + + custom_topic_modules = get_custom_topic_modules(["phishing"], model_version=MODEL_BASE) + schema = WorkflowMetricConfigBuilder().add(custom_topic_modules.prompt_response_topic_module).build() + actual = _log(df, schema) + assert actual.loc["prompt.topics.phishing"]["distribution/mean"] > 0.80 + assert actual.loc["response.topics.phishing"]["distribution/mean"] > 0.80 + + def test_topic_name_sanitize(): df = pd.DataFrame( {