Skip to content

Commit

Permalink
Chat with message history (draft)
Browse files Browse the repository at this point in the history
  • Loading branch information
leila-messallem committed Dec 6, 2024
1 parent bc6dd9c commit 7c10077
Show file tree
Hide file tree
Showing 10 changed files with 182 additions and 24 deletions.
14 changes: 0 additions & 14 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,3 @@ repos:
language: system
types: [ python ]
stages: [ commit, push ]
- id: mypy
name: Mypy Type Check
entry: mypy .
language: system
types: [ python ]
stages: [ commit, push ]
pass_filenames: false
args: [
--strict,
--ignore-missing-imports,
--allow-untyped-calls,
--allow-subclassing-any,
--exclude='./docs/'
]
7 changes: 6 additions & 1 deletion src/neo4j_graphrag/generation/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(
def search(
self,
query_text: str = "",
chat_history: Optional[list[str]] = None,
examples: str = "",
retriever_config: Optional[dict[str, Any]] = None,
return_context: bool | None = None,
Expand All @@ -100,6 +101,7 @@ def search(
Args:
query_text (str): The user question
chat_history: (Optional[list]): A list of previous messages in the conversation
examples (str): Examples added to the LLM prompt.
retriever_config (Optional[dict]): Parameters passed to the retriever
search method; e.g.: top_k
Expand Down Expand Up @@ -134,7 +136,10 @@ def search(
)
logger.debug(f"RAG: retriever_result={retriever_result}")
logger.debug(f"RAG: prompt={prompt}")
answer = self.llm.invoke(prompt)
if chat_history is not None:
answer = self.llm.chat(prompt, chat_history)
else:
answer = self.llm.invoke(prompt)
result: dict[str, Any] = {"answer": answer.content}
if return_context:
result["retriever_result"] = retriever_result
Expand Down
3 changes: 3 additions & 0 deletions src/neo4j_graphrag/llm/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,6 @@ async def ainvoke(self, input: str) -> LLMResponse:
return LLMResponse(content=response.content)
except self.anthropic.APIError as e:
raise LLMGenerationError(e)

def chat(self, input: str, chat_history: list[str]) -> LLMResponse:
pass
17 changes: 17 additions & 0 deletions src/neo4j_graphrag/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@ def __init__(
self,
model_name: str,
model_params: Optional[dict[str, Any]] = None,
system_instruction: Optional[str] = None,
**kwargs: Any,
):
self.model_name = model_name
self.model_params = model_params or {}
self.system_instruction = system_instruction

@abstractmethod
def invoke(self, input: str) -> LLMResponse:
Expand All @@ -52,6 +54,21 @@ def invoke(self, input: str) -> LLMResponse:
LLMGenerationError: If anything goes wrong.
"""

@abstractmethod
def chat(self, input: str, chat_history: list[str]) -> LLMResponse:
"""Sends a text input and a converstion history to the LLM and retrieves a response.
Args:
input (str): Text sent to the LLM
chat_history (list[str]]): A list of previous messages in the conversation
Returns:
LLMResponse: The response from the LLM.
Raises:
LLMGenerationError: If anything goes wrong.
"""

@abstractmethod
async def ainvoke(self, input: str) -> LLMResponse:
"""Asynchronously sends a text input to the LLM and retrieves a response.
Expand Down
3 changes: 3 additions & 0 deletions src/neo4j_graphrag/llm/cohere_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,6 @@ async def ainvoke(self, input: str) -> LLMResponse:
return LLMResponse(
content=res.text,
)

def chat(self, input: str, chat_history: list[str]) -> LLMResponse:
pass
3 changes: 3 additions & 0 deletions src/neo4j_graphrag/llm/mistralai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,6 @@ async def ainvoke(self, input: str) -> LLMResponse:
return LLMResponse(content=content)
except SDKError as e:
raise LLMGenerationError(e)

def chat(self, input: str, chat_history: list[str]) -> LLMResponse:
pass
35 changes: 32 additions & 3 deletions src/neo4j_graphrag/llm/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
self,
model_name: str,
model_params: Optional[dict[str, Any]] = None,
system_instruction: Optional[str] = None,
):
"""
Base class for OpenAI LLM.
Expand All @@ -54,7 +55,7 @@ def __init__(
"Please install it with `pip install openai`."
)
self.openai = openai
super().__init__(model_name, model_params)
super().__init__(model_name, model_params, system_instruction)

def get_messages(
self,
Expand All @@ -64,6 +65,32 @@ def get_messages(
{"role": "system", "content": input},
]

def get_conversation_history(
self,
input: str,
chat_history: list[str],
) -> Iterable[ChatCompletionMessageParam]:
messages = [{"role": "system", "content": self.system_instruction}]
for i, message in enumerate(chat_history):
if i % 2 == 0:
messages.append({"role": "user", "content": message})
else:
messages.append({"role": "assistant", "content": message})
messages.append({"role": "user", "content": input})
return messages

def chat(self, input: str, chat_history: list[str]) -> LLMResponse:
try:
response = self.client.chat.completions.create(
messages=self.get_conversation_history(input, chat_history),
model=self.model_name,
**self.model_params,
)
content = response.choices[0].message.content or ""
return LLMResponse(content=content)
except self.openai.OpenAIError as e:
raise LLMGenerationError(e)

def invoke(self, input: str) -> LLMResponse:
"""Sends a text input to the OpenAI chat completion model
and returns the response's content.
Expand Down Expand Up @@ -118,6 +145,7 @@ def __init__(
self,
model_name: str,
model_params: Optional[dict[str, Any]] = None,
system_instruction: Optional[str] = None,
**kwargs: Any,
):
"""OpenAI LLM
Expand All @@ -129,7 +157,7 @@ def __init__(
model_params (str): Parameters like temperature that will be passed to the model when text is sent to it
kwargs: All other parameters will be passed to the openai.OpenAI init.
"""
super().__init__(model_name, model_params)
super().__init__(model_name, model_params, system_instruction)
self.client = self.openai.OpenAI(**kwargs)
self.async_client = self.openai.AsyncOpenAI(**kwargs)

Expand All @@ -139,6 +167,7 @@ def __init__(
self,
model_name: str,
model_params: Optional[dict[str, Any]] = None,
system_instruction: Optional[str] = None,
**kwargs: Any,
):
"""Azure OpenAI LLM. Use this class when using an OpenAI model
Expand All @@ -149,6 +178,6 @@ def __init__(
model_params (str): Parameters like temperature that will be passed to the model when text is sent to it
kwargs: All other parameters will be passed to the openai.OpenAI init.
"""
super().__init__(model_name, model_params)
super().__init__(model_name, model_params, system_instruction)
self.client = self.openai.AzureOpenAI(**kwargs)
self.async_client = self.openai.AsyncAzureOpenAI(**kwargs)
48 changes: 44 additions & 4 deletions src/neo4j_graphrag/llm/vertexai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from neo4j_graphrag.llm.types import LLMResponse

try:
from vertexai.generative_models import GenerativeModel, ResponseValidationError
from vertexai.generative_models import (
GenerativeModel,
ResponseValidationError,
Part,
Content,
)
except ImportError:
GenerativeModel = None
ResponseValidationError = None
Expand Down Expand Up @@ -55,15 +60,18 @@ def __init__(
self,
model_name: str = "gemini-1.5-flash-001",
model_params: Optional[dict[str, Any]] = None,
system_instruction: Optional[str] = None,
**kwargs: Any,
):
if GenerativeModel is None or ResponseValidationError is None:
raise ImportError(
"Could not import Vertex AI Python client. "
"Please install it with `pip install google-cloud-aiplatform`."
)
super().__init__(model_name, model_params)
self.model = GenerativeModel(model_name=model_name, **kwargs)
super().__init__(model_name, model_params, system_instruction)
self.model = GenerativeModel(
model_name=model_name, system_instruction=[system_instruction], **kwargs
)

def invoke(self, input: str) -> LLMResponse:
"""Sends text to the LLM and returns a response.
Expand All @@ -80,7 +88,25 @@ def invoke(self, input: str) -> LLMResponse:
except ResponseValidationError as e:
raise LLMGenerationError(e)

async def ainvoke(self, input: str) -> LLMResponse:
def chat(self, input: str, chat_history: list[str] = []) -> LLMResponse:
"""Sends text to the LLM and returns a response.
Args:
input (str): The text to send to the LLM.
Returns:
LLMResponse: The response from the LLM.
"""
try:
messages = self.get_conversation_history(input, chat_history)
response = self.model.generate_content(messages, **self.model_params)
return LLMResponse(content=response.text)
except ResponseValidationError as e:
raise LLMGenerationError(e)

async def ainvoke(
self, input: str, chat_history: Optional[list[str]] = []
) -> LLMResponse:
"""Asynchronously sends text to the LLM and returns a response.
Args:
Expand All @@ -96,3 +122,17 @@ async def ainvoke(self, input: str) -> LLMResponse:
return LLMResponse(content=response.text)
except ResponseValidationError as e:
raise LLMGenerationError(e)

def get_conversation_history(
self,
input: str,
chat_history: list[str],
) -> list[Content]:
messages = []
for i, message in enumerate(chat_history):
if i % 2 == 0:
messages.append(Content(role="user", parts=[Part.from_text(message)]))
else:
messages.append(Content(role="model", parts=[Part.from_text(message)]))
messages.append(Content(role="user", parts=[Part.from_text(input)]))
return messages
41 changes: 40 additions & 1 deletion tests/unit/llm/test_openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_openai_llm_missing_dependency(mock_import: Mock) -> None:


@patch("builtins.__import__")
def test_openai_llm_happy_path(mock_import: Mock) -> None:
def test_openai_llm_invoke_happy_path(mock_import: Mock) -> None:
mock_openai = get_mock_openai()
mock_import.return_value = mock_openai
mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock(
Expand All @@ -46,6 +46,20 @@ def test_openai_llm_happy_path(mock_import: Mock) -> None:
assert res.content == "openai chat response"


@patch("builtins.__import__")
def test_openai_llm_chat_happy_path(mock_import: Mock) -> None:
mock_openai = get_mock_openai()
mock_import.return_value = mock_openai
mock_openai.OpenAI.return_value.chat.completions.create.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content="openai chat response"))],
)
llm = OpenAILLM(api_key="my key", model_name="gpt")

