Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for conversations with message history #234

Merged
merged 36 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
ec65a48
Add system_instruction parameter
leila-messallem Dec 10, 2024
0537565
Add chat_history parameter
leila-messallem Dec 10, 2024
2e7e8bf
Add missing doc strings
leila-messallem Dec 10, 2024
3c19041
Open AI
leila-messallem Dec 10, 2024
226c08b
Add a summary of the chat history to the query embedding
leila-messallem Dec 11, 2024
6d101dd
Anthropic
leila-messallem Dec 11, 2024
d8f3948
Change return type of Anthropic get_messages()
leila-messallem Dec 11, 2024
b8910df
Cohere
leila-messallem Dec 12, 2024
72f4de5
Mistral
leila-messallem Dec 12, 2024
5720a4b
VertexAI
leila-messallem Dec 12, 2024
615cea6
Merge branch 'main' into chat-history
leila-messallem Dec 13, 2024
597eff1
Formatting
leila-messallem Dec 13, 2024
f2792ff
Merge branch 'chat-history' of github.com:leila-messallem/neo4j-graph…
leila-messallem Dec 13, 2024
6288907
Fix mypy errors
leila-messallem Dec 13, 2024
a362fd3
Ollama
leila-messallem Dec 13, 2024
5bb56f6
Override of the system message
leila-messallem Dec 16, 2024
6aea7fa
Use TYPE_CHECKING for dev dependencies
leila-messallem Dec 16, 2024
07038dd
Formatting
leila-messallem Dec 16, 2024
37225fd
Rename `chat_history` to `message_history`
leila-messallem Dec 16, 2024
abef33c
Use BaseMessage class type
leila-messallem Dec 16, 2024
d7df9e8
System instruction override
leila-messallem Dec 16, 2024
a749a9e
Merge branch 'main' into chat-history
leila-messallem Dec 16, 2024
819179e
Revert BaseMessage class type
leila-messallem Dec 17, 2024
2143973
Fix mypy errors
leila-messallem Dec 17, 2024
775447f
Update tests
leila-messallem Dec 17, 2024
17db6b1
Fix ollama NameError
leila-messallem Dec 17, 2024
3c55d3f
Fix NameError in unit tests
leila-messallem Dec 18, 2024
d5a287b
Add TypeDict `LLMMessage`
leila-messallem Dec 18, 2024
bd34e1a
Simplify the retriever prompt
leila-messallem Dec 18, 2024
23a8001
Fix E2E tests
leila-messallem Dec 18, 2024
fa12a9f
Unit tests for the system instruction override
leila-messallem Dec 19, 2024
f5a9833
Move and rename the prompts
leila-messallem Dec 20, 2024
81f7ff4
Update changelog
leila-messallem Dec 20, 2024
a15a514
Add missing parameter in example
leila-messallem Dec 20, 2024
7557b07
Add LLMMessage to the docs
leila-messallem Dec 20, 2024
717be1c
Update docs README
leila-messallem Dec 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@

## Next

### Added
- Support for conversations with message history, including a new `message_history` parameter for LLM interactions.
- Ability to include system instructions and override them for specific invocations.
- Summarization of chat history to enhance query embedding and context handling.

### Changed
- Updated LLM implementations to handle message history consistently across providers.

## 1.3.0

### Added
Expand Down
17 changes: 14 additions & 3 deletions examples/customize/llms/custom_llm.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,32 @@
import random
import string
from typing import Any
from typing import Any, Optional

from neo4j_graphrag.llm import LLMInterface, LLMResponse
from neo4j_graphrag.llm.types import LLMMessage


class CustomLLM(LLMInterface):
def __init__(self, model_name: str, **kwargs: Any):
super().__init__(model_name, **kwargs)

