Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

topic model path [workflow] #276

Open
wants to merge 1 commit into
base: workflow
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 34 additions & 17 deletions langkit/metrics/topic.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from dataclasses import dataclass
from functools import partial
from functools import lru_cache, partial
from typing import Any, Dict, List, Optional

import pandas as pd
import torch
from transformers import Pipeline, pipeline # type: ignore

from langkit.core.metric import MetricCreator, MultiMetric, MultiMetricResult
from langkit.metrics.util import LazyInit

__default_topics = [
"medicine",
Expand All @@ -19,21 +18,35 @@
_hypothesis_template = "This example is about {}"


__classifier: LazyInit[Pipeline] = LazyInit(
lambda: pipeline(
@lru_cache(maxsize=None)
def _get_classifier(model_version: str) -> Pipeline:
return pipeline(
"zero-shot-classification",
model="MoritzLaurer/xtremedistil-l6-h256-zeroshot-v1.1-all-33",
model=model_version,
device="cuda" if torch.cuda.is_available() else "cpu",
)
)


MODEL_SMALL = "MoritzLaurer/xtremedistil-l6-h256-zeroshot-v1.1-all-33"
MODEL_BASE = "MoritzLaurer/deberta-v3-base-zeroshot-v1.1-all-33"
MODEL_LARGE = "MoritzLaurer/deberta-v3-large-zeroshot-v1.1-all-33"


def __get_scores_per_label(
text: str, topics: List[str], hypothesis_template: str = _hypothesis_template, multi_label: bool = True
text: str,
topics: List[str],
hypothesis_template: str = _hypothesis_template,
multi_label: bool = True,
model_version: str = MODEL_SMALL,
) -> Optional[Dict[str, float]]:
if not text:
return None
result: Dict[str, [str, float]] = __classifier.value(text, topics, hypothesis_template=hypothesis_template, multi_label=multi_label) # type: ignore
result: Dict[str, [str, float]] = _get_classifier(model_version)( # type: ignore
text,
topics,
hypothesis_template=hypothesis_template,
multi_label=multi_label, # type: ignore
)
scores_per_label: Dict[str, float] = {label: score for label, score in zip(result["labels"], result["scores"])} # type: ignore[reportUnknownVariableType]
return scores_per_label

Expand All @@ -45,15 +58,17 @@ def _sanitize_metric_name(topic: str) -> str:
return topic.replace(" ", "_").lower()


def topic_metric(input_name: str, topics: List[str], hypothesis_template: Optional[str] = None) -> MultiMetric:
def topic_metric(
input_name: str, topics: List[str], hypothesis_template: Optional[str] = None, model_version: str = MODEL_SMALL
) -> MultiMetric:
hypothesis_template = hypothesis_template or _hypothesis_template

def udf(text: pd.DataFrame) -> MultiMetricResult:
metrics: Dict[str, List[Optional[float]]] = {topic: [] for topic in topics}

def process_row(row: pd.DataFrame) -> Dict[str, List[Optional[float]]]:
value: Any = row[input_name] # type: ignore
scores = __get_scores_per_label(value, topics=topics, hypothesis_template=hypothesis_template) # pyright: ignore[reportUnknownArgumentType]
scores = __get_scores_per_label(value, topics=topics, hypothesis_template=hypothesis_template, model_version=model_version) # pyright: ignore[reportUnknownArgumentType]
for topic in topics:
metrics[topic].append(scores[topic] if scores else None)
return metrics
Expand All @@ -67,15 +82,15 @@ def process_row(row: pd.DataFrame) -> Dict[str, List[Optional[float]]]:
return MultiMetricResult(metrics=all_metrics)

def cache_assets():
__classifier.value
_get_classifier(model_version)

metric_names = [f"{input_name}.topics.{_sanitize_metric_name(topic)}" for topic in topics]
return MultiMetric(names=metric_names, input_name=input_name, evaluate=udf, cache_assets=cache_assets)


prompt_topic_module = partial(topic_metric, "prompt", __default_topics, _hypothesis_template)
response_topic_module = partial(topic_metric, "response", __default_topics, _hypothesis_template)
prompt_response_topic_module = [prompt_topic_module, response_topic_module, _hypothesis_template]
prompt_topic_module = partial(topic_metric, "prompt", __default_topics, _hypothesis_template, MODEL_SMALL)
response_topic_module = partial(topic_metric, "response", __default_topics, _hypothesis_template, MODEL_SMALL)
prompt_response_topic_module = [prompt_topic_module, response_topic_module, _hypothesis_template, MODEL_SMALL]


@dataclass
Expand All @@ -85,9 +100,11 @@ class CustomTopicModules:
prompt_response_topic_module: MetricCreator


def get_custom_topic_modules(topics: List[str], template: str = _hypothesis_template) -> CustomTopicModules:
prompt_topic_module = partial(topic_metric, "prompt", topics, template)
response_topic_module = partial(topic_metric, "response", topics, template)
def get_custom_topic_modules(
topics: List[str], template: str = _hypothesis_template, model_version: str = MODEL_SMALL
) -> CustomTopicModules:
prompt_topic_module = partial(topic_metric, "prompt", topics, template, model_version)
response_topic_module = partial(topic_metric, "response", topics, template, model_version)
return CustomTopicModules(
prompt_topic_module=prompt_topic_module,
response_topic_module=response_topic_module,
Expand Down
37 changes: 36 additions & 1 deletion tests/langkit/metrics/test_topic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# pyright: reportUnknownMemberType=none
from functools import partial
from typing import Any

import pandas as pd
Expand All @@ -7,7 +8,7 @@
from langkit.core.metric import WorkflowMetricConfig, WorkflowMetricConfigBuilder
from langkit.core.workflow import Workflow
from langkit.metrics.library import lib
from langkit.metrics.topic import get_custom_topic_modules, prompt_topic_module
from langkit.metrics.topic import MODEL_BASE, get_custom_topic_modules, prompt_topic_module, topic_metric
from langkit.metrics.whylogs_compat import create_whylogs_udf_schema

expected_metrics = [
Expand Down Expand Up @@ -81,6 +82,21 @@ def test_topic():
assert actual.index.tolist() == expected_columns


def test_topic_base_model():
df = pd.DataFrame(
{
"prompt": [
"http://get-free-money-now.xyz/bank/details",
],
}
)

custom_topic_module = partial(topic_metric, "prompt", ["phishing"], model_version=MODEL_BASE)
schema = WorkflowMetricConfigBuilder().add(custom_topic_module).build()
actual = _log(df, schema)
assert actual.loc["prompt.topics.phishing"]["distribution/mean"] > 0.80


def test_topic_empty_input():
df = pd.DataFrame(
{
Expand Down Expand Up @@ -243,6 +259,25 @@ def test_custom_topic():
assert actual.loc[column]["distribution/max"] >= 0.50


def test_custom_topics_base_model():
df = pd.DataFrame(
{
"prompt": [
"http://get-free-money-now.xyz/bank/details",
],
"response": [
"http://win-a-free-iphone-today.net",
],
}
)

custom_topic_modules = get_custom_topic_modules(["phishing"], model_version=MODEL_BASE)
schema = WorkflowMetricConfigBuilder().add(custom_topic_modules.prompt_response_topic_module).build()
actual = _log(df, schema)
assert actual.loc["prompt.topics.phishing"]["distribution/mean"] > 0.80
assert actual.loc["response.topics.phishing"]["distribution/mean"] > 0.80


def test_topic_name_sanitize():
df = pd.DataFrame(
{
Expand Down
Loading