diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 87cc5d54..3cd3a03c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,17 +18,3 @@ repos: language: system types: [ python ] stages: [ commit, push ] - - id: mypy - name: Mypy Type Check - entry: mypy . - language: system - types: [ python ] - stages: [ commit, push ] - pass_filenames: false - args: [ - --strict, - --ignore-missing-imports, - --allow-untyped-calls, - --allow-subclassing-any, - --exclude='./docs/' - ] diff --git a/src/neo4j_graphrag/generation/graphrag.py b/src/neo4j_graphrag/generation/graphrag.py index 3d65fcb5..9430679e 100644 --- a/src/neo4j_graphrag/generation/graphrag.py +++ b/src/neo4j_graphrag/generation/graphrag.py @@ -83,6 +83,7 @@ def __init__( def search( self, query_text: str = "", + chat_history: Optional[list[str]] = None, examples: str = "", retriever_config: Optional[dict[str, Any]] = None, return_context: bool | None = None, @@ -100,6 +101,7 @@ def search( Args: query_text (str): The user question + chat_history: (Optional[list]): A list of previous messages in the conversation examples (str): Examples added to the LLM prompt. retriever_config (Optional[dict]): Parameters passed to the retriever search method; e.g.: top_k @@ -134,7 +136,10 @@ def search( ) logger.debug(f"RAG: retriever_result={retriever_result}") logger.debug(f"RAG: prompt={prompt}") - answer = self.llm.invoke(prompt) + if chat_history is not None: + answer = self.llm.chat(prompt, chat_history) + else: + answer = self.llm.invoke(prompt) result: dict[str, Any] = {"answer": answer.content} if return_context: result["retriever_result"] = retriever_result diff --git a/src/neo4j_graphrag/llm/anthropic_llm.py b/src/neo4j_graphrag/llm/anthropic_llm.py index fd076c0f..e1c9c2db 100644 --- a/src/neo4j_graphrag/llm/anthropic_llm.py +++ b/src/neo4j_graphrag/llm/anthropic_llm.py @@ -110,3 +110,6 @@ async def ainvoke(self, input: str) -> LLMResponse: return LLMResponse(content=response.content) except self.anthropic.APIError as e: raise LLMGenerationError(e) + + def chat(self, input: str, chat_history: list[str]) -> LLMResponse: + pass diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index 3d98423b..ffa4bdd6 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -33,10 +33,12 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, + system_instruction: Optional[str] = None, **kwargs: Any, ): self.model_name = model_name self.model_params = model_params or {} + self.system_instruction = system_instruction @abstractmethod def invoke(self, input: str) -> LLMResponse: @@ -52,6 +54,21 @@ def invoke(self, input: str) -> LLMResponse: LLMGenerationError: If anything goes wrong. """ + @abstractmethod + def chat(self, input: str, chat_history: list[str]) -> LLMResponse: + """Sends a text input and a converstion history to the LLM and retrieves a response. + + Args: + input (str): Text sent to the LLM + chat_history (list[str]]): A list of previous messages in the conversation + + Returns: + LLMResponse: The response from the LLM. + + Raises: + LLMGenerationError: If anything goes wrong. + """ + @abstractmethod async def ainvoke(self, input: str) -> LLMResponse: """Asynchronously sends a text input to the LLM and retrieves a response. diff --git a/src/neo4j_graphrag/llm/cohere_llm.py b/src/neo4j_graphrag/llm/cohere_llm.py index 6a6aa036..9c11bf54 100644 --- a/src/neo4j_graphrag/llm/cohere_llm.py +++ b/src/neo4j_graphrag/llm/cohere_llm.py @@ -102,3 +102,6 @@ async def ainvoke(self, input: str) -> LLMResponse: return LLMResponse( content=res.text, ) + + def chat(self, input: str, chat_history: list[str]) -> LLMResponse: + pass diff --git a/src/neo4j_graphrag/llm/mistralai_llm.py b/src/neo4j_graphrag/llm/mistralai_llm.py index bbdcc4db..34881c22 100644 --- a/src/neo4j_graphrag/llm/mistralai_llm.py +++ b/src/neo4j_graphrag/llm/mistralai_llm.py @@ -118,3 +118,6 @@ async def ainvoke(self, input: str) -> LLMResponse: return LLMResponse(content=content) except SDKError as e: raise LLMGenerationError(e) + + def chat(self, input: str, chat_history: list[str]) -> LLMResponse: + pass diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index 04f3f2bf..03b19250 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -36,6 +36,7 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, + system_instruction: Optional[str] = None, ): """ Base class for OpenAI LLM. @@ -54,7 +55,7 @@ def __init__( "Please install it with `pip install openai`." ) self.openai = openai - super().__init__(model_name, model_params) + super().__init__(model_name, model_params, system_instruction) def get_messages( self, @@ -64,6 +65,32 @@ def get_messages( {"role": "system", "content": input}, ] + def get_conversation_history( + self, + input: str, + chat_history: list[str], + ) -> Iterable[ChatCompletionMessageParam]: + messages = [{"role": "system", "content": self.system_instruction}] + for i, message in enumerate(chat_history): + if i % 2 == 0: + messages.append({"role": "user", "content": message}) + else: + messages.append({"role": "assistant", "content": message}) + messages.append({"role": "user", "content": input}) + return messages + + def chat(self, input: str, chat_history: list[str]) -> LLMResponse: + try: + response = self.client.chat.completions.create( + messages=self.get_conversation_history(input, chat_history), + model=self.model_name, + **self.model_params, + ) + content = response.choices[0].message.content or "" + return LLMResponse(content=content) + except self.openai.OpenAIError as e: + raise LLMGenerationError(e) + def invoke(self, input: str) -> LLMResponse: """Sends a text input to the OpenAI chat completion model and returns the response's content. @@ -118,6 +145,7 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, + system_instruction: Optional[str] = None, **kwargs: Any, ): """OpenAI LLM @@ -129,7 +157,7 @@ def __init__( model_params (str): Parameters like temperature that will be passed to the model when text is sent to it kwargs: All other parameters will be passed to the openai.OpenAI init. """ - super().__init__(model_name, model_params) + super().__init__(model_name, model_params, system_instruction) self.client = self.openai.OpenAI(**kwargs) self.async_client = self.openai.AsyncOpenAI(**kwargs) @@ -139,6 +167,7 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, + system_instruction: Optional[str] = None, **kwargs: Any, ): """Azure OpenAI LLM. Use this class when using an OpenAI model @@ -149,6 +178,6 @@ def __init__( model_params (str): Parameters like temperature that will be passed to the model when text is sent to it kwargs: All other parameters will be passed to the openai.OpenAI init. """ - super().__init__(model_name, model_params) + super().__init__(model_name, model_params, system_instruction) self.client = self.openai.AzureOpenAI(**kwargs) self.async_client = self.openai.AsyncAzureOpenAI(**kwargs) diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index b0820e62..c71b5e09 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -20,7 +20,12 @@ from neo4j_graphrag.llm.types import LLMResponse try: - from vertexai.generative_models import GenerativeModel, ResponseValidationError + from vertexai.generative_models import ( + GenerativeModel, + ResponseValidationError, + Part, + Content, + ) except ImportError: GenerativeModel = None ResponseValidationError = None @@ -55,6 +60,7 @@ def __init__( self, model_name: str = "gemini-1.5-flash-001", model_params: Optional[dict[str, Any]] = None, + system_instruction: Optional[str] = None, **kwargs: Any, ): if GenerativeModel is None or ResponseValidationError is None: @@ -62,8 +68,10 @@ def __init__( "Could not import Vertex AI Python client. " "Please install it with `pip install google-cloud-aiplatform`." ) - super().__init__(model_name, model_params) - self.model = GenerativeModel(model_name=model_name, **kwargs) + super().__init__(model_name, model_params, system_instruction) + self.model = GenerativeModel( + model_name=model_name, system_instruction=[system_instruction], **kwargs + ) def invoke(self, input: str) -> LLMResponse: """Sends text to the LLM and returns a response. @@ -80,7 +88,25 @@ def invoke(self, input: str) -> LLMResponse: except ResponseValidationError as e: raise LLMGenerationError(e) - async def ainvoke(self, input: str) -> LLMResponse: + def chat(self, input: str, chat_history: list[str] = []) -> LLMResponse: + """Sends text to the LLM and returns a response. + + Args: + input (str): The text to send to the LLM. + + Returns: + LLMResponse: The response from the LLM. + """ + try: + messages = self.get_conversation_history(input, chat_history) + response = self.model.generate_content(messages, **self.model_params) + return LLMResponse(content=response.text) + except ResponseValidationError as e: + raise LLMGenerationError(e) + + async def ainvoke( + self, input: str, chat_history: Optional[list[str]] = [] + ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. Args: @@ -96,3 +122,17 @@ async def ainvoke(self, input: str) -> LLMResponse: return LLMResponse(content=response.text) except ResponseValidationError as e: raise LLMGenerationError(e) + + def get_conversation_history( + self, + input: str, + chat_history: list[str], + ) -> list[Content]: + messages = [] + for i, message in enumerate(chat_history): + if i % 2 == 0: + messages.append(Content(role="user", parts=[Part.from_text(message)])) + else: + messages.append(Content(role="model", parts=[Part.from_text(message)])) + messages.append(Content(role="user", parts=[Part.from_text(input)])) + return messages diff --git a/tests/unit/llm/test_openai_llm.py b/tests/unit/llm/test_openai_llm.py index 546d4e39..412789ce 100644 --- a/tests/unit/llm/test_openai_llm.py +++ b/tests/unit/llm/test_openai_llm.py @@ -33,7 +33,7 @@ def test_openai_llm_missing_dependency(mock_import: Mock) -> None: @patch("builtins.__import__") -def test_openai_llm_happy_path(mock_import: Mock) -> None: +def test_openai_llm_invoke_happy_path(mock_import: Mock) -> None: mock_openai = get_mock_openai() mock_import.return_value = mock_openai mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( @@ -46,6 +46,20 @@ def test_openai_llm_happy_path(mock_import: Mock) -> None: assert res.content == "openai chat response" +@patch("builtins.__import__") +def test_openai_llm_chat_happy_path(mock_import: Mock) -> None: + mock_openai = get_mock_openai() + mock_import.return_value = mock_openai + mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( + choices=[MagicMock(message=MagicMock(content="openai chat response"))], + ) + llm = OpenAILLM(api_key="my key", model_name="gpt") + + res = llm.chat("my question", ["user message", "assistant message"]) + assert isinstance(res, LLMResponse) + assert res.content == "openai chat response" + + @patch("builtins.__import__", side_effect=ImportError) def test_azure_openai_llm_missing_dependency(mock_import: Mock) -> None: with pytest.raises(ImportError): @@ -71,3 +85,28 @@ def test_azure_openai_llm_happy_path(mock_import: Mock) -> None: res = llm.invoke("my text") assert isinstance(res, LLMResponse) assert res.content == "openai chat response" + + +def test_openai_llm_get_conversation_history() -> None: + system_instruction = "You are a helpful assistant." + question = "When does it set?" + chat_history = [ + "When does the sun come up in the summer?", + "Usually around 6am.", + "What about next season?", + "Around 8am.", + ] + expected_response = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + {"role": "user", "content": "What about next season?"}, + {"role": "assistant", "content": "Around 8am."}, + {"role": "user", "content": "When does it set?"}, + ] + + llm = OpenAILLM( + api_key="my key", model_name="gpt", system_instruction=system_instruction + ) + response = llm.get_conversation_history(question, chat_history) + assert response == expected_response diff --git a/tests/unit/llm/test_vertexai_llm.py b/tests/unit/llm/test_vertexai_llm.py index adffeb1d..209548f0 100644 --- a/tests/unit/llm/test_vertexai_llm.py +++ b/tests/unit/llm/test_vertexai_llm.py @@ -16,7 +16,7 @@ from unittest.mock import AsyncMock, MagicMock, Mock, patch import pytest -from neo4j_graphrag.llm.vertexai_llm import VertexAILLM +from neo4j_graphrag.llm.vertexai_llm import VertexAILLM, Part, Content @patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel", None) @@ -52,3 +52,36 @@ async def test_vertexai_ainvoke_happy_path(GenerativeModelMock: MagicMock) -> No response = await llm.ainvoke(input_text) assert response.content == "Return text" llm.model.generate_content_async.assert_called_once_with(input_text, **model_params) + + +@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") +def test_vertexai_get_conversation_history(GenerativeModelMock: MagicMock) -> None: + system_instruction = "You are a helpful assistant." + question = "When does it set?" + chat_history = [ + "When does the sun come up in the summer?", + "Usually around 6am.", + "What about next season?", + "Around 8am.", + ] + expected_response = [ + Content( + role="user", + parts=[Part.from_text("When does the sun come up in the summer?")], + ), + Content(role="model", parts=[Part.from_text("Usually around 6am.")]), + Content(role="user", parts=[Part.from_text("What about next season?")]), + Content(role="model", parts=[Part.from_text("Around 8am.")]), + Content(role="user", parts=[Part.from_text("When does it set?")]), + ] + + llm = VertexAILLM( + model_name="gemini-1.5-flash-001", system_instruction=system_instruction + ) + response = llm.get_conversation_history(question, chat_history) + + assert llm.system_instruction == system_instruction + assert len(response) == len(expected_response) + for actual, expected in zip(response, expected_response): + assert actual.role == expected.role + assert actual.parts[0].text == expected.parts[0].text