Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chat with message history (draft) #225

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,3 @@ repos:
language: system
types: [ python ]
stages: [ commit, push ]
- id: mypy
Copy link
Contributor Author

@leila-messallem leila-messallem Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry about this, but I couldn't bother with mypy just for opening a Draft.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a good point actually, maybe we should disable this for draft PRs

name: Mypy Type Check
entry: mypy .
language: system
types: [ python ]
stages: [ commit, push ]
pass_filenames: false
args: [
--strict,
--ignore-missing-imports,
--allow-untyped-calls,
--allow-subclassing-any,
--exclude='./docs/'
]
7 changes: 6 additions & 1 deletion src/neo4j_graphrag/generation/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
def search(
self,
query_text: str = "",
chat_history: Optional[list[str]] = None,
examples: str = "",
retriever_config: Optional[dict[str, Any]] = None,
return_context: bool | None = None,
Expand All @@ -100,6 +101,7 @@ def search(
Args:
query_text (str): The user question
chat_history: (Optional[list]): A list of previous messages in the conversation
examples (str): Examples added to the LLM prompt.
retriever_config (Optional[dict]): Parameters passed to the retriever
search method; e.g.: top_k
Expand Down Expand Up @@ -134,7 +136,10 @@ def search(
)
logger.debug(f"RAG: retriever_result={retriever_result}")
logger.debug(f"RAG: prompt={prompt}")
answer = self.llm.invoke(prompt)
if chat_history is not None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this part isn't great. The idea I want to demonstrate is to have two different flows for chat with message history and single invokation.

answer = self.llm.chat(prompt, chat_history)
else:
answer = self.llm.invoke(prompt)
result: dict[str, Any] = {"answer": answer.content}
if return_context:
result["retriever_result"] = retriever_result
Expand Down
3 changes: 3 additions & 0 deletions src/neo4j_graphrag/llm/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,6 @@ async def ainvoke(self, input: str) -> LLMResponse:
return LLMResponse(content=response.content)
except self.anthropic.APIError as e:
raise LLMGenerationError(e)

def chat(self, input: str, chat_history: list[str]) -> LLMResponse:
pass
17 changes: 17 additions & 0 deletions src/neo4j_graphrag/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@ def __init__(
self,
model_name: str,
model_params: Optional[dict[str, Any]] = None,
system_instruction: Optional[str] = None,
**kwargs: Any,
):
self.model_name = model_name
self.model_params = model_params or {}
self.system_instruction = system_instruction

@abstractmethod
def invoke(self, input: str) -> LLMResponse:
Expand All @@ -52,6 +54,21 @@ def invoke(self, input: str) -> LLMResponse:
LLMGenerationError: If anything goes wrong.
"""

@abstractmethod
def chat(self, input: str, chat_history: list[str]) -> LLMResponse:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like it could be done in one method, with an optional chat_history. The difference would be dealt with in a get_messages method, that would do something like this:

def get_messages(self, input, chat_history) -> list:
    messages = []
    if self.system_instructions:
        messages.append({"role": "system", "content": self.system_instructions}})
    if chat_history:
        # handle chat history messages format
        messages.extend(chat_history)
    messages.append({"role": "user", "content": input})
    return messages

This is an advantage if most of the LLM providers share the same format for messages, in which case we can implement this method only once in the interface.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, that works.

"""Sends a text input and a converstion history to the LLM and retrieves a response.

Args:
input (str): Text sent to the LLM
chat_history (list[str]]): A list of previous messages in the conversation

Returns:
LLMResponse: The response from the LLM.

Raises:
LLMGenerationError: If anything goes wrong.
"""

@abstractmethod
async def ainvoke(self, input: str) -> LLMResponse:
"""Asynchronously sends a text input to the LLM and retrieves a response.
Expand Down
3 changes: 3 additions & 0 deletions src/neo4j_graphrag/llm/cohere_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,6 @@ async def ainvoke(self, input: str) -> LLMResponse:
return LLMResponse(
content=res.text,
)

def chat(self, input: str, chat_history: list[str]) -> LLMResponse:
pass
3 changes: 3 additions & 0 deletions src/neo4j_graphrag/llm/mistralai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,6 @@ async def ainvoke(self, input: str) -> LLMResponse:
return LLMResponse(content=content)
except SDKError as e:
raise LLMGenerationError(e)

def chat(self, input: str, chat_history: list[str]) -> LLMResponse:
pass
35 changes: 32 additions & 3 deletions src/neo4j_graphrag/llm/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
self,
model_name: str,
model_params: Optional[dict[str, Any]] = None,
system_instruction: Optional[str] = None,
):
"""
Base class for OpenAI LLM.
Expand All @@ -54,7 +55,7 @@ def __init__(
"Please install it with `pip install openai`."
)
self.openai = openai
super().__init__(model_name, model_params)
super().__init__(model_name, model_params, system_instruction)

def get_messages(
self,
Expand All @@ -64,6 +65,32 @@ def get_messages(
{"role": "system", "content": input},
]

def get_conversation_history(
self,
input: str,
chat_history: list[str],
) -> Iterable[ChatCompletionMessageParam]:
messages = [{"role": "system", "content": self.system_instruction}]
for i, message in enumerate(chat_history):
if i % 2 == 0:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes that the chat history is a list of strings, alternating between user and assistant messages. It's obviously fragile and should probably be changed to some kind of UserMessage and AssistantMessage objects, or smtn like that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can use dicts with "role" and "content" keys (it'd work for OpenAI, I don't know about the others)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can look into it.

messages.append({"role": "user", "content": message})
else:
messages.append({"role": "assistant", "content": message})
messages.append({"role": "user", "content": input})
return messages

def chat(self, input: str, chat_history: list[str]) -> LLMResponse:
try:
response = self.client.chat.completions.create(
messages=self.get_conversation_history(input, chat_history),
model=self.model_name,
**self.model_params,
)
content = response.choices[0].message.content or ""
return LLMResponse(content=content)
except self.openai.OpenAIError as e:
raise LLMGenerationError(e)

def invoke(self, input: str) -> LLMResponse:
"""Sends a text input to the OpenAI chat completion model
and returns the response's content.
Expand Down Expand Up @@ -118,6 +145,7 @@ def __init__(
self,
model_name: str,
model_params: Optional[dict[str, Any]] = None,
system_instruction: Optional[str] = None,
**kwargs: Any,
):
"""OpenAI LLM
Expand All @@ -129,7 +157,7 @@ def __init__(
model_params (str): Parameters like temperature that will be passed to the model when text is sent to it
kwargs: All other parameters will be passed to the openai.OpenAI init.
"""
super().__init__(model_name, model_params)
super().__init__(model_name, model_params, system_instruction)
self.client = self.openai.OpenAI(**kwargs)
self.async_client = self.openai.AsyncOpenAI(**kwargs)