res = llm.chat("my question", ["user message", "assistant message"])
assert isinstance(res, LLMResponse)
assert res.content == "openai chat response"


@patch("builtins.__import__", side_effect=ImportError)
def test_azure_openai_llm_missing_dependency(mock_import: Mock) -> None:
with pytest.raises(ImportError):
Expand All @@ -71,3 +85,28 @@ def test_azure_openai_llm_happy_path(mock_import: Mock) -> None:
res = llm.invoke("my text")
assert isinstance(res, LLMResponse)
assert res.content == "openai chat response"


def test_openai_llm_get_conversation_history() -> None:
system_instruction = "You are a helpful assistant."
question = "When does it set?"
chat_history = [
"When does the sun come up in the summer?",
"Usually around 6am.",
"What about next season?",
"Around 8am.",
]
expected_response = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "When does the sun come up in the summer?"},
{"role": "assistant", "content": "Usually around 6am."},
{"role": "user", "content": "What about next season?"},
{"role": "assistant", "content": "Around 8am."},
{"role": "user", "content": "When does it set?"},
]

llm = OpenAILLM(
api_key="my key", model_name="gpt", system_instruction=system_instruction
)
response = llm.get_conversation_history(question, chat_history)
assert response == expected_response
35 changes: 34 additions & 1 deletion tests/unit/llm/test_vertexai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from unittest.mock import AsyncMock, MagicMock, Mock, patch

