diff --git a/docs/_src/api/api/query_classifier.md b/docs/_src/api/api/query_classifier.md index f6630c0ec3..b92d8fa7cf 100644 --- a/docs/_src/api/api/query_classifier.md +++ b/docs/_src/api/api/query_classifier.md @@ -96,10 +96,11 @@ queries or statement vs question queries. class TransformersQueryClassifier(BaseQueryClassifier) ``` -A node to classify an incoming query into one of two categories using a (small) BERT transformer model. +A node to classify an incoming query into categories using a transformer model. Depending on the result, the query flows to a different branch in your pipeline and the further processing -can be customized. You can define this by connecting the further pipeline to either `output_1` or `output_2` +can be customized. You can define this by connecting the further pipeline to `output_1`, `output_2`, ..., `output_n` from this node. +This node also supports zero-shot-classification. **Example**: @@ -120,7 +121,7 @@ from this node. Models: - Pass your own `Transformer` binary classification model from file/huggingface or use one of the following + Pass your own `Transformer` classification/zero-shot-classification model from file/huggingface or use one of the following pretrained ones hosted on Huggingface: 1) Keywords vs. Questions/Statements (Default) model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection" @@ -143,13 +144,20 @@ from this node. #### TransformersQueryClassifier.\_\_init\_\_ ```python -def __init__(model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection", use_gpu: bool = True, batch_size: int = 16, progress_bar: bool = True) +def __init__(model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection", model_version: Optional[str] = None, tokenizer: Optional[str] = None, use_gpu: bool = True, task: str = "text-classification", labels: List[str] = DEFAULT_LABELS, batch_size: int = 16, progress_bar: bool = True) ``` **Arguments**: -- `model_name_or_path`: Transformer based fine tuned mini bert model for query classification +- `model_name_or_path`: Directory of a saved model or the name of a public model, for example 'shahrukhx01/bert-mini-finetune-question-detection'. +See [Hugging Face models](https://huggingface.co/models) for a full list of available models. +- `model_version`: The version of the model to use from the Hugging Face model hub. This can be a tag name, a branch name, or a commit hash. +- `tokenizer`: The name of the tokenizer (usually the same as model). - `use_gpu`: Whether to use GPU (if available). -- `batch_size`: Batch size for inference. +- `task`: Specifies the type of classification. Possible values: 'text-classification' or 'zero-shot-classification'. +- `labels`: If the task is 'text-classification' and an ordered list of labels is provided, the first label corresponds to output_1, +the second label to output_2, and so on. The labels must match the model labels; only the order can differ. +If the task is 'zero-shot-classification', these are the candidate labels. +- `batch_size`: The number of queries to be processed at a time. - `progress_bar`: Whether to show a progress bar. diff --git a/haystack/json-schemas/haystack-pipeline-master.schema.json b/haystack/json-schemas/haystack-pipeline-master.schema.json index c02b798b0d..1656dd835a 100644 --- a/haystack/json-schemas/haystack-pipeline-master.schema.json +++ b/haystack/json-schemas/haystack-pipeline-master.schema.json @@ -4801,11 +4801,35 @@ } ] }, + "model_version": { + "title": "Model Version", + "type": "string" + }, + "tokenizer": { + "title": "Tokenizer", + "type": "string" + }, "use_gpu": { "title": "Use Gpu", "default": true, "type": "boolean" }, + "task": { + "title": "Task", + "default": "text-classification", + "type": "string" + }, + "labels": { + "title": "Labels", + "default": [ + "LABEL_1", + "LABEL_0" + ], + "type": "array", + "items": { + "type": "string" + } + }, "batch_size": { "title": "Batch Size", "default": 16, diff --git a/haystack/nodes/query_classifier/transformers.py b/haystack/nodes/query_classifier/transformers.py index ad397b4f82..4b92c840a6 100644 --- a/haystack/nodes/query_classifier/transformers.py +++ b/haystack/nodes/query_classifier/transformers.py @@ -1,22 +1,27 @@ import logging from pathlib import Path -from typing import Union, List, Optional, Dict +from typing import Union, List, Optional, Dict, Any +from transformers import pipeline from tqdm.auto import tqdm -from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline + +# from transformers import AutoTokenizer, AutoModelForSequenceClassification, TextClassificationPipeline from haystack.nodes.query_classifier.base import BaseQueryClassifier from haystack.modeling.utils import initialize_device_settings from haystack.utils.torch_utils import ListDataset logger = logging.getLogger(__name__) +DEFAULT_LABELS = ["LABEL_1", "LABEL_0"] + class TransformersQueryClassifier(BaseQueryClassifier): """ - A node to classify an incoming query into one of two categories using a (small) BERT transformer model. + A node to classify an incoming query into categories using a transformer model. Depending on the result, the query flows to a different branch in your pipeline and the further processing - can be customized. You can define this by connecting the further pipeline to either `output_1` or `output_2` + can be customized. You can define this by connecting the further pipeline to `output_1`, `output_2`, ..., `output_n` from this node. + This node also supports zero-shot-classification. Example: ```python @@ -36,7 +41,7 @@ class TransformersQueryClassifier(BaseQueryClassifier): Models: - Pass your own `Transformer` binary classification model from file/huggingface or use one of the following + Pass your own `Transformer` classification/zero-shot-classification model from file/huggingface or use one of the following pretrained ones hosted on Huggingface: 1) Keywords vs. Questions/Statements (Default) model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection" @@ -58,56 +63,98 @@ class TransformersQueryClassifier(BaseQueryClassifier): def __init__( self, model_name_or_path: Union[Path, str] = "shahrukhx01/bert-mini-finetune-question-detection", + model_version: Optional[str] = None, + tokenizer: Optional[str] = None, use_gpu: bool = True, + task: str = "text-classification", + labels: List[str] = DEFAULT_LABELS, batch_size: int = 16, progress_bar: bool = True, ): """ - :param model_name_or_path: Transformer based fine tuned mini bert model for query classification + :param model_name_or_path: Directory of a saved model or the name of a public model, for example 'shahrukhx01/bert-mini-finetune-question-detection'. + See [Hugging Face models](https://huggingface.co/models) for a full list of available models. + :param model_version: The version of the model to use from the Hugging Face model hub. This can be a tag name, a branch name, or a commit hash. + :param tokenizer: The name of the tokenizer (usually the same as model). :param use_gpu: Whether to use GPU (if available). - :param batch_size: Batch size for inference. + :param task: Specifies the type of classification. Possible values: 'text-classification' or 'zero-shot-classification'. + :param labels: If the task is 'text-classification' and an ordered list of labels is provided, the first label corresponds to output_1, + the second label to output_2, and so on. The labels must match the model labels; only the order can differ. + If the task is 'zero-shot-classification', these are the candidate labels. + :param batch_size: The number of queries to be processed at a time. :param progress_bar: Whether to show a progress bar. """ super().__init__() - - self.devices, _ = initialize_device_settings(use_cuda=use_gpu) + devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False) + device = 0 if devices[0].type == "cuda" else -1 + + self.model = pipeline( + task=task, model=model_name_or_path, tokenizer=tokenizer, device=device, revision=model_version + ) + + self.labels = labels + if task == "text-classification": + labels_from_model = [label for label in self.model.model.config.id2label.values()] + if set(labels) != set(labels_from_model): + raise ValueError( + f"For text-classification, the provided labels must match the model labels; only the order can differ.\n" + f"Provided labels: {labels}\n" + f"Model labels: {labels_from_model}" + ) + if task not in ["text-classification", "zero-shot-classification"]: + raise ValueError( + f"Task not supported: {task}.\n" + f"Possible task values are: 'text-classification' or 'zero-shot-classification'" + ) + self.task = task self.batch_size = batch_size - device = 0 if self.devices[0].type == "cuda" else -1 self.progress_bar = progress_bar - model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path) - tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) - - self.query_classification_pipeline = TextClassificationPipeline(model=model, tokenizer=tokenizer, device=device) - - def run(self, query): - is_question: bool = self.query_classification_pipeline(query)[0]["label"] == "LABEL_1" - - if is_question: - return {}, "output_1" - else: - return {}, "output_2" + @classmethod + def _calculate_outgoing_edges(cls, component_params: Dict[str, Any]) -> int: + labels = component_params.get("labels", DEFAULT_LABELS) + if labels is None or len(labels) == 0: + raise ValueError("The labels must be provided") + return len(labels) + + def _get_edge_number_from_label(self, label): + return self.labels.index(label) + 1 + + def run(self, query: str): # type: ignore + if self.task == "zero-shot-classification": + prediction = self.model([query], candidate_labels=self.labels, truncation=True) + label = prediction[0]["labels"][0] + elif self.task == "text-classification": + prediction = self.model([query], truncation=True) + label = prediction[0]["label"] + return {}, f"output_{self._get_edge_number_from_label(label)}" def run_batch(self, queries: List[str], batch_size: Optional[int] = None): # type: ignore - if batch_size is None: - batch_size = self.batch_size - - split: Dict[str, Dict[str, List]] = {"output_1": {"queries": []}, "output_2": {"queries": []}} - # HF pb hack https://discuss.huggingface.co/t/progress-bar-for-hf-pipelines/20498/2 queries_dataset = ListDataset(queries) + if batch_size is None: + batch_size = self.batch_size all_predictions = [] - for predictions in tqdm( - self.query_classification_pipeline(queries_dataset, batch_size=batch_size), - disable=not self.progress_bar, - desc="Classifying queries", - ): - all_predictions.extend(predictions) - - for query, pred in zip(queries, all_predictions): - if pred["label"] == "LABEL_1": - split["output_1"]["queries"].append(query) - else: - split["output_2"]["queries"].append(query) - - return split, "split" + if self.task == "zero-shot-classification": + for predictions in tqdm( + self.model(queries_dataset, candidate_labels=self.labels, truncation=True, batch_size=batch_size), + disable=not self.progress_bar, + desc="Classifying queries", + ): + all_predictions.extend([predictions]) + elif self.task == "text-classification": + for predictions in tqdm( + self.model(queries_dataset, truncation=True, batch_size=batch_size), + disable=not self.progress_bar, + desc="Classifying queries", + ): + all_predictions.extend([predictions]) + results = {f"output_{self._get_edge_number_from_label(label)}": {"queries": []} for label in self.labels} # type: ignore + for query, prediction in zip(queries, all_predictions): + if self.task == "zero-shot-classification": + label = prediction["labels"][0] + elif self.task == "text-classification": + label = prediction["label"] + results[f"output_{self._get_edge_number_from_label(label)}"]["queries"].append(query) + + return results, "split" diff --git a/test/nodes/test_query_classifier.py b/test/nodes/test_query_classifier.py new file mode 100644 index 0000000000..a96eec594e --- /dev/null +++ b/test/nodes/test_query_classifier.py @@ -0,0 +1,94 @@ +import pytest +from haystack.nodes.query_classifier.transformers import TransformersQueryClassifier + + +@pytest.fixture +def transformers_query_classifier(): + return TransformersQueryClassifier( + model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection", + use_gpu=False, + task="text-classification", + labels=["LABEL_1", "LABEL_0"], + ) + + +@pytest.fixture +def zero_shot_transformers_query_classifier(): + return TransformersQueryClassifier( + model_name_or_path="typeform/distilbert-base-uncased-mnli", + use_gpu=False, + task="zero-shot-classification", + labels=["happy", "unhappy", "neutral"], + ) + + +def test_transformers_query_classifier(transformers_query_classifier): + output = transformers_query_classifier.run(query="morse code") + assert output == ({}, "output_2") + + output = transformers_query_classifier.run(query="How old is John?") + assert output == ({}, "output_1") + + +def test_transformers_query_classifier_batch(transformers_query_classifier): + queries = ["morse code", "How old is John?"] + output = transformers_query_classifier.run_batch(queries=queries) + + assert output[0] == {"output_2": {"queries": ["morse code"]}, "output_1": {"queries": ["How old is John?"]}} + + +def test_zero_shot_transformers_query_classifier(zero_shot_transformers_query_classifier): + output = zero_shot_transformers_query_classifier.run(query="What's the answer?") + assert output == ({}, "output_3") + + output = zero_shot_transformers_query_classifier.run(query="Would you be so kind to tell me the answer?") + assert output == ({}, "output_1") + + output = zero_shot_transformers_query_classifier.run(query="Can you give me the right answer for once??") + assert output == ({}, "output_2") + + +def test_zero_shot_transformers_query_classifier_batch(zero_shot_transformers_query_classifier): + queries = [ + "What's the answer?", + "Would you be so kind to tell me the answer?", + "Can you give me the right answer for once??", + ] + + output = zero_shot_transformers_query_classifier.run_batch(queries=queries) + + assert output[0] == { + "output_3": {"queries": ["What's the answer?"]}, + "output_1": {"queries": ["Would you be so kind to tell me the answer?"]}, + "output_2": {"queries": ["Can you give me the right answer for once??"]}, + } + + +def test_transformers_query_classifier_wrong_labels(): + with pytest.raises(ValueError, match="For text-classification, the provided labels must match the model labels"): + query_classifier = TransformersQueryClassifier( + model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection", + use_gpu=False, + task="text-classification", + labels=["WRONG_LABEL_1", "WRONG_LABEL_2", "WRONG_LABEL_3"], + ) + + +def test_transformers_query_classifier_no_labels(): + with pytest.raises(ValueError, match="The labels must be provided"): + query_classifier = TransformersQueryClassifier( + model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection", + use_gpu=False, + task="text-classification", + labels=None, + ) + + +def test_transformers_query_classifier_unsupported_task(): + with pytest.raises(ValueError, match="Task not supported"): + query_classifier = TransformersQueryClassifier( + model_name_or_path="shahrukhx01/bert-mini-finetune-question-detection", + use_gpu=False, + task="summarization", + labels=["LABEL_1", "LABEL_0"], + )