Expand All @@ -139,6 +167,7 @@ def __init__(
self,
model_name: str,
model_params: Optional[dict[str, Any]] = None,
system_instruction: Optional[str] = None,
**kwargs: Any,
):
"""Azure OpenAI LLM. Use this class when using an OpenAI model
Expand All @@ -149,6 +178,6 @@ def __init__(
model_params (str): Parameters like temperature that will be passed to the model when text is sent to it
kwargs: All other parameters will be passed to the openai.OpenAI init.
"""
super().__init__(model_name, model_params)
super().__init__(model_name, model_params, system_instruction)
self.client = self.openai.AzureOpenAI(**kwargs)
self.async_client = self.openai.AsyncAzureOpenAI(**kwargs)
48 changes: 44 additions & 4 deletions src/neo4j_graphrag/llm/vertexai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from neo4j_graphrag.llm.types import LLMResponse

try:
from vertexai.generative_models import GenerativeModel, ResponseValidationError
from vertexai.generative_models import (
GenerativeModel,
ResponseValidationError,
Part,
Content,
)
except ImportError:
GenerativeModel = None
ResponseValidationError = None
Expand Down Expand Up @@ -55,15 +60,18 @@ def __init__(
self,
model_name: str = "gemini-1.5-flash-001",
model_params: Optional[dict[str, Any]] = None,
system_instruction: Optional[str] = None,
**kwargs: Any,
):
if GenerativeModel is None or ResponseValidationError is None:
raise ImportError(
"Could not import Vertex AI Python client. "
"Please install it with `pip install google-cloud-aiplatform`."
)
super().__init__(model_name, model_params)
self.model = GenerativeModel(model_name=model_name, **kwargs)
super().__init__(model_name, model_params, system_instruction)
self.model = GenerativeModel(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On vertexai, system_instruction is set on the GenerativeModel object. So it's separated from the question prompt.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite bad news actually, that means that we can use an LLM instance only in one specific use case? For instance, with RAG with a text2Cypher retriever, we can't use the same LLM for the retrieval and the generation? (I mean, if we want to use the system prompt, which is not what we do at the moment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we don't have to always use the system_instruction. Although, it's the recommended way.

model_name=model_name, system_instruction=[system_instruction], **kwargs
)

def invoke(self, input: str) -> LLMResponse:
"""Sends text to the LLM and returns a response.
Expand All @@ -80,7 +88,25 @@ def invoke(self, input: str) -> LLMResponse:
except ResponseValidationError as e:
raise LLMGenerationError(e)

async def ainvoke(self, input: str) -> LLMResponse:
def chat(self, input: str, chat_history: list[str] = []) -> LLMResponse:
"""Sends text to the LLM and returns a response.
Args:
input (str): The text to send to the LLM.
Returns:
LLMResponse: The response from the LLM.
"""
try:
messages = self.get_conversation_history(input, chat_history)
response = self.model.generate_content(messages, **self.model_params)
return LLMResponse(content=response.text)
except ResponseValidationError as e:
raise LLMGenerationError(e)

async def ainvoke(
self, input: str, chat_history: Optional[list[str]] = []
) -> LLMResponse:
"""Asynchronously sends text to the LLM and returns a response.
Args:
Expand All @@ -96,3 +122,17 @@ async def ainvoke(self, input: str) -> LLMResponse:
return LLMResponse(content=response.text)
except ResponseValidationError as e:
raise LLMGenerationError(e)

def get_conversation_history(
self,
input: str,
chat_history: list[str],
) -> list[Content]:
messages = []
for i, message in enumerate(chat_history):
if i % 2 == 0:
messages.append(Content(role="user", parts=[Part.from_text(message)]))
else:
messages.append(Content(role="model", parts=[Part.from_text(message)]))
messages.append(Content(role="user", parts=[Part.from_text(input)]))
return messages
41 changes: 40 additions & 1 deletion tests/unit/llm/test_openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_openai_llm_missing_dependency(mock_import: Mock) -> None:


@patch("builtins.__import__")
def test_openai_llm_happy_path(mock_import: Mock) -> None:
def test_openai_llm_invoke_happy_path(mock_import: Mock) -> None:
mock_openai = get_mock_openai()
mock_import.return_value = mock_openai
mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock(
Expand All @@ -46,6 +46,20 @@ def test_openai_llm_happy_path(mock_import: Mock) -> None:
assert res.content == "openai chat response"


@patch("builtins.__import__")
def test_openai_llm_chat_happy_path(mock_import: Mock) -> None:
mock_openai = get_mock_openai()
mock_import.return_value = mock_openai
mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content="openai chat response"))],
)
llm = OpenAILLM(api_key="my key", model_name="gpt")

