diff --git a/app/llm/langchain/iris_langchain_chat_model.py b/app/llm/langchain/iris_langchain_chat_model.py index d0e558fa..ec78e963 100644 --- a/app/llm/langchain/iris_langchain_chat_model.py +++ b/app/llm/langchain/iris_langchain_chat_model.py @@ -9,7 +9,7 @@ from langchain_core.outputs.chat_generation import ChatGeneration from domain import IrisMessage, IrisMessageRole -from llm import RequestHandlerInterface, CompletionArguments +from llm import RequestHandler, CompletionArguments def convert_iris_message_to_base_message(iris_message: IrisMessage) -> BaseMessage: @@ -35,9 +35,9 @@ def convert_base_message_to_iris_message(base_message: BaseMessage) -> IrisMessa class IrisLangchainChatModel(BaseChatModel): """Custom langchain chat model for our own request handler""" - request_handler: RequestHandlerInterface + request_handler: RequestHandler - def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None: + def __init__(self, request_handler: RequestHandler, **kwargs: Any) -> None: super().__init__(request_handler=request_handler, **kwargs) def _generate( @@ -48,7 +48,7 @@ def _generate( **kwargs: Any ) -> ChatResult: iris_messages = [convert_base_message_to_iris_message(m) for m in messages] - iris_message = self.request_handler.chat_completion( + iris_message = self.request_handler.chat( iris_messages, CompletionArguments(stop=stop) ) base_message = convert_iris_message_to_base_message(iris_message) diff --git a/app/llm/langchain/iris_langchain_completion_model.py b/app/llm/langchain/iris_langchain_completion_model.py index 8a8b2e95..1dc54c6b 100644 --- a/app/llm/langchain/iris_langchain_completion_model.py +++ b/app/llm/langchain/iris_langchain_completion_model.py @@ -5,15 +5,15 @@ from langchain_core.outputs import LLMResult from langchain_core.outputs.generation import Generation -from llm import RequestHandlerInterface, CompletionArguments +from llm import RequestHandler, CompletionArguments class IrisLangchainCompletionModel(BaseLLM): """Custom langchain chat model for our own request handler""" - request_handler: RequestHandlerInterface + request_handler: RequestHandler - def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None: + def __init__(self, request_handler: RequestHandler, **kwargs: Any) -> None: super().__init__(request_handler=request_handler, **kwargs) def _generate( @@ -26,8 +26,8 @@ def _generate( generations = [] args = CompletionArguments(stop=stop) for prompt in prompts: - completion = self.request_handler.completion( - prompt=prompt, arguments=args, **kwargs + completion = self.request_handler.complete( + prompt=prompt, arguments=args ) generations.append([Generation(text=completion)]) return LLMResult(generations=generations) diff --git a/app/llm/langchain/iris_langchain_embedding.py b/app/llm/langchain/iris_langchain_embedding.py index 01a840ff..504fe46f 100644 --- a/app/llm/langchain/iris_langchain_embedding.py +++ b/app/llm/langchain/iris_langchain_embedding.py @@ -2,19 +2,19 @@ from langchain_core.embeddings import Embeddings -from llm import RequestHandlerInterface +from llm import RequestHandler class IrisLangchainEmbeddingModel(Embeddings): """Custom langchain embedding for our own request handler""" - request_handler: RequestHandlerInterface + request_handler: RequestHandler - def __init__(self, request_handler: RequestHandlerInterface, **kwargs: Any) -> None: + def __init__(self, request_handler: RequestHandler, **kwargs: Any) -> None: super().__init__(request_handler=request_handler, **kwargs) def embed_documents(self, texts: List[str]) -> List[List[float]]: return [self.embed_query(text) for text in texts] def embed_query(self, text: str) -> List[float]: - return self.request_handler.create_embedding(text) + return self.request_handler.embed(text)