Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for conversations with message history #234

Merged
merged 36 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
ec65a48
Add system_instruction parameter
leila-messallem Dec 10, 2024
0537565
Add chat_history parameter
leila-messallem Dec 10, 2024
2e7e8bf
Add missing doc strings
leila-messallem Dec 10, 2024
3c19041
Open AI
leila-messallem Dec 10, 2024
226c08b
Add a summary of the chat history to the query embedding
leila-messallem Dec 11, 2024
6d101dd
Anthropic
leila-messallem Dec 11, 2024
d8f3948
Change return type of Anthropic get_messages()
leila-messallem Dec 11, 2024
b8910df
Cohere
leila-messallem Dec 12, 2024
72f4de5
Mistral
leila-messallem Dec 12, 2024
5720a4b
VertexAI
leila-messallem Dec 12, 2024
615cea6
Merge branch 'main' into chat-history
leila-messallem Dec 13, 2024
597eff1
Formatting
leila-messallem Dec 13, 2024
f2792ff
Merge branch 'chat-history' of github.com:leila-messallem/neo4j-graph…
leila-messallem Dec 13, 2024
6288907
Fix mypy errors
leila-messallem Dec 13, 2024
a362fd3
Ollama
leila-messallem Dec 13, 2024
5bb56f6
Override of the system message
leila-messallem Dec 16, 2024
6aea7fa
Use TYPE_CHECKING for dev dependencies
leila-messallem Dec 16, 2024
07038dd
Formatting
leila-messallem Dec 16, 2024
37225fd
Rename `chat_history` to `message_history`
leila-messallem Dec 16, 2024
abef33c
Use BaseMessage class type
leila-messallem Dec 16, 2024
d7df9e8
System instruction override
leila-messallem Dec 16, 2024
a749a9e
Merge branch 'main' into chat-history
leila-messallem Dec 16, 2024
819179e
Revert BaseMessage class type
leila-messallem Dec 17, 2024
2143973
Fix mypy errors
leila-messallem Dec 17, 2024
775447f
Update tests
leila-messallem Dec 17, 2024
17db6b1
Fix ollama NameError
leila-messallem Dec 17, 2024
3c55d3f
Fix NameError in unit tests
leila-messallem Dec 18, 2024
d5a287b
Add TypeDict `LLMMessage`
leila-messallem Dec 18, 2024
bd34e1a
Simplify the retriever prompt
leila-messallem Dec 18, 2024
23a8001
Fix E2E tests
leila-messallem Dec 18, 2024
fa12a9f
Unit tests for the system instruction override
leila-messallem Dec 19, 2024
f5a9833
Move and rename the prompts
leila-messallem Dec 20, 2024
81f7ff4
Update changelog
leila-messallem Dec 20, 2024
a15a514
Add missing parameter in example
leila-messallem Dec 20, 2024
7557b07
Add LLMMessage to the docs
leila-messallem Dec 20, 2024
717be1c
Update docs README
leila-messallem Dec 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions src/neo4j_graphrag/generation/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
RagInitializationError,
SearchValidationError,
)
from neo4j_graphrag.generation.prompts import RagTemplate
from neo4j_graphrag.generation.prompts import (
RagTemplate,
ChatSummaryTemplate,
ConversationTemplate,
)
from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel
from neo4j_graphrag.llm import LLMInterface
from neo4j_graphrag.retrievers.base import Retriever
Expand Down Expand Up @@ -83,6 +87,7 @@ def __init__(
def search(
self,
query_text: str = "",
chat_history: Optional[list[dict[str, str]]] = None,
oskarhane marked this conversation as resolved.
Show resolved Hide resolved
examples: str = "",
retriever_config: Optional[dict[str, Any]] = None,
return_context: bool | None = None,
Expand All @@ -99,14 +104,15 @@ def search(


Args:
query_text (str): The user question
query_text (str): The user question.
chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
examples (str): Examples added to the LLM prompt.
retriever_config (Optional[dict]): Parameters passed to the retriever
retriever_config (Optional[dict]): Parameters passed to the retriever.
search method; e.g.: top_k
return_context (bool): Whether to append the retriever result to the final result (default: False)
return_context (bool): Whether to append the retriever result to the final result (default: False).

Returns:
RagResultModel: The LLM-generated answer
RagResultModel: The LLM-generated answer.

"""
if return_context is None:
Expand All @@ -124,18 +130,31 @@ def search(
)
except ValidationError as e:
raise SearchValidationError(e.errors())
query_text = validated_data.query_text
query = self.build_query(validated_data.query_text, chat_history)
retriever_result: RetrieverResult = self.retriever.search(
query_text=query_text, **validated_data.retriever_config
query_text=query, **validated_data.retriever_config
)
context = "\n".join(item.content for item in retriever_result.items)
prompt = self.prompt_template.format(
query_text=query_text, context=context, examples=validated_data.examples
)
logger.debug(f"RAG: retriever_result={retriever_result}")
logger.debug(f"RAG: prompt={prompt}")
answer = self.llm.invoke(prompt)
answer = self.llm.invoke(prompt, chat_history)
result: dict[str, Any] = {"answer": answer.content}
if return_context:
result["retriever_result"] = retriever_result
return RagResultModel(**result)

def build_query(
self, query_text: str, chat_history: Optional[list[dict[str, str]]] = None
) -> str:
if chat_history:
summarization_prompt = ChatSummaryTemplate().format(
chat_history=chat_history
)
summary = self.llm.invoke(summarization_prompt).content
stellasia marked this conversation as resolved.
Show resolved Hide resolved
return ConversationTemplate().format(
summary=summary, current_query=query_text
)
return query_text
31 changes: 31 additions & 0 deletions src/neo4j_graphrag/generation/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,34 @@ def format(
text: str = "",
) -> str:
return super().format(text=text, schema=schema, examples=examples)


class ChatSummaryTemplate(PromptTemplate):
DEFAULT_TEMPLATE = """
Summarize the chat history:

{chat_history}
"""
EXPECTED_INPUTS = ["chat_history"]

def format(self, chat_history: list[dict[str, str]]) -> str:
message_list = [
": ".join([f"{value}" for _, value in message.items()])
for message in chat_history
]
history = "\n".join(message_list)
return super().format(chat_history=history)


class ConversationTemplate(PromptTemplate):
DEFAULT_TEMPLATE = """
Chat Summary:
{summary}
stellasia marked this conversation as resolved.
Show resolved Hide resolved

Current Query:
{current_query}
"""
EXPECTED_INPUTS = ["summary", "current_query"]

def format(self, summary: str, current_query: str) -> str:
return super().format(summary=summary, current_query=current_query)
61 changes: 41 additions & 20 deletions src/neo4j_graphrag/llm/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,19 @@
# limitations under the License.
from __future__ import annotations

from typing import Any, Optional
from typing import Any, Iterable, Optional

from pydantic import ValidationError

from neo4j_graphrag.exceptions import LLMGenerationError
from neo4j_graphrag.llm.base import LLMInterface
from neo4j_graphrag.llm.types import LLMResponse
from neo4j_graphrag.llm.types import LLMResponse, MessageList, UserMessage

try:
import anthropic
from anthropic.types.message_param import MessageParam
except ImportError:
anthropic = None


class AnthropicLLM(LLMInterface):
Expand All @@ -26,6 +34,7 @@ class AnthropicLLM(LLMInterface):
Args:
model_name (str, optional): Name of the LLM to use. Defaults to "gemini-1.5-flash-001".
model_params (Optional[dict], optional): Additional parameters passed to the model when text is sent to it. Defaults to None.
system_instruction: Optional[str], optional): Additional instructions for setting the behavior and context for the model in a conversation. Defaults to None.
**kwargs (Any): Arguments passed to the model when for the class is initialised. Defaults to None.

Raises:
Expand All @@ -49,62 +58,74 @@ def __init__(
self,
model_name: str,
model_params: Optional[dict[str, Any]] = None,
system_instruction: Optional[str] = None,
**kwargs: Any,
):
try:
import anthropic
except ImportError:
if anthropic is None:
stellasia marked this conversation as resolved.
Show resolved Hide resolved
raise ImportError(
"""Could not import Anthropic Python client.
Please install it with `pip install "neo4j-graphrag[anthropic]"`."""
)
super().__init__(model_name, model_params)
super().__init__(model_name, model_params, system_instruction)
self.anthropic = anthropic
self.client = anthropic.Anthropic(**kwargs)
self.async_client = anthropic.AsyncAnthropic(**kwargs)

def invoke(self, input: str) -> LLMResponse:
def get_messages(
self, input: str, chat_history: Optional[list[Any]] = None
) -> Iterable[MessageParam]:
messages = []
if chat_history:
try:
MessageList(messages=chat_history)
except ValidationError as e:
raise LLMGenerationError(e.errors()) from e
stellasia marked this conversation as resolved.
Show resolved Hide resolved
messages.extend(chat_history)
messages.append(UserMessage(content=input).model_dump())
return messages

def invoke(
self, input: str, chat_history: Optional[list[Any]] = None
stellasia marked this conversation as resolved.
Show resolved Hide resolved
) -> LLMResponse:
"""Sends text to the LLM and returns a response.

Args:
input (str): The text to send to the LLM.
chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.

Returns:
LLMResponse: The response from the LLM.
"""
try:
messages = self.get_messages(input, chat_history)
response = self.client.messages.create(
model=self.model_name,
messages=[
{
"role": "user",
"content": input,
}
],
system=self.system_instruction,
messages=messages,
**self.model_params,
)
return LLMResponse(content=response.content)
except self.anthropic.APIError as e:
raise LLMGenerationError(e)

async def ainvoke(self, input: str) -> LLMResponse:
async def ainvoke(
self, input: str, chat_history: Optional[list[Any]] = None
) -> LLMResponse:
"""Asynchronously sends text to the LLM and returns a response.

Args:
input (str): The text to send to the LLM.
chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.

Returns:
LLMResponse: The response from the LLM.
"""
try:
messages = self.get_messages(input, chat_history)
response = await self.async_client.messages.create(
model=self.model_name,
messages=[
{
"role": "user",
"content": input,
}
],
system=self.system_instruction,
messages=messages,
**self.model_params,
)
return LLMResponse(content=response.content)
Expand Down
18 changes: 14 additions & 4 deletions src/neo4j_graphrag/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,30 @@ class LLMInterface(ABC):
Args:
model_name (str): The name of the language model.
model_params (Optional[dict], optional): Additional parameters passed to the model when text is sent to it. Defaults to None.
system_instruction: Optional[str], optional): Additional instructions for setting the behavior and context for the model in a conversation. Defaults to None.
**kwargs (Any): Arguments passed to the model when for the class is initialised. Defaults to None.
"""

def __init__(
self,
model_name: str,
model_params: Optional[dict[str, Any]] = None,
system_instruction: Optional[str] = None,
**kwargs: Any,
):
self.model_name = model_name
self.model_params = model_params or {}
self.system_instruction = system_instruction

@abstractmethod
def invoke(self, input: str) -> LLMResponse:
def invoke(
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
) -> LLMResponse:
"""Sends a text input to the LLM and retrieves a response.

Args:
input (str): Text sent to the LLM
input (str): Text sent to the LLM.
chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.

Returns:
LLMResponse: The response from the LLM.
Expand All @@ -53,11 +59,15 @@ def invoke(self, input: str) -> LLMResponse:
"""

@abstractmethod
async def ainvoke(self, input: str) -> LLMResponse:
async def ainvoke(
self, input: str, chat_history: Optional[list[dict[str, str]]] = None
) -> LLMResponse:
"""Asynchronously sends a text input to the LLM and retrieves a response.

Args:
input (str): Text sent to the LLM
input (str): Text sent to the LLM.
chat_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.


Returns:
LLMResponse: The response from the LLM.
Expand Down
Loading
Loading