Skip to content

Commit

Permalink
Add reranking microservice (opea-project#21)
Browse files Browse the repository at this point in the history
* Add reranking microservice

Signed-off-by: lvliang-intel <[email protected]>
  • Loading branch information
lvliang-intel authored May 6, 2024
1 parent 1f6c1a5 commit 3bc899b
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 15 deletions.
10 changes: 9 additions & 1 deletion comps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@
# limitations under the License.

# Document
from comps.proto.docarray import TextDoc, EmbedDoc768, EmbedDoc1024, GeneratedDoc, LLMParamsDoc
from comps.proto.docarray import (
TextDoc,
EmbedDoc768,
EmbedDoc1024,
GeneratedDoc,
LLMParamsDoc,
RerankingInputDoc,
RerankingOutputDoc,
)

# Microservice
from comps.mega.orchestrator import ServiceOrchestrator
Expand Down
13 changes: 12 additions & 1 deletion comps/proto/docarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from docarray import BaseDoc
from docarray import BaseDoc, DocList
from docarray.typing import NdArray


Expand Down Expand Up @@ -41,3 +41,14 @@ class LLMParamsDoc(BaseDoc):
temperature: float = 0.01
repetition_penalty: float = 1.03
streaming: bool = True


class RerankingInputDoc(BaseDoc):
query: str
passages: DocList[TextDoc]
top_n: int = 3


class RerankingOutputDoc(BaseDoc):
query: str
doc: TextDoc
Empty file removed comps/reranks/langchain/README.md
Empty file.
13 changes: 0 additions & 13 deletions comps/reranks/langchain/__init__.py

This file was deleted.

37 changes: 37 additions & 0 deletions comps/reranks/local_reranking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from sentence_transformers import CrossEncoder

from comps import RerankingInputDoc, RerankingOutputDoc, opea_microservices, register_microservice


@register_microservice(
name="opea_service@local_reranking",
expose_endpoint="/v1/reranking",
port=8040,
input_datatype=RerankingInputDoc,
output_datatype=RerankingOutputDoc,
)
def reranking(input: RerankingInputDoc) -> RerankingOutputDoc:
query_and_docs = [(input.query, doc.text) for doc in input.passages]
scores = reranker_model.predict(query_and_docs)
first_passage = sorted(list(zip(input.passages, scores)), key=lambda x: x[1], reverse=True)[0][0]
res = RerankingOutputDoc(query=input.query, doc=first_passage)
return res


if __name__ == "__main__":
reranker_model = CrossEncoder(model_name="BAAI/bge-reranker-large", max_length=512)
opea_microservices["opea_service@local_reranking"].start()
2 changes: 2 additions & 0 deletions comps/reranks/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
docarray[full]
sentence_transformers
44 changes: 44 additions & 0 deletions comps/reranks/reranking_tei_gaudi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2024 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os

import requests

from comps import RerankingInputDoc, RerankingOutputDoc, opea_microservices, register_microservice


@register_microservice(
name="opea_service@reranking_tgi_gaudi",
expose_endpoint="/v1/reranking",
port=8040,
input_datatype=RerankingInputDoc,
output_datatype=RerankingOutputDoc,
)
def reranking(input: RerankingInputDoc) -> RerankingOutputDoc:
docs = [doc.text for doc in input.passages]
url = tei_reranking_endpoint + "/rerank"
data = {"query": input.query, "texts": docs}
headers = {"Content-Type": "application/json"}
response = requests.post(url, data=json.dumps(data), headers=headers)
response_data = response.json()
best_response = max(response_data, key=lambda response: response["score"])
res = RerankingOutputDoc(query=input.query, doc=input.passages[best_response["index"]])
return res


if __name__ == "__main__":
tei_reranking_endpoint = os.getenv("TEI_RERANKING_ENDPOINT", "http://localhost:8080")
opea_microservices["opea_service@reranking_tgi_gaudi"].start()

0 comments on commit 3bc899b

Please sign in to comment.