From 3bc899beb81cd9a50028e626db36852b130855c4 Mon Sep 17 00:00:00 2001 From: lvliang-intel Date: Mon, 6 May 2024 21:35:39 +0800 Subject: [PATCH] Add reranking microservice (#21) * Add reranking microservice Signed-off-by: lvliang-intel --- comps/__init__.py | 10 ++++++- comps/proto/docarray.py | 13 +++++++- comps/reranks/langchain/README.md | 0 comps/reranks/langchain/__init__.py | 13 -------- comps/reranks/local_reranking.py | 37 +++++++++++++++++++++++ comps/reranks/requirements.txt | 2 ++ comps/reranks/reranking_tei_gaudi.py | 44 ++++++++++++++++++++++++++++ 7 files changed, 104 insertions(+), 15 deletions(-) delete mode 100644 comps/reranks/langchain/README.md delete mode 100644 comps/reranks/langchain/__init__.py create mode 100644 comps/reranks/local_reranking.py create mode 100644 comps/reranks/requirements.txt create mode 100644 comps/reranks/reranking_tei_gaudi.py diff --git a/comps/__init__.py b/comps/__init__.py index bc2f1aa6da..ba441ad501 100644 --- a/comps/__init__.py +++ b/comps/__init__.py @@ -13,7 +13,15 @@ # limitations under the License. # Document -from comps.proto.docarray import TextDoc, EmbedDoc768, EmbedDoc1024, GeneratedDoc, LLMParamsDoc +from comps.proto.docarray import ( + TextDoc, + EmbedDoc768, + EmbedDoc1024, + GeneratedDoc, + LLMParamsDoc, + RerankingInputDoc, + RerankingOutputDoc, +) # Microservice from comps.mega.orchestrator import ServiceOrchestrator diff --git a/comps/proto/docarray.py b/comps/proto/docarray.py index 5fa3d81620..5c47c2b864 100644 --- a/comps/proto/docarray.py +++ b/comps/proto/docarray.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from docarray import BaseDoc +from docarray import BaseDoc, DocList from docarray.typing import NdArray @@ -41,3 +41,14 @@ class LLMParamsDoc(BaseDoc): temperature: float = 0.01 repetition_penalty: float = 1.03 streaming: bool = True + + +class RerankingInputDoc(BaseDoc): + query: str + passages: DocList[TextDoc] + top_n: int = 3 + + +class RerankingOutputDoc(BaseDoc): + query: str + doc: TextDoc diff --git a/comps/reranks/langchain/README.md b/comps/reranks/langchain/README.md deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/comps/reranks/langchain/__init__.py b/comps/reranks/langchain/__init__.py deleted file mode 100644 index 28f108cb63..0000000000 --- a/comps/reranks/langchain/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/comps/reranks/local_reranking.py b/comps/reranks/local_reranking.py new file mode 100644 index 0000000000..509ecc5fcf --- /dev/null +++ b/comps/reranks/local_reranking.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sentence_transformers import CrossEncoder + +from comps import RerankingInputDoc, RerankingOutputDoc, opea_microservices, register_microservice + + +@register_microservice( + name="opea_service@local_reranking", + expose_endpoint="/v1/reranking", + port=8040, + input_datatype=RerankingInputDoc, + output_datatype=RerankingOutputDoc, +) +def reranking(input: RerankingInputDoc) -> RerankingOutputDoc: + query_and_docs = [(input.query, doc.text) for doc in input.passages] + scores = reranker_model.predict(query_and_docs) + first_passage = sorted(list(zip(input.passages, scores)), key=lambda x: x[1], reverse=True)[0][0] + res = RerankingOutputDoc(query=input.query, doc=first_passage) + return res + + +if __name__ == "__main__": + reranker_model = CrossEncoder(model_name="BAAI/bge-reranker-large", max_length=512) + opea_microservices["opea_service@local_reranking"].start() diff --git a/comps/reranks/requirements.txt b/comps/reranks/requirements.txt new file mode 100644 index 0000000000..1fa8840aec --- /dev/null +++ b/comps/reranks/requirements.txt @@ -0,0 +1,2 @@ +docarray[full] +sentence_transformers diff --git a/comps/reranks/reranking_tei_gaudi.py b/comps/reranks/reranking_tei_gaudi.py new file mode 100644 index 0000000000..e5a2fb9bf4 --- /dev/null +++ b/comps/reranks/reranking_tei_gaudi.py @@ -0,0 +1,44 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +import requests + +from comps import RerankingInputDoc, RerankingOutputDoc, opea_microservices, register_microservice + + +@register_microservice( + name="opea_service@reranking_tgi_gaudi", + expose_endpoint="/v1/reranking", + port=8040, + input_datatype=RerankingInputDoc, + output_datatype=RerankingOutputDoc, +) +def reranking(input: RerankingInputDoc) -> RerankingOutputDoc: + docs = [doc.text for doc in input.passages] + url = tei_reranking_endpoint + "/rerank" + data = {"query": input.query, "texts": docs} + headers = {"Content-Type": "application/json"} + response = requests.post(url, data=json.dumps(data), headers=headers) + response_data = response.json() + best_response = max(response_data, key=lambda response: response["score"]) + res = RerankingOutputDoc(query=input.query, doc=input.passages[best_response["index"]]) + return res + + +if __name__ == "__main__": + tei_reranking_endpoint = os.getenv("TEI_RERANKING_ENDPOINT", "http://localhost:8080") + opea_microservices["opea_service@reranking_tgi_gaudi"].start()