def invoke(self, input: str) -> LLMResponse:
def invoke(
stellasia marked this conversation as resolved.
Show resolved Hide resolved
self,
input: str,
message_history: Optional[list[LLMMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
content: str = (
self.model_name + ": " + "".join(random.choices(string.ascii_letters, k=30))
)
return LLMResponse(content=content)

async def ainvoke(self, input: str) -> LLMResponse:
async def ainvoke(
self,
input: str,
message_history: Optional[list[LLMMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
raise NotImplementedError()


Expand Down
53 changes: 46 additions & 7 deletions src/neo4j_graphrag/generation/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from neo4j_graphrag.generation.prompts import RagTemplate
from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel
from neo4j_graphrag.llm import LLMInterface
from neo4j_graphrag.llm.types import LLMMessage
from neo4j_graphrag.retrievers.base import Retriever
from neo4j_graphrag.types import RetrieverResult

Expand Down Expand Up @@ -83,6 +84,7 @@ def __init__(
def search(
self,
query_text: str = "",
message_history: Optional[list[LLMMessage]] = None,
examples: str = "",
retriever_config: Optional[dict[str, Any]] = None,
return_context: bool | None = None,
Expand All @@ -99,14 +101,15 @@ def search(


Args:
query_text (str): The user question
query_text (str): The user question.
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
examples (str): Examples added to the LLM prompt.
retriever_config (Optional[dict]): Parameters passed to the retriever
retriever_config (Optional[dict]): Parameters passed to the retriever.
search method; e.g.: top_k
return_context (bool): Whether to append the retriever result to the final result (default: False)
return_context (bool): Whether to append the retriever result to the final result (default: False).

Returns:
RagResultModel: The LLM-generated answer
RagResultModel: The LLM-generated answer.

"""
if return_context is None:
Expand All @@ -124,18 +127,54 @@ def search(
)
except ValidationError as e:
raise SearchValidationError(e.errors())
query_text = validated_data.query_text
query = self.build_query(validated_data.query_text, message_history)
retriever_result: RetrieverResult = self.retriever.search(
query_text=query_text, **validated_data.retriever_config
query_text=query, **validated_data.retriever_config
)
context = "\n".join(item.content for item in retriever_result.items)
prompt = self.prompt_template.format(
query_text=query_text, context=context, examples=validated_data.examples
)
logger.debug(f"RAG: retriever_result={retriever_result}")
logger.debug(f"RAG: prompt={prompt}")
answer = self.llm.invoke(prompt)
answer = self.llm.invoke(prompt, message_history)
result: dict[str, Any] = {"answer": answer.content}
if return_context:
result["retriever_result"] = retriever_result
return RagResultModel(**result)

def build_query(
self, query_text: str, message_history: Optional[list[LLMMessage]] = None
) -> str:
summary_system_message = "You are a summarization assistant. Summarize the given text in no more than 300 words."
if message_history:
summarization_prompt = self.chat_summary_prompt(
message_history=message_history
)
summary = self.llm.invoke(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we should allow the user to use a different LLM for summarization. I'm thinking users might want to use a "small" LLM for this simple task, and use a "better" one for the Q&A part. But we can leave it for a later improvement.

input=summarization_prompt,
system_instruction=summary_system_message,
).content
return self.conversation_prompt(summary=summary, current_query=query_text)
return query_text

def chat_summary_prompt(self, message_history: list[LLMMessage]) -> str:
message_list = [
": ".join([f"{value}" for _, value in message.items()])
for message in message_history
]
history = "\n".join(message_list)
return f"""
Summarize the message history:

{history}
"""

def conversation_prompt(self, summary: str, current_query: str) -> str:
return f"""
Message Summary:
{summary}

Current Query:
{current_query}
"""
82 changes: 63 additions & 19 deletions src/neo4j_graphrag/llm/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,22 @@
# limitations under the License.
from __future__ import annotations

from typing import Any, Optional
from typing import Any, Iterable, Optional, TYPE_CHECKING, cast

from pydantic import ValidationError

from neo4j_graphrag.exceptions import LLMGenerationError
from neo4j_graphrag.llm.base import LLMInterface
from neo4j_graphrag.llm.types import LLMResponse
from neo4j_graphrag.llm.types import (
BaseMessage,
LLMMessage,
LLMResponse,
MessageList,
UserMessage,
)

if TYPE_CHECKING:
from anthropic.types.message_param import MessageParam


class AnthropicLLM(LLMInterface):
Expand All @@ -26,6 +37,7 @@ class AnthropicLLM(LLMInterface):
Args:
model_name (str, optional): Name of the LLM to use. Defaults to "gemini-1.5-flash-001".
model_params (Optional[dict], optional): Additional parameters passed to the model when text is sent to it. Defaults to None.
system_instruction: Optional[str], optional): Additional instructions for setting the behavior and context for the model in a conversation. Defaults to None.
**kwargs (Any): Arguments passed to the model when for the class is initialised. Defaults to None.

Raises:
Expand All @@ -49,6 +61,7 @@ def __init__(
self,
model_name: str,
model_params: Optional[dict[str, Any]] = None,
system_instruction: Optional[str] = None,
**kwargs: Any,
):
try:
Expand All @@ -58,55 +71,86 @@ def __init__(
"""Could not import Anthropic Python client.
Please install it with `pip install "neo4j-graphrag[anthropic]"`."""
)
super().__init__(model_name, model_params)
super().__init__(model_name, model_params, system_instruction)
self.anthropic = anthropic
self.client = anthropic.Anthropic(**kwargs)
self.async_client = anthropic.AsyncAnthropic(**kwargs)

def invoke(self, input: str) -> LLMResponse:
def get_messages(
self, input: str, message_history: Optional[list[LLMMessage]] = None
) -> Iterable[MessageParam]:
messages: list[dict[str, str]] = []
if message_history:
try:
MessageList(messages=cast(list[BaseMessage], message_history))
except ValidationError as e:
raise LLMGenerationError(e.errors()) from e
stellasia marked this conversation as resolved.
Show resolved Hide resolved
messages.extend(cast(Iterable[dict[str, Any]], message_history))
messages.append(UserMessage(content=input).model_dump())
return messages # type: ignore

def invoke(
self,
input: str,
message_history: Optional[list[LLMMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Sends text to the LLM and returns a response.

Args:
input (str): The text to send to the LLM.
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
system_instruction (Optional[str]): An option to override the llm system message for this invokation.

Returns:
LLMResponse: The response from the LLM.
"""
try:
messages = self.get_messages(input, message_history)
system_message = (
system_instruction
if system_instruction is not None
else self.system_instruction
)
response = self.client.messages.create(
model=self.model_name,
messages=[
{
"role": "user",
"content": input,
}
],
system=system_message, # type: ignore
messages=messages,
**self.model_params,
)
return LLMResponse(content=response.content)
return LLMResponse(content=response.content) # type: ignore
except self.anthropic.APIError as e:
raise LLMGenerationError(e)

async def ainvoke(self, input: str) -> LLMResponse:
async def ainvoke(
self,
input: str,
message_history: Optional[list[LLMMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Asynchronously sends text to the LLM and returns a response.

Args:
input (str): The text to send to the LLM.
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
system_instruction (Optional[str]): An option to override the llm system message for this invokation.

Returns:
LLMResponse: The response from the LLM.
"""
try:
messages = self.get_messages(input, message_history)
system_message = (
system_instruction
if system_instruction is not None
else self.system_instruction
)
response = await self.async_client.messages.create(
model=self.model_name,
messages=[
{
"role": "user",
"content": input,
}
],
system=system_message, # type: ignore
messages=messages,
**self.model_params,
)
return LLMResponse(content=response.content)
return LLMResponse(content=response.content) # type: ignore
except self.anthropic.APIError as e:
raise LLMGenerationError(e)
27 changes: 22 additions & 5 deletions src/neo4j_graphrag/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from abc import ABC, abstractmethod
from typing import Any, Optional

from .types import LLMResponse
from .types import LLMMessage, LLMResponse
stellasia marked this conversation as resolved.
Show resolved Hide resolved


class LLMInterface(ABC):
Expand All @@ -26,24 +26,34 @@ class LLMInterface(ABC):
Args:
model_name (str): The name of the language model.
model_params (Optional[dict], optional): Additional parameters passed to the model when text is sent to it. Defaults to None.
system_instruction: Optional[str], optional): Additional instructions for setting the behavior and context for the model in a conversation. Defaults to None.
**kwargs (Any): Arguments passed to the model when for the class is initialised. Defaults to None.
"""

def __init__(
self,
model_name: str,
model_params: Optional[dict[str, Any]] = None,
system_instruction: Optional[str] = None,
**kwargs: Any,
):
self.model_name = model_name
self.model_params = model_params or {}
self.system_instruction = system_instruction

@abstractmethod
def invoke(self, input: str) -> LLMResponse:
def invoke(
self,
input: str,
message_history: Optional[list[LLMMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Sends a text input to the LLM and retrieves a response.

Args:
input (str): Text sent to the LLM
input (str): Text sent to the LLM.
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
system_instruction (Optional[str]): An option to override the llm system message for this invokation.

Returns:
LLMResponse: The response from the LLM.
Expand All @@ -53,11 +63,18 @@ def invoke(self, input: str) -> LLMResponse:
"""

@abstractmethod
async def ainvoke(self, input: str) -> LLMResponse:
async def ainvoke(
self,
input: str,
message_history: Optional[list[LLMMessage]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Asynchronously sends a text input to the LLM and retrieves a response.

Args:
input (str): Text sent to the LLM
input (str): Text sent to the LLM.
message_history (Optional[list]): A collection previous messages, with each message having a specific role assigned.
system_instruction (Optional[str]): An option to override the llm system message for this invokation.

Returns:
LLMResponse: The response from the LLM.
Expand Down
Loading
Loading