res = llm.chat("my question", ["user message", "assistant message"])
assert isinstance(res, LLMResponse)
assert res.content == "openai chat response"


@patch("builtins.__import__", side_effect=ImportError)
def test_azure_openai_llm_missing_dependency(mock_import: Mock) -> None:
with pytest.raises(ImportError):
Expand All @@ -71,3 +85,28 @@ def test_azure_openai_llm_happy_path(mock_import: Mock) -> None:
res = llm.invoke("my text")
assert isinstance(res, LLMResponse)
assert res.content == "openai chat response"


def test_openai_llm_get_conversation_history() -> None:
system_instruction = "You are a helpful assistant."
question = "When does it set?"
chat_history = [
"When does the sun come up in the summer?",
"Usually around 6am.",
"What about next season?",
"Around 8am.",
]
expected_response = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "When does the sun come up in the summer?"},
{"role": "assistant", "content": "Usually around 6am."},
{"role": "user", "content": "What about next season?"},
{"role": "assistant", "content": "Around 8am."},
{"role": "user", "content": "When does it set?"},
]

llm = OpenAILLM(
api_key="my key", model_name="gpt", system_instruction=system_instruction
)
response = llm.get_conversation_history(question, chat_history)
assert response == expected_response
35 changes: 34 additions & 1 deletion tests/unit/llm/test_vertexai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from unittest.mock import AsyncMock, MagicMock, Mock, patch

import pytest
from neo4j_graphrag.llm.vertexai_llm import VertexAILLM
from neo4j_graphrag.llm.vertexai_llm import VertexAILLM, Part, Content


@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel", None)
Expand Down Expand Up @@ -52,3 +52,36 @@ async def test_vertexai_ainvoke_happy_path(GenerativeModelMock: MagicMock) -> No
response = await llm.ainvoke(input_text)
assert response.content == "Return text"
llm.model.generate_content_async.assert_called_once_with(input_text, **model_params)


@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel")
def test_vertexai_get_conversation_history(GenerativeModelMock: MagicMock) -> None:
system_instruction = "You are a helpful assistant."
question = "When does it set?"
chat_history = [
"When does the sun come up in the summer?",
"Usually around 6am.",
"What about next season?",
"Around 8am.",
]
expected_response = [
Content(
role="user",
parts=[Part.from_text("When does the sun come up in the summer?")],
),
Content(role="model", parts=[Part.from_text("Usually around 6am.")]),
Content(role="user", parts=[Part.from_text("What about next season?")]),
Content(role="model", parts=[Part.from_text("Around 8am.")]),
Content(role="user", parts=[Part.from_text("When does it set?")]),
]

llm = VertexAILLM(
model_name="gemini-1.5-flash-001", system_instruction=system_instruction
)
response = llm.get_conversation_history(question, chat_history)

assert llm.system_instruction == system_instruction
assert len(response) == len(expected_response)
for actual, expected in zip(response, expected_response):
assert actual.role == expected.role
assert actual.parts[0].text == expected.parts[0].text
Loading