Skip to content

Commit

Permalink
Add CodeGen and CodeTrans gateway (opea-project#67)
Browse files Browse the repository at this point in the history
* Add CodeGen and CodeTrans gateway

Signed-off-by: lvliang-intel <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: lvliang-intel <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
lvliang-intel and pre-commit-ci[bot] authored May 17, 2024
1 parent 3b236d7 commit 5d56d7f
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 10 deletions.
2 changes: 1 addition & 1 deletion comps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from comps.cores.mega.orchestrator import ServiceOrchestrator
from comps.cores.mega.orchestrator_with_yaml import ServiceOrchestratorWithYaml
from comps.cores.mega.micro_service import MicroService, register_microservice, opea_microservices
from comps.cores.mega.gateway import Gateway, ChatQnAGateway
from comps.cores.mega.gateway import Gateway, ChatQnAGateway, CodeGenGateway, CodeTransGateway

# Telemetry
from comps.cores.telemetry.opea_telemetry import opea_telemetry
Expand Down
91 changes: 82 additions & 9 deletions comps/cores/mega/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,23 @@ def define_routes(self):
str(MegaServiceEndpoint.LIST_PARAMETERS), self.list_parameter, methods=["GET"]
)

def add_route(self, endpoint, handler, methods=["POST"]):
self.service.app.router.add_api_route(endpoint, handler, methods=methods)

def stop(self):
self.service.stop()

async def handle_request(self, request: Request):
raise NotImplementedError("Subclasses must implement this method")

def list_service(self):
raise NotImplementedError("Subclasses must implement this method")
response = {}
for node in self.all_leaves():
response = {self.services[node].description: self.services[node].endpoint_path}
return response

def list_parameter(self):
raise NotImplementedError("Subclasses must implement this method")
pass


class ChatQnAGateway(Gateway):
Expand Down Expand Up @@ -111,11 +117,78 @@ async def handle_request(self, request: Request):
)
return ChatCompletionResponse(model="chatqna", choices=choices, usage=usage)

def list_service(self):
response = {}
for node in self.all_leaves():
response = {self.services[node].description: self.services[node].endpoint_path}
return response

def list_parameter(self):
pass
class CodeGenGateway(Gateway):
def __init__(self, megaservice, host="0.0.0.0", port=8888):
super().__init__(
megaservice, host, port, str(MegaServiceEndpoint.CODE_GEN), ChatCompletionRequest, ChatCompletionResponse
)

async def handle_request(self, request: Request):
data = await request.json()
chat_request = ChatCompletionRequest.parse_obj(data)
if isinstance(chat_request.messages, str):
prompt = chat_request.messages
else:
for message in chat_request.messages:
text_list = [item["text"] for item in message["content"] if item["type"] == "text"]
prompt = "\n".join(text_list)
await self.megaservice.schedule(initial_inputs={"text": prompt})
for node, response in self.megaservice.result_dict.items():
# Here it suppose the last microservice in the megaservice is LLM.
if (
isinstance(response, StreamingResponse)
and node == list(self.megaservice.services.keys())[-1]
and self.megaservice.services[node].service_type == ServiceType.LLM
):
return response
last_node = self.megaservice.all_leaves()[-1]
response = self.megaservice.result_dict[last_node]["text"]
choices = []
usage = UsageInfo()
choices.append(
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop",
)
)
return ChatCompletionResponse(model="codegen", choices=choices, usage=usage)


class CodeTransGateway(Gateway):
def __init__(self, megaservice, host="0.0.0.0", port=8888):
super().__init__(
megaservice, host, port, str(MegaServiceEndpoint.CODE_TRANS), ChatCompletionRequest, ChatCompletionResponse
)

async def handle_request(self, request: Request):
data = await request.json()
chat_request = ChatCompletionRequest.parse_obj(data)
if isinstance(chat_request.messages, str):
prompt = chat_request.messages
else:
for message in chat_request.messages:
text_list = [item["text"] for item in message["content"] if item["type"] == "text"]
prompt = "\n".join(text_list)
await self.megaservice.schedule(initial_inputs={"text": prompt})
for node, response in self.megaservice.result_dict.items():
# Here it suppose the last microservice in the megaservice is LLM.
if (
isinstance(response, StreamingResponse)
and node == list(self.megaservice.services.keys())[-1]
and self.megaservice.services[node].service_type == ServiceType.LLM
):
return response
last_node = self.megaservice.all_leaves()[-1]
response = self.megaservice.result_dict[last_node]["text"]
choices = []
usage = UsageInfo()
choices.append(
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response),
finish_reason="stop",
)
)
return ChatCompletionResponse(model="codetrans", choices=choices, usage=usage)

0 comments on commit 5d56d7f

Please sign in to comment.