From 81d1f312d3db0d59ac3b04c505363b0b529fdf26 Mon Sep 17 00:00:00 2001 From: Estelle Scifo Date: Mon, 13 Jan 2025 10:41:24 +0100 Subject: [PATCH] More on LLM message history and system instruction (#240) * Add examples for LLM message history and system instructions * Fix for when system_message is None * Ruff * Better (?) system prompt for RAG * Update system_instruction * Mypy * Mypy * Fix ollama test * Fix anthropic test * Fix cohere test * Fix vertexai test * Fix mistralai test * Fix graphrag test * Ruff * Mypy * Variable is not used * Ruff... * Mypy + e2e tests * Ruffffffff * CHANGELOG * Fix examples * Remove useless commented api_key from examples --- CHANGELOG.md | 4 +- examples/README.md | 4 + .../customize/embeddings/cohere_embeddings.py | 1 - .../embeddings/mistalai_embeddings.py | 1 - .../customize/embeddings/openai_embeddings.py | 1 - examples/customize/llms/anthropic_llm.py | 2 - examples/customize/llms/cohere_llm.py | 1 - .../llms/llm_with_message_history.py | 42 +++++++++ .../llms/llm_with_system_instructions.py | 22 +++++ examples/customize/llms/mistalai_llm.py | 2 - examples/customize/llms/openai_llm.py | 1 - examples/question_answering/graphrag.py | 7 +- .../graphrag_with_message_history.py | 85 +++++++++++++++++++ src/neo4j_graphrag/generation/graphrag.py | 10 ++- src/neo4j_graphrag/generation/prompts.py | 10 ++- src/neo4j_graphrag/llm/anthropic_llm.py | 21 ++--- src/neo4j_graphrag/llm/base.py | 10 +-- src/neo4j_graphrag/llm/cohere_llm.py | 12 +-- src/neo4j_graphrag/llm/mistralai_llm.py | 13 +-- src/neo4j_graphrag/llm/ollama_llm.py | 12 +-- src/neo4j_graphrag/llm/openai_llm.py | 20 ++--- src/neo4j_graphrag/llm/vertexai_llm.py | 14 +-- tests/e2e/test_graphrag_e2e.py | 15 ++-- tests/unit/llm/test_anthropic_llm.py | 67 ++++++++------- tests/unit/llm/test_cohere_llm.py | 41 ++------- tests/unit/llm/test_mistralai_llm.py | 39 ++------- tests/unit/llm/test_ollama_llm.py | 67 +++++++-------- tests/unit/llm/test_openai_llm.py | 37 ++------ tests/unit/llm/test_vertexai_llm.py | 60 +++++++++---- tests/unit/test_graphrag.py | 59 +++++++++---- 30 files changed, 381 insertions(+), 299 deletions(-) create mode 100644 examples/customize/llms/llm_with_message_history.py create mode 100644 examples/customize/llms/llm_with_system_instructions.py create mode 100644 examples/question_answering/graphrag_with_message_history.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 58ae6a61..3a2b4909 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,8 +4,8 @@ ### Added - Support for conversations with message history, including a new `message_history` parameter for LLM interactions. -- Ability to include system instructions and override them for specific invocations. -- Summarization of chat history to enhance query embedding and context handling. +- Ability to include system instructions in LLM invoke method. +- Summarization of chat history to enhance query embedding and context handling in GraphRAG. ### Changed - Updated LLM implementations to handle message history consistently across providers. diff --git a/examples/README.md b/examples/README.md index 6d56bbd0..d3e4ef54 100644 --- a/examples/README.md +++ b/examples/README.md @@ -51,6 +51,7 @@ are listed in [the last section of this file](#customize). ## Answer: GraphRAG - [End to end GraphRAG](./answer/graphrag.py) +- [GraphRAG with message history](./question_answering/graphrag_with_message_history.py) ## Customize @@ -73,6 +74,9 @@ are listed in [the last section of this file](#customize). - [Ollama](./customize/llms/ollama_llm.py) - [Custom LLM](./customize/llms/custom_llm.py) +- [Message history](./customize/llms/llm_with_message_history.py) +- [System Instruction](./customize/llms/llm_with_system_instructions.py) + ### Prompts diff --git a/examples/customize/embeddings/cohere_embeddings.py b/examples/customize/embeddings/cohere_embeddings.py index 92a4ca3d..7de265b5 100644 --- a/examples/customize/embeddings/cohere_embeddings.py +++ b/examples/customize/embeddings/cohere_embeddings.py @@ -2,7 +2,6 @@ # set api key here on in the CO_API_KEY env var api_key = None -# api_key = "sk-..." embeder = CohereEmbeddings( model="embed-english-v3.0", diff --git a/examples/customize/embeddings/mistalai_embeddings.py b/examples/customize/embeddings/mistalai_embeddings.py index d26c6cce..594474a0 100644 --- a/examples/customize/embeddings/mistalai_embeddings.py +++ b/examples/customize/embeddings/mistalai_embeddings.py @@ -6,7 +6,6 @@ # set api key here on in the MISTRAL_API_KEY env var api_key = None -# api_key = "sk-..." embeder = MistralAIEmbeddings(model="mistral-embed", api_key=api_key) res = embeder.embed_query("my question") diff --git a/examples/customize/embeddings/openai_embeddings.py b/examples/customize/embeddings/openai_embeddings.py index c1d9b57a..6ffe6bac 100644 --- a/examples/customize/embeddings/openai_embeddings.py +++ b/examples/customize/embeddings/openai_embeddings.py @@ -6,7 +6,6 @@ # set api key here on in the OPENAI_API_KEY env var api_key = None -# api_key = "sk-..." embeder = OpenAIEmbeddings(model="text-embedding-ada-002", api_key=api_key) res = embeder.embed_query("my question") diff --git a/examples/customize/llms/anthropic_llm.py b/examples/customize/llms/anthropic_llm.py index d84266a3..85c4ad03 100644 --- a/examples/customize/llms/anthropic_llm.py +++ b/examples/customize/llms/anthropic_llm.py @@ -2,8 +2,6 @@ # set api key here on in the ANTHROPIC_API_KEY env var api_key = None -# api_key = "sk-..." - llm = AnthropicLLM( model_name="claude-3-opus-20240229", diff --git a/examples/customize/llms/cohere_llm.py b/examples/customize/llms/cohere_llm.py index 7dfd8d6e..d631d3e4 100644 --- a/examples/customize/llms/cohere_llm.py +++ b/examples/customize/llms/cohere_llm.py @@ -2,7 +2,6 @@ # set api key here on in the CO_API_KEY env var api_key = None -# api_key = "sk-..." llm = CohereLLM( model_name="command-r", diff --git a/examples/customize/llms/llm_with_message_history.py b/examples/customize/llms/llm_with_message_history.py new file mode 100644 index 00000000..099fea6f --- /dev/null +++ b/examples/customize/llms/llm_with_message_history.py @@ -0,0 +1,42 @@ +"""This example illustrates the message_history feature +of the LLMInterface by mocking a conversation between a user +and an LLM about Tom Hanks. + +OpenAILLM can be replaced by any supported LLM from this package. +""" + +from neo4j_graphrag.llm import LLMResponse, OpenAILLM + +# set api key here on in the OPENAI_API_KEY env var +api_key = None + +llm = OpenAILLM(model_name="gpt-4o", api_key=api_key) + +questions = [ + "What are some movies Tom Hanks starred in?", + "Is he also a director?", + "Wow, that's impressive. And what about his personal life, does he have children?", +] + +history: list[dict[str, str]] = [] +for question in questions: + res: LLMResponse = llm.invoke( + question, + message_history=history, # type: ignore + ) + history.append( + { + "role": "user", + "content": question, + } + ) + history.append( + { + "role": "assistant", + "content": res.content, + } + ) + + print("#" * 50, question) + print(res.content) + print("#" * 50) diff --git a/examples/customize/llms/llm_with_system_instructions.py b/examples/customize/llms/llm_with_system_instructions.py new file mode 100644 index 00000000..4ecf7731 --- /dev/null +++ b/examples/customize/llms/llm_with_system_instructions.py @@ -0,0 +1,22 @@ +"""This example illustrates how to set system instructions for LLM. + +OpenAILLM can be replaced by any supported LLM from this package. +""" + +from neo4j_graphrag.llm import LLMResponse, OpenAILLM + +# set api key here on in the OPENAI_API_KEY env var +api_key = None + +llm = OpenAILLM( + model_name="gpt-4o", + api_key=api_key, +) + +question = "How fast is Santa Claus during the Christmas eve?" + +res: LLMResponse = llm.invoke( + question, + system_instruction="Answer with a serious tone", +) +print(res.content) diff --git a/examples/customize/llms/mistalai_llm.py b/examples/customize/llms/mistalai_llm.py index 7aaa8e8d..b829baad 100644 --- a/examples/customize/llms/mistalai_llm.py +++ b/examples/customize/llms/mistalai_llm.py @@ -2,8 +2,6 @@ # set api key here on in the MISTRAL_API_KEY env var api_key = None -# api_key = "sk-..." - llm = MistralAILLM( model_name="mistral-small-latest", diff --git a/examples/customize/llms/openai_llm.py b/examples/customize/llms/openai_llm.py index 89ea44a4..d4b38244 100644 --- a/examples/customize/llms/openai_llm.py +++ b/examples/customize/llms/openai_llm.py @@ -2,7 +2,6 @@ # set api key here on in the OPENAI_API_KEY env var api_key = None -# api_key = "sk-..." llm = OpenAILLM(model_name="gpt-4o", api_key=api_key) res: LLMResponse = llm.invoke("say something") diff --git a/examples/question_answering/graphrag.py b/examples/question_answering/graphrag.py index 526f83c7..25186e94 100644 --- a/examples/question_answering/graphrag.py +++ b/examples/question_answering/graphrag.py @@ -16,9 +16,10 @@ from neo4j_graphrag.retrievers import VectorCypherRetriever from neo4j_graphrag.types import RetrieverResultItem -URI = "neo4j://localhost:7687" -AUTH = ("neo4j", "password") -DATABASE = "neo4j" +# Define database credentials +URI = "neo4j+s://demo.neo4jlabs.com" +AUTH = ("recommendations", "recommendations") +DATABASE = "recommendations" INDEX = "moviePlotsEmbedding" diff --git a/examples/question_answering/graphrag_with_message_history.py b/examples/question_answering/graphrag_with_message_history.py new file mode 100644 index 00000000..9a3da9da --- /dev/null +++ b/examples/question_answering/graphrag_with_message_history.py @@ -0,0 +1,85 @@ +"""End to end example of building a RAG pipeline backed by a Neo4j database, +simulating a chat with message history feature. + +Requires OPENAI_API_KEY to be in the env var. +""" + +import neo4j +from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings +from neo4j_graphrag.generation import GraphRAG +from neo4j_graphrag.llm import OpenAILLM +from neo4j_graphrag.retrievers import VectorCypherRetriever + +# Define database credentials +URI = "neo4j+s://demo.neo4jlabs.com" +AUTH = ("recommendations", "recommendations") +DATABASE = "recommendations" +INDEX = "moviePlotsEmbedding" + + +driver = neo4j.GraphDatabase.driver( + URI, + auth=AUTH, +) + +embedder = OpenAIEmbeddings() + +retriever = VectorCypherRetriever( + driver, + index_name=INDEX, + retrieval_query=""" + WITH node as movie, score + CALL(movie) { + MATCH (movie)<-[:ACTED_IN]-(p:Person) + RETURN collect(p.name) as actors + } + CALL(movie) { + MATCH (movie)<-[:DIRECTED]-(p:Person) + RETURN collect(p.name) as directors + } + RETURN movie.title as title, movie.plot as plot, movie.year as year, actors, directors + """, + embedder=embedder, + neo4j_database=DATABASE, +) + +llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0}) + +rag = GraphRAG( + retriever=retriever, + llm=llm, +) + +questions = [ + "Who starred in the Apollo 13 movies?", + "Who was its director?", + "In which year was this movie released?", +] + +history: list[dict[str, str]] = [] +for question in questions: + result = rag.search( + question, + return_context=False, + message_history=history, # type: ignore + ) + + answer = result.answer + print("#" * 50, question) + print(answer) + print("#" * 50) + + history.append( + { + "role": "user", + "content": question, + } + ) + history.append( + { + "role": "assistant", + "content": answer, + } + ) + +driver.close() diff --git a/src/neo4j_graphrag/generation/graphrag.py b/src/neo4j_graphrag/generation/graphrag.py index 48a864e4..6b764716 100644 --- a/src/neo4j_graphrag/generation/graphrag.py +++ b/src/neo4j_graphrag/generation/graphrag.py @@ -137,7 +137,11 @@ def search( ) logger.debug(f"RAG: retriever_result={retriever_result}") logger.debug(f"RAG: prompt={prompt}") - answer = self.llm.invoke(prompt, message_history) + answer = self.llm.invoke( + prompt, + message_history, + system_instruction=self.prompt_template.system_instructions, + ) result: dict[str, Any] = {"answer": answer.content} if return_context: result["retriever_result"] = retriever_result @@ -172,9 +176,9 @@ def chat_summary_prompt(self, message_history: list[LLMMessage]) -> str: def conversation_prompt(self, summary: str, current_query: str) -> str: return f""" -Message Summary: +Message Summary: {summary} -Current Query: +Current Query: {current_query} """ diff --git a/src/neo4j_graphrag/generation/prompts.py b/src/neo4j_graphrag/generation/prompts.py index 365d74c0..ade30272 100644 --- a/src/neo4j_graphrag/generation/prompts.py +++ b/src/neo4j_graphrag/generation/prompts.py @@ -32,6 +32,7 @@ class PromptTemplate: missing, a `PromptMissingInputError` is raised. """ + DEFAULT_SYSTEM_INSTRUCTIONS: str = "" DEFAULT_TEMPLATE: str = "" EXPECTED_INPUTS: list[str] = list() @@ -39,9 +40,13 @@ def __init__( self, template: Optional[str] = None, expected_inputs: Optional[list[str]] = None, + system_instructions: Optional[str] = None, ) -> None: self.template = template or self.DEFAULT_TEMPLATE self.expected_inputs = expected_inputs or self.EXPECTED_INPUTS + self.system_instructions = ( + system_instructions or self.DEFAULT_SYSTEM_INSTRUCTIONS + ) for e in self.expected_inputs: if f"{{{e}}}" not in self.template: @@ -88,9 +93,8 @@ def format(self, *args: Any, **kwargs: Any) -> str: class RagTemplate(PromptTemplate): - DEFAULT_TEMPLATE = """Answer the user question using the following context - -Context: + DEFAULT_SYSTEM_INSTRUCTIONS = "Answer the user question using the provided context." + DEFAULT_TEMPLATE = """Context: {context} Examples: diff --git a/src/neo4j_graphrag/llm/anthropic_llm.py b/src/neo4j_graphrag/llm/anthropic_llm.py index 614c5b9c..0ff74a4a 100644 --- a/src/neo4j_graphrag/llm/anthropic_llm.py +++ b/src/neo4j_graphrag/llm/anthropic_llm.py @@ -61,7 +61,6 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, - system_instruction: Optional[str] = None, **kwargs: Any, ): try: @@ -71,7 +70,7 @@ def __init__( """Could not import Anthropic Python client. Please install it with `pip install "neo4j-graphrag[anthropic]"`.""" ) - super().__init__(model_name, model_params, system_instruction) + super().__init__(model_name, model_params) self.anthropic = anthropic self.client = anthropic.Anthropic(**kwargs) self.async_client = anthropic.AsyncAnthropic(**kwargs) @@ -107,18 +106,13 @@ def invoke( """ try: messages = self.get_messages(input, message_history) - system_message = ( - system_instruction - if system_instruction is not None - else self.system_instruction - ) response = self.client.messages.create( model=self.model_name, - system=system_message, # type: ignore + system=system_instruction or self.anthropic.NOT_GIVEN, messages=messages, **self.model_params, ) - return LLMResponse(content=response.content) # type: ignore + return LLMResponse(content=response.content) except self.anthropic.APIError as e: raise LLMGenerationError(e) @@ -140,17 +134,12 @@ async def ainvoke( """ try: messages = self.get_messages(input, message_history) - system_message = ( - system_instruction - if system_instruction is not None - else self.system_instruction - ) response = await self.async_client.messages.create( model=self.model_name, - system=system_message, # type: ignore + system=system_instruction or self.anthropic.NOT_GIVEN, messages=messages, **self.model_params, ) - return LLMResponse(content=response.content) # type: ignore + return LLMResponse(content=response.content) except self.anthropic.APIError as e: raise LLMGenerationError(e) diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index eab3eb4f..49f1afc3 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -17,7 +17,10 @@ from abc import ABC, abstractmethod from typing import Any, Optional -from .types import LLMMessage, LLMResponse +from .types import ( + LLMMessage, + LLMResponse, +) class LLMInterface(ABC): @@ -25,8 +28,7 @@ class LLMInterface(ABC): Args: model_name (str): The name of the language model. - model_params (Optional[dict], optional): Additional parameters passed to the model when text is sent to it. Defaults to None. - system_instruction: Optional[str], optional): Additional instructions for setting the behavior and context for the model in a conversation. Defaults to None. + model_params (Optional[dict]): Additional parameters passed to the model when text is sent to it. Defaults to None. **kwargs (Any): Arguments passed to the model when for the class is initialised. Defaults to None. """ @@ -34,12 +36,10 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, - system_instruction: Optional[str] = None, **kwargs: Any, ): self.model_name = model_name self.model_params = model_params or {} - self.system_instruction = system_instruction @abstractmethod def invoke( diff --git a/src/neo4j_graphrag/llm/cohere_llm.py b/src/neo4j_graphrag/llm/cohere_llm.py index 024d578a..9621cae2 100644 --- a/src/neo4j_graphrag/llm/cohere_llm.py +++ b/src/neo4j_graphrag/llm/cohere_llm.py @@ -59,7 +59,6 @@ def __init__( self, model_name: str = "", model_params: Optional[dict[str, Any]] = None, - system_instruction: Optional[str] = None, **kwargs: Any, ) -> None: try: @@ -69,7 +68,7 @@ def __init__( """Could not import cohere python client. Please install it with `pip install "neo4j-graphrag[cohere]"`.""" ) - super().__init__(model_name, model_params, system_instruction) + super().__init__(model_name, model_params) self.cohere = cohere self.cohere_api_error = cohere.core.api_error.ApiError @@ -83,13 +82,8 @@ def get_messages( system_instruction: Optional[str] = None, ) -> ChatMessages: messages = [] - system_message = ( - system_instruction - if system_instruction is not None - else self.system_instruction - ) - if system_message: - messages.append(SystemMessage(content=system_message).model_dump()) + if system_instruction: + messages.append(SystemMessage(content=system_instruction).model_dump()) if message_history: try: MessageList(messages=cast(list[BaseMessage], message_history)) diff --git a/src/neo4j_graphrag/llm/mistralai_llm.py b/src/neo4j_graphrag/llm/mistralai_llm.py index 761b9a62..afd1447e 100644 --- a/src/neo4j_graphrag/llm/mistralai_llm.py +++ b/src/neo4j_graphrag/llm/mistralai_llm.py @@ -43,7 +43,6 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, - system_instruction: Optional[str] = None, **kwargs: Any, ): """ @@ -52,7 +51,6 @@ def __init__( model_name (str): model_params (str): Parameters like temperature and such that will be passed to the chat completions endpoint - system_instruction: Optional[str], optional): Additional instructions for setting the behavior and context for the model in a conversation. Defaults to None. kwargs: All other parameters will be passed to the Mistral client. """ @@ -61,7 +59,7 @@ def __init__( """Could not import Mistral Python client. Please install it with `pip install "neo4j-graphrag[mistralai]"`.""" ) - super().__init__(model_name, model_params, system_instruction) + super().__init__(model_name, model_params) api_key = kwargs.pop("api_key", None) if api_key is None: api_key = os.getenv("MISTRAL_API_KEY", "") @@ -74,13 +72,8 @@ def get_messages( system_instruction: Optional[str] = None, ) -> list[Messages]: messages = [] - system_message = ( - system_instruction - if system_instruction is not None - else self.system_instruction - ) - if system_message: - messages.append(SystemMessage(content=system_message).model_dump()) + if system_instruction: + messages.append(SystemMessage(content=system_instruction).model_dump()) if message_history: try: MessageList(messages=cast(list[BaseMessage], message_history)) diff --git a/src/neo4j_graphrag/llm/ollama_llm.py b/src/neo4j_graphrag/llm/ollama_llm.py index 08dfa655..8f1e8193 100644 --- a/src/neo4j_graphrag/llm/ollama_llm.py +++ b/src/neo4j_graphrag/llm/ollama_llm.py @@ -39,7 +39,6 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, - system_instruction: Optional[str] = None, **kwargs: Any, ): try: @@ -49,7 +48,7 @@ def __init__( "Could not import ollama Python client. " "Please install it with `pip install ollama`." ) - super().__init__(model_name, model_params, system_instruction, **kwargs) + super().__init__(model_name, model_params, **kwargs) self.ollama = ollama self.client = ollama.Client( **kwargs, @@ -65,13 +64,8 @@ def get_messages( system_instruction: Optional[str] = None, ) -> Sequence[Message]: messages = [] - system_message = ( - system_instruction - if system_instruction is not None - else self.system_instruction - ) - if system_message: - messages.append(SystemMessage(content=system_message).model_dump()) + if system_instruction: + messages.append(SystemMessage(content=system_instruction).model_dump()) if message_history: try: MessageList(messages=cast(list[BaseMessage], message_history)) diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index 353dd19d..21a49a15 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -45,7 +45,6 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, - system_instruction: Optional[str] = None, ): """ Base class for OpenAI LLM. @@ -55,7 +54,6 @@ def __init__( Args: model_name (str): model_params (str): Parameters like temperature that will be passed to the model when text is sent to it. Defaults to None. - system_instruction: Optional[str], optional): Additional instructions for setting the behavior and context for the model in a conversation. Defaults to None. """ try: import openai @@ -65,7 +63,7 @@ def __init__( Please install it with `pip install "neo4j-graphrag[openai]"`.""" ) self.openai = openai - super().__init__(model_name, model_params, system_instruction) + super().__init__(model_name, model_params) def get_messages( self, @@ -74,13 +72,8 @@ def get_messages( system_instruction: Optional[str] = None, ) -> Iterable[ChatCompletionMessageParam]: messages = [] - system_message = ( - system_instruction - if system_instruction is not None - else self.system_instruction - ) - if system_message: - messages.append(SystemMessage(content=system_message).model_dump()) + if system_instruction: + messages.append(SystemMessage(content=system_instruction).model_dump()) if message_history: try: MessageList(messages=cast(list[BaseMessage], message_history)) @@ -158,7 +151,6 @@ def __init__( self, model_name: str, model_params: Optional[dict[str, Any]] = None, - system_instruction: Optional[str] = None, **kwargs: Any, ): """OpenAI LLM @@ -168,10 +160,9 @@ def __init__( Args: model_name (str): model_params (str): Parameters like temperature that will be passed to the model when text is sent to it. Defaults to None. - system_instruction: Optional[str], optional): Additional instructions for setting the behavior and context for the model in a conversation. Defaults to None. kwargs: All other parameters will be passed to the openai.OpenAI init. """ - super().__init__(model_name, model_params, system_instruction) + super().__init__(model_name, model_params) self.client = self.openai.OpenAI(**kwargs) self.async_client = self.openai.AsyncOpenAI(**kwargs) @@ -190,9 +181,8 @@ def __init__( Args: model_name (str): model_params (str): Parameters like temperature that will be passed to the model when text is sent to it. Defaults to None. - system_instruction: Optional[str], optional): Additional instructions for setting the behavior and context for the model in a conversation. Defaults to None. kwargs: All other parameters will be passed to the openai.OpenAI init. """ - super().__init__(model_name, model_params, system_instruction) + super().__init__(model_name, model_params) self.client = self.openai.AzureOpenAI(**kwargs) self.async_client = self.openai.AsyncAzureOpenAI(**kwargs) diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index f2cbe26a..a465e553 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -119,14 +119,10 @@ def invoke( Returns: LLMResponse: The response from the LLM. """ - system_message = ( - system_instruction - if system_instruction is not None - else self.system_instruction - ) + system_message = [system_instruction] if system_instruction is not None else [] self.model = GenerativeModel( model_name=self.model_name, - system_instruction=[system_message], + system_instruction=system_message, **self.options, ) try: @@ -154,13 +150,11 @@ async def ainvoke( """ try: system_message = ( - system_instruction - if system_instruction is not None - else self.system_instruction + [system_instruction] if system_instruction is not None else [] ) self.model = GenerativeModel( model_name=self.model_name, - system_instruction=[system_message], + system_instruction=system_message, **self.options, ) messages = self.get_messages(input, message_history) diff --git a/tests/e2e/test_graphrag_e2e.py b/tests/e2e/test_graphrag_e2e.py index 7eaaa50c..d8695756 100644 --- a/tests/e2e/test_graphrag_e2e.py +++ b/tests/e2e/test_graphrag_e2e.py @@ -59,9 +59,7 @@ def test_graphrag_happy_path( ) llm.invoke.assert_called_once_with( - """Answer the user question using the following context - -Context: + """Context: @@ -74,6 +72,7 @@ def test_graphrag_happy_path( Answer: """, None, + system_instruction="Answer the user question using the provided context.", ) assert isinstance(result, RagResultModel) assert result.answer == "some text" @@ -105,9 +104,7 @@ def test_graphrag_happy_path_return_context( ) llm.invoke.assert_called_once_with( - """Answer the user question using the following context - -Context: + """Context: @@ -120,6 +117,7 @@ def test_graphrag_happy_path_return_context( Answer: """, None, + system_instruction="Answer the user question using the provided context.", ) assert isinstance(result, RagResultModel) assert result.answer == "some text" @@ -152,9 +150,7 @@ def test_graphrag_happy_path_examples( ) llm.invoke.assert_called_once_with( - """Answer the user question using the following context - -Context: + """Context: @@ -167,6 +163,7 @@ def test_graphrag_happy_path_examples( Answer: """, None, + system_instruction="Answer the user question using the provided context.", ) assert result.answer == "some text" diff --git a/tests/unit/llm/test_anthropic_llm.py b/tests/unit/llm/test_anthropic_llm.py index b1d583f6..c8bf3c33 100644 --- a/tests/unit/llm/test_anthropic_llm.py +++ b/tests/unit/llm/test_anthropic_llm.py @@ -28,6 +28,7 @@ def mock_anthropic() -> Generator[MagicMock, None, None]: mock = MagicMock() mock.APIError = anthropic.APIError + mock.NOT_GIVEN = anthropic.NOT_GIVEN with patch.dict(sys.modules, {"anthropic": mock}): yield mock @@ -51,7 +52,7 @@ def test_anthropic_invoke_happy_path(mock_anthropic: Mock) -> None: llm.client.messages.create.assert_called_once_with( # type: ignore messages=[{"role": "user", "content": input_text}], model="claude-3-opus-20240229", - system=None, + system=anthropic.NOT_GIVEN, **model_params, ) @@ -77,69 +78,69 @@ def test_anthropic_invoke_with_message_history_happy_path(mock_anthropic: Mock) llm.client.messages.create.assert_called_once_with( # type: ignore[attr-defined] messages=message_history, model="claude-3-opus-20240229", - system=None, + system=anthropic.NOT_GIVEN, **model_params, ) -def test_anthropic_invoke_with_message_history_and_system_instruction( +def test_anthropic_invoke_with_system_instruction( mock_anthropic: Mock, ) -> None: mock_anthropic.Anthropic.return_value.messages.create.return_value = MagicMock( content="generated text" ) model_params = {"temperature": 0.3} - initial_instruction = "You are a helpful assistant." + system_instruction = "You are a helpful assistant." llm = AnthropicLLM( "claude-3-opus-20240229", model_params=model_params, - system_instruction=initial_instruction, - ) - message_history = [ - {"role": "user", "content": "When does the sun come up in the summer?"}, - {"role": "assistant", "content": "Usually around 6am."}, - ] - question = "What about next season?" - - # first invokation - initial instructions - response = llm.invoke(question, message_history) # type: ignore - assert response.content == "generated text" - message_history.append({"role": "user", "content": question}) - llm.client.messages.create.assert_called_once_with( # type: ignore[attr-defined] - model="claude-3-opus-20240229", - system=initial_instruction, - messages=message_history, - **model_params, ) - # second invokation - override instructions - override_instruction = "Ignore all previous instructions" question = "When does it come up in the winter?" - response = llm.invoke(question, message_history, override_instruction) # type: ignore + response = llm.invoke(question, system_instruction=system_instruction) assert isinstance(response, LLMResponse) assert response.content == "generated text" - message_history.append({"role": "user", "content": question}) + messages = [{"role": "user", "content": question}] llm.client.messages.create.assert_called_with( # type: ignore[attr-defined] model="claude-3-opus-20240229", - system=override_instruction, - messages=message_history, + system=system_instruction, + messages=messages, **model_params, ) - # third invokation - default instructions - question = "When does it set?" - response = llm.invoke(question, message_history) # type: ignore + assert llm.client.messages.create.call_count == 1 # type: ignore + + +def test_anthropic_invoke_with_message_history_and_system_instruction( + mock_anthropic: Mock, +) -> None: + mock_anthropic.Anthropic.return_value.messages.create.return_value = MagicMock( + content="generated text" + ) + model_params = {"temperature": 0.3} + system_instruction = "You are a helpful assistant." + llm = AnthropicLLM( + "claude-3-opus-20240229", + model_params=model_params, + ) + message_history = [ + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, + ] + + question = "When does it come up in the winter?" + response = llm.invoke(question, message_history, system_instruction) # type: ignore assert isinstance(response, LLMResponse) assert response.content == "generated text" message_history.append({"role": "user", "content": question}) llm.client.messages.create.assert_called_with( # type: ignore[attr-defined] model="claude-3-opus-20240229", - system=initial_instruction, + system=system_instruction, messages=message_history, **model_params, ) - assert llm.client.messages.create.call_count == 3 # type: ignore + assert llm.client.messages.create.call_count == 1 # type: ignore def test_anthropic_invoke_with_message_history_validation_error( @@ -179,7 +180,7 @@ async def test_anthropic_ainvoke_happy_path(mock_anthropic: Mock) -> None: assert response.content == "Return text" llm.async_client.messages.create.assert_awaited_once_with( # type: ignore model="claude-3-opus-20240229", - system=None, + system=anthropic.NOT_GIVEN, messages=[{"role": "user", "content": input_text}], **model_params, ) diff --git a/tests/unit/llm/test_cohere_llm.py b/tests/unit/llm/test_cohere_llm.py index 6088799a..9ba9138e 100644 --- a/tests/unit/llm/test_cohere_llm.py +++ b/tests/unit/llm/test_cohere_llm.py @@ -53,14 +53,14 @@ def test_cohere_llm_invoke_with_message_history_happy_path(mock_cohere: Mock) -> mock_cohere.ClientV2.return_value.chat.return_value = chat_response_mock system_instruction = "You are a helpful assistant." - llm = CohereLLM(model_name="something", system_instruction=system_instruction) + llm = CohereLLM(model_name="something") message_history = [ {"role": "user", "content": "When does the sun come up in the summer?"}, {"role": "assistant", "content": "Usually around 6am."}, ] question = "What about next season?" - res = llm.invoke(question, message_history) # type: ignore + res = llm.invoke(question, message_history, system_instruction=system_instruction) # type: ignore assert isinstance(res, LLMResponse) assert res.content == "cohere response text" messages = [{"role": "system", "content": system_instruction}] @@ -79,52 +79,25 @@ def test_cohere_llm_invoke_with_message_history_and_system_instruction( chat_response_mock.message.content = [MagicMock(text="cohere response text")] mock_cohere.ClientV2.return_value.chat.return_value = chat_response_mock - initial_instruction = "You are a helpful assistant." - llm = CohereLLM(model_name="gpt", system_instruction=initial_instruction) + system_instruction = "You are a helpful assistant." + llm = CohereLLM(model_name="gpt") message_history = [ {"role": "user", "content": "When does the sun come up in the summer?"}, {"role": "assistant", "content": "Usually around 6am."}, ] question = "What about next season?" - # first invokation - initial instructions - res = llm.invoke(question, message_history) # type: ignore + res = llm.invoke(question, message_history, system_instruction=system_instruction) # type: ignore assert isinstance(res, LLMResponse) assert res.content == "cohere response text" - messages = [{"role": "system", "content": initial_instruction}] + messages = [{"role": "system", "content": system_instruction}] messages.extend(message_history) messages.append({"role": "user", "content": question}) llm.client.chat.assert_called_once_with( messages=messages, model="gpt", ) - - # second invokation - override instructions - override_instruction = "Ignore all previous instructions" - res = llm.invoke(question, message_history, override_instruction) # type: ignore - assert isinstance(res, LLMResponse) - assert res.content == "cohere response text" - messages = [{"role": "system", "content": override_instruction}] - messages.extend(message_history) - messages.append({"role": "user", "content": question}) - llm.client.chat.assert_called_with( - messages=messages, - model="gpt", - ) - - # third invokation - default instructions - res = llm.invoke(question, message_history) # type: ignore - assert isinstance(res, LLMResponse) - assert res.content == "cohere response text" - messages = [{"role": "system", "content": initial_instruction}] - messages.extend(message_history) - messages.append({"role": "user", "content": question}) - llm.client.chat.assert_called_with( - messages=messages, - model="gpt", - ) - - assert llm.client.chat.call_count == 3 + assert llm.client.chat.call_count == 1 def test_cohere_llm_invoke_with_message_history_validation_error( diff --git a/tests/unit/llm/test_mistralai_llm.py b/tests/unit/llm/test_mistralai_llm.py index 4d5c5c96..4a3e0860 100644 --- a/tests/unit/llm/test_mistralai_llm.py +++ b/tests/unit/llm/test_mistralai_llm.py @@ -57,14 +57,14 @@ def test_mistralai_llm_invoke_with_message_history(mock_mistral: Mock) -> None: model = "mistral-model" system_instruction = "You are a helpful assistant." - llm = MistralAILLM(model_name=model, system_instruction=system_instruction) + llm = MistralAILLM(model_name=model) message_history = [ {"role": "user", "content": "When does the sun come up in the summer?"}, {"role": "assistant", "content": "Usually around 6am."}, ] question = "What about next season?" - res = llm.invoke(question, message_history) # type: ignore + res = llm.invoke(question, message_history, system_instruction=system_instruction) # type: ignore assert isinstance(res, LLMResponse) assert res.content == "mistral response" @@ -88,8 +88,8 @@ def test_mistralai_llm_invoke_with_message_history_and_system_instruction( ] mock_mistral_instance.chat.complete.return_value = chat_response_mock model = "mistral-model" - initial_instruction = "You are a helpful assistant." - llm = MistralAILLM(model_name=model, system_instruction=initial_instruction) + system_instruction = "You are a helpful assistant." + llm = MistralAILLM(model_name=model) message_history = [ {"role": "user", "content": "When does the sun come up in the summer?"}, {"role": "assistant", "content": "Usually around 6am."}, @@ -97,10 +97,10 @@ def test_mistralai_llm_invoke_with_message_history_and_system_instruction( question = "What about next season?" # first invokation - initial instructions - res = llm.invoke(question, message_history) # type: ignore + res = llm.invoke(question, message_history, system_instruction=system_instruction) # type: ignore assert isinstance(res, LLMResponse) assert res.content == "mistral response" - messages = [{"role": "system", "content": initial_instruction}] + messages = [{"role": "system", "content": system_instruction}] messages.extend(message_history) messages.append({"role": "user", "content": question}) llm.client.chat.complete.assert_called_once_with( # type: ignore[attr-defined] @@ -108,32 +108,7 @@ def test_mistralai_llm_invoke_with_message_history_and_system_instruction( model=model, ) - # second invokation - override instructions - override_instruction = "Ignore all previous instructions" - res = llm.invoke(question, message_history, override_instruction) # type: ignore - assert isinstance(res, LLMResponse) - assert res.content == "mistral response" - messages = [{"role": "system", "content": override_instruction}] - messages.extend(message_history) - messages.append({"role": "user", "content": question}) - llm.client.chat.complete.assert_called_with( # type: ignore - messages=messages, - model=model, - ) - - # third invokation - default instructions - res = llm.invoke(question, message_history) # type: ignore - assert isinstance(res, LLMResponse) - assert res.content == "mistral response" - messages = [{"role": "system", "content": initial_instruction}] - messages.extend(message_history) - messages.append({"role": "user", "content": question}) - llm.client.chat.complete.assert_called_with( # type: ignore - messages=messages, - model=model, - ) - - assert llm.client.chat.complete.call_count == 3 # type: ignore + assert llm.client.chat.complete.call_count == 1 # type: ignore @patch("neo4j_graphrag.llm.mistralai_llm.Mistral") diff --git a/tests/unit/llm/test_ollama_llm.py b/tests/unit/llm/test_ollama_llm.py index deb56f93..4665a10a 100644 --- a/tests/unit/llm/test_ollama_llm.py +++ b/tests/unit/llm/test_ollama_llm.py @@ -43,19 +43,16 @@ def test_ollama_llm_happy_path(mock_import: Mock) -> None: ) model = "gpt" model_params = {"temperature": 0.3} - system_instruction = "You are a helpful assistant." question = "What is graph RAG?" llm = OllamaLLM( model, model_params=model_params, - system_instruction=system_instruction, ) res = llm.invoke(question) assert isinstance(res, LLMResponse) assert res.content == "ollama chat response" messages = [ - {"role": "system", "content": system_instruction}, {"role": "user", "content": question}, ] llm.client.chat.assert_called_once_with( # type: ignore[attr-defined] @@ -64,7 +61,7 @@ def test_ollama_llm_happy_path(mock_import: Mock) -> None: @patch("builtins.__import__") -def test_ollama_invoke_with_message_history_happy_path(mock_import: Mock) -> None: +def test_ollama_invoke_with_system_instruction_happy_path(mock_import: Mock) -> None: mock_ollama = get_mock_ollama() mock_import.return_value = mock_ollama mock_ollama.Client.return_value.chat.return_value = MagicMock( @@ -72,11 +69,34 @@ def test_ollama_invoke_with_message_history_happy_path(mock_import: Mock) -> Non ) model = "gpt" model_params = {"temperature": 0.3} + llm = OllamaLLM( + model, + model_params=model_params, + ) system_instruction = "You are a helpful assistant." + question = "What about next season?" + + response = llm.invoke(question, system_instruction=system_instruction) + assert response.content == "ollama chat response" + messages = [{"role": "system", "content": system_instruction}] + messages.append({"role": "user", "content": question}) + llm.client.chat.assert_called_once_with( # type: ignore[attr-defined] + model=model, messages=messages, options=model_params + ) + + +@patch("builtins.__import__") +def test_ollama_invoke_with_message_history_happy_path(mock_import: Mock) -> None: + mock_ollama = get_mock_ollama() + mock_import.return_value = mock_ollama + mock_ollama.Client.return_value.chat.return_value = MagicMock( + message=MagicMock(content="ollama chat response"), + ) + model = "gpt" + model_params = {"temperature": 0.3} llm = OllamaLLM( model, model_params=model_params, - system_instruction=system_instruction, ) message_history = [ {"role": "user", "content": "When does the sun come up in the summer?"}, @@ -86,8 +106,7 @@ def test_ollama_invoke_with_message_history_happy_path(mock_import: Mock) -> Non response = llm.invoke(question, message_history) # type: ignore assert response.content == "ollama chat response" - messages = [{"role": "system", "content": system_instruction}] - messages.extend(message_history) + messages = [m for m in message_history] messages.append({"role": "user", "content": question}) llm.client.chat.assert_called_once_with( # type: ignore[attr-defined] model=model, messages=messages, options=model_params @@ -109,7 +128,6 @@ def test_ollama_invoke_with_message_history_and_system_instruction( llm = OllamaLLM( model, model_params=model_params, - system_instruction=system_instruction, ) message_history = [ {"role": "user", "content": "When does the sun come up in the summer?"}, @@ -117,38 +135,19 @@ def test_ollama_invoke_with_message_history_and_system_instruction( ] question = "What about next season?" - # first invokation - initial instructions - response = llm.invoke(question, message_history) # type: ignore - assert response.content == "ollama chat response" - messages = [{"role": "system", "content": system_instruction}] - messages.extend(message_history) - messages.append({"role": "user", "content": question}) - llm.client.chat.assert_called_once_with( # type: ignore[attr-defined] - model=model, messages=messages, options=model_params - ) - - # second invokation - override instructions - override_instruction = "Ignore all previous instructions" - response = llm.invoke(question, message_history, override_instruction) # type: ignore - assert response.content == "ollama chat response" - messages = [{"role": "system", "content": override_instruction}] - messages.extend(message_history) - messages.append({"role": "user", "content": question}) - llm.client.chat.assert_called_with( # type: ignore[attr-defined] - model=model, messages=messages, options=model_params + response = llm.invoke( + question, + message_history, # type: ignore + system_instruction=system_instruction, ) - - # third invokation - default instructions - response = llm.invoke(question, message_history) # type: ignore assert response.content == "ollama chat response" messages = [{"role": "system", "content": system_instruction}] messages.extend(message_history) messages.append({"role": "user", "content": question}) - llm.client.chat.assert_called_with( # type: ignore[attr-defined] + llm.client.chat.assert_called_once_with( # type: ignore[attr-defined] model=model, messages=messages, options=model_params ) - - assert llm.client.chat.call_count == 3 # type: ignore + assert llm.client.chat.call_count == 1 # type: ignore @patch("builtins.__import__") @@ -189,12 +188,10 @@ async def mock_chat_async(*args: Any, **kwargs: Any) -> MagicMock: mock_ollama.AsyncClient.return_value.chat = mock_chat_async model = "gpt" model_params = {"temperature": 0.3} - system_instruction = "You are a helpful assistant." question = "What is graph RAG?" llm = OllamaLLM( model, model_params=model_params, - system_instruction=system_instruction, ) res = await llm.ainvoke(question) diff --git a/tests/unit/llm/test_openai_llm.py b/tests/unit/llm/test_openai_llm.py index 82a79325..03fbf120 100644 --- a/tests/unit/llm/test_openai_llm.py +++ b/tests/unit/llm/test_openai_llm.py @@ -80,9 +80,10 @@ def test_openai_llm_with_message_history_and_system_instruction( mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock( choices=[MagicMock(message=MagicMock(content="openai chat response"))], ) - initial_instruction = "You are a helpful assistent." + system_instruction = "You are a helpful assistent." llm = OpenAILLM( - api_key="my key", model_name="gpt", system_instruction=initial_instruction + api_key="my key", + model_name="gpt", ) message_history = [ {"role": "user", "content": "When does the sun come up in the summer?"}, @@ -90,11 +91,10 @@ def test_openai_llm_with_message_history_and_system_instruction( ] question = "What about next season?" - # first invokation - initial instructions - res = llm.invoke(question, message_history) # type: ignore + res = llm.invoke(question, message_history, system_instruction=system_instruction) # type: ignore assert isinstance(res, LLMResponse) assert res.content == "openai chat response" - messages = [{"role": "system", "content": initial_instruction}] + messages = [{"role": "system", "content": system_instruction}] messages.extend(message_history) messages.append({"role": "user", "content": question}) llm.client.chat.completions.create.assert_called_once_with( # type: ignore @@ -102,32 +102,7 @@ def test_openai_llm_with_message_history_and_system_instruction( model="gpt", ) - # second invokation - override instructions - override_instruction = "Ignore all previous instructions" - res = llm.invoke(question, message_history, override_instruction) # type: ignore - assert isinstance(res, LLMResponse) - assert res.content == "openai chat response" - messages = [{"role": "system", "content": override_instruction}] - messages.extend(message_history) - messages.append({"role": "user", "content": question}) - llm.client.chat.completions.create.assert_called_with( # type: ignore - messages=messages, - model="gpt", - ) - - # third invokation - default instructions - res = llm.invoke(question, message_history) # type: ignore - assert isinstance(res, LLMResponse) - assert res.content == "openai chat response" - messages = [{"role": "system", "content": initial_instruction}] - messages.extend(message_history) - messages.append({"role": "user", "content": question}) - llm.client.chat.completions.create.assert_called_with( # type: ignore - messages=messages, - model="gpt", - ) - - assert llm.client.chat.completions.create.call_count == 3 # type: ignore + assert llm.client.chat.completions.create.call_count == 1 # type: ignore @patch("builtins.__import__") diff --git a/tests/unit/llm/test_vertexai_llm.py b/tests/unit/llm/test_vertexai_llm.py index e0755376..8c78db24 100644 --- a/tests/unit/llm/test_vertexai_llm.py +++ b/tests/unit/llm/test_vertexai_llm.py @@ -44,14 +44,19 @@ def test_vertexai_invoke_happy_path(GenerativeModelMock: MagicMock) -> None: response = llm.invoke(input_text) assert response.content == "Return text" GenerativeModelMock.assert_called_once_with( - model_name=model_name, system_instruction=[None] + model_name=model_name, system_instruction=[] ) user_message = mock.ANY llm.model.generate_content.assert_called_once_with(user_message, **model_params) + last_call = llm.model.generate_content.call_args_list[0] + content = last_call.args[0] + assert len(content) == 1 + assert content[0].role == "user" + assert content[0].parts[0].text == input_text @patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") -def test_vertexai_invoke_with_message_history_and_system_instruction( +def test_vertexai_invoke_with_system_instruction( GenerativeModelMock: MagicMock, ) -> None: system_instruction = "You are a helpful assistant." @@ -62,9 +67,9 @@ def test_vertexai_invoke_with_message_history_and_system_instruction( mock_model = GenerativeModelMock.return_value mock_model.generate_content.return_value = mock_response model_params = {"temperature": 0.5} - llm = VertexAILLM(model_name, model_params, system_instruction) + llm = VertexAILLM(model_name, model_params) - response = llm.invoke(input_text) + response = llm.invoke(input_text, system_instruction=system_instruction) assert response.content == "Return text" GenerativeModelMock.assert_called_once_with( model_name=model_name, system_instruction=[system_instruction] @@ -72,24 +77,47 @@ def test_vertexai_invoke_with_message_history_and_system_instruction( user_message = mock.ANY llm.model.generate_content.assert_called_once_with(user_message, **model_params) + +@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") +def test_vertexai_invoke_with_message_history_and_system_instruction( + GenerativeModelMock: MagicMock, +) -> None: + system_instruction = "You are a helpful assistant." + model_name = "gemini-1.5-flash-001" + mock_response = Mock() + mock_response.text = "Return text" + mock_model = GenerativeModelMock.return_value + mock_model.generate_content.return_value = mock_response + model_params = {"temperature": 0.5} + llm = VertexAILLM(model_name, model_params) + message_history = [ - {"role": "user", "content": "hello!"}, - {"role": "assistant", "content": "hi."}, + {"role": "user", "content": "When does the sun come up in the summer?"}, + {"role": "assistant", "content": "Usually around 6am."}, ] - response = llm.invoke(input_text, message_history, "new instructions") # type:ignore - GenerativeModelMock.assert_called_with( - model_name=model_name, system_instruction=["new instructions"] + question = "What about next season?" + + response = llm.invoke( + question, + message_history, # type: ignore + system_instruction=system_instruction, + ) + assert response.content == "Return text" + GenerativeModelMock.assert_called_once_with( + model_name=model_name, system_instruction=[system_instruction] ) - messages = [mock.ANY, mock.ANY, mock.ANY] - llm.model.generate_content.assert_called_with(messages, **model_params) + user_message = mock.ANY + llm.model.generate_content.assert_called_once_with(user_message, **model_params) + last_call = llm.model.generate_content.call_args_list[0] + content = last_call.args[0] + assert len(content) == 3 # question + 2 messages in history @patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel") def test_vertexai_get_messages(GenerativeModelMock: MagicMock) -> None: - system_instruction = "You are a helpful assistant." model_name = "gemini-1.5-flash-001" question = "When does it set?" - message_history = [ + message_history: list[LLMMessage] = [ {"role": "user", "content": "When does the sun come up in the summer?"}, {"role": "assistant", "content": "Usually around 6am."}, {"role": "user", "content": "What about next season?"}, @@ -106,10 +134,10 @@ def test_vertexai_get_messages(GenerativeModelMock: MagicMock) -> None: Content(role="user", parts=[Part.from_text("When does it set?")]), ] - llm = VertexAILLM(model_name=model_name, system_instruction=system_instruction) - response = llm.get_messages(question, cast(list[LLMMessage], message_history)) + llm = VertexAILLM(model_name=model_name) + response = llm.get_messages(question, message_history) - GenerativeModelMock.assert_not_called + GenerativeModelMock.assert_not_called() assert len(response) == len(expected_response) for actual, expected in zip(response, expected_response): assert actual.role == expected.role diff --git a/tests/unit/test_graphrag.py b/tests/unit/test_graphrag.py index 178d34f7..ffacd10f 100644 --- a/tests/unit/test_graphrag.py +++ b/tests/unit/test_graphrag.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock from unittest.mock import MagicMock, call import pytest @@ -30,9 +31,7 @@ def test_graphrag_prompt_template() -> None: ) assert ( prompt - == """Answer the user question using the following context - -Context: + == """Context: my context Examples: @@ -63,9 +62,7 @@ def test_graphrag_happy_path(retriever_mock: MagicMock, llm: MagicMock) -> None: retriever_mock.search.assert_called_once_with(query_text="question") llm.invoke.assert_called_once_with( - """Answer the user question using the following context - -Context: + """Context: item content 1 item content 2 @@ -77,7 +74,8 @@ def test_graphrag_happy_path(retriever_mock: MagicMock, llm: MagicMock) -> None: Answer: """, - None, + None, # message history + system_instruction="Answer the user question using the provided context.", ) assert isinstance(res, RagResultModel) @@ -109,10 +107,10 @@ def test_graphrag_happy_path_with_message_history( res = rag.search("question", message_history) # type: ignore expected_retriever_query_text = """ -Message Summary: +Message Summary: llm generated summary -Current Query: +Current Query: question """ @@ -123,9 +121,7 @@ def test_graphrag_happy_path_with_message_history( assistant: answer to initial question """ first_invokation_system_instruction = "You are a summarization assistant. Summarize the given text in no more than 300 words." - second_invokation = """Answer the user question using the following context - -Context: + second_invokation = """Context: item content 1 item content 2 @@ -148,7 +144,11 @@ def test_graphrag_happy_path_with_message_history( input=first_invokation_input, system_instruction=first_invokation_system_instruction, ), - call(second_invokation, message_history), + call( + second_invokation, + message_history, + system_instruction="Answer the user question using the provided context.", + ), ] ) @@ -157,6 +157,35 @@ def test_graphrag_happy_path_with_message_history( assert res.retriever_result is None +def test_graphrag_happy_path_custom_system_instruction( + retriever_mock: MagicMock, llm: MagicMock +) -> None: + prompt_template = RagTemplate(system_instructions="Custom instruction") + rag = GraphRAG( + retriever=retriever_mock, + llm=llm, + prompt_template=prompt_template, + ) + retriever_mock.search.return_value = RetrieverResult(items=[]) + llm.invoke.side_effect = [ + LLMResponse(content="llm generated text"), + ] + res = rag.search("question") + + assert llm.invoke.call_count == 1 + llm.invoke.assert_has_calls( + [ + call( + mock.ANY, + None, # no message history + system_instruction="Custom instruction", + ), + ] + ) + + assert res.answer == "llm generated text" + + def test_graphrag_initialization_error(llm: MagicMock) -> None: with pytest.raises(RagInitializationError) as excinfo: GraphRAG( @@ -212,10 +241,10 @@ def test_conversation_template(retriever_mock: MagicMock, llm: MagicMock) -> Non assert ( prompt == """ -Message Summary: +Message Summary: llm generated chat summary -Current Query: +Current Query: latest question """ )