import pytest
from neo4j_graphrag.llm.vertexai_llm import VertexAILLM
from neo4j_graphrag.llm.vertexai_llm import VertexAILLM, Part, Content


@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel", None)
Expand Down Expand Up @@ -52,3 +52,36 @@ async def test_vertexai_ainvoke_happy_path(GenerativeModelMock: MagicMock) -> No
response = await llm.ainvoke(input_text)
assert response.content == "Return text"
llm.model.generate_content_async.assert_called_once_with(input_text, **model_params)


@patch("neo4j_graphrag.llm.vertexai_llm.GenerativeModel")
def test_vertexai_get_conversation_history(GenerativeModelMock: MagicMock) -> None:
system_instruction = "You are a helpful assistant."
question = "When does it set?"
chat_history = [
"When does the sun come up in the summer?",
"Usually around 6am.",
"What about next season?",
"Around 8am.",
]
expected_response = [
Content(
role="user",
parts=[Part.from_text("When does the sun come up in the summer?")],
),
Content(role="model", parts=[Part.from_text("Usually around 6am.")]),
Content(role="user", parts=[Part.from_text("What about next season?")]),
Content(role="model", parts=[Part.from_text("Around 8am.")]),
Content(role="user", parts=[Part.from_text("When does it set?")]),
]

llm = VertexAILLM(
model_name="gemini-1.5-flash-001", system_instruction=system_instruction
)
response = llm.get_conversation_history(question, chat_history)

assert llm.system_instruction == system_instruction
assert len(response) == len(expected_response)
for actual, expected in zip(response, expected_response):
assert actual.role == expected.role
assert actual.parts[0].text == expected.parts[0].text

0 comments on commit 7c10077

Please sign in to comment.