Skip to content

Commit

Permalink
Revert BaseMessage class type
Browse files Browse the repository at this point in the history
* bring back list[dicy[str,str]] type declaration for the `message_history` parameter
  • Loading branch information
leila-messallem committed Dec 17, 2024
1 parent a749a9e commit 819179e
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 58 deletions.
5 changes: 2 additions & 3 deletions src/neo4j_graphrag/generation/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
)
from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel
from neo4j_graphrag.llm import LLMInterface
from neo4j_graphrag.llm.types import BaseMessage
from neo4j_graphrag.retrievers.base import Retriever
from neo4j_graphrag.types import RetrieverResult

Expand Down Expand Up @@ -88,7 +87,7 @@ def __init__(
def search(
self,
query_text: str = "",
message_history: Optional[list[BaseMessage]] = None,
message_history: Optional[list[dict[str, str]]] = None,
examples: str = "",
retriever_config: Optional[dict[str, Any]] = None,
return_context: bool | None = None,
Expand Down Expand Up @@ -148,7 +147,7 @@ def search(
return RagResultModel(**result)

def build_query(
self, query_text: str, message_history: Optional[list[BaseMessage]] = None
self, query_text: str, message_history: Optional[list[dict[str, str]]] = None
) -> str:
if message_history:
summarization_prompt = ChatSummaryTemplate().format(
Expand Down
6 changes: 3 additions & 3 deletions src/neo4j_graphrag/generation/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import warnings
from typing import Any, Optional

from neo4j_graphrag.llm.types import BaseMessage
from neo4j_graphrag.exceptions import (
PromptMissingInputError,
PromptMissingPlaceholderError,
Expand Down Expand Up @@ -208,9 +207,10 @@ class ChatSummaryTemplate(PromptTemplate):
EXPECTED_INPUTS = ["message_history"]
SYSTEM_MESSAGE = "You are a summarization assistant. Summarize the given text in no more than 200 words"

def format(self, message_history: list[BaseMessage]) -> str:
def format(self, message_history: list[dict[str, str]]) -> str:
message_list = [
f"{message.role}: {message.content}" for message in message_history
": ".join([f"{value}" for _, value in message.items()])
for message in message_history
]
history = "\n".join(message_list)
return super().format(message_history=history)
Expand Down
22 changes: 11 additions & 11 deletions src/neo4j_graphrag/llm/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
# limitations under the License.
from __future__ import annotations

from typing import Any, Iterable, Optional, TYPE_CHECKING
from typing import Any, Iterable, Optional, TYPE_CHECKING, cast

from pydantic import ValidationError

from neo4j_graphrag.exceptions import LLMGenerationError
from neo4j_graphrag.llm.base import LLMInterface
from neo4j_graphrag.llm.types import LLMResponse, MessageList, UserMessage, BaseMessage
from neo4j_graphrag.llm.types import LLMResponse, MessageList, UserMessage

if TYPE_CHECKING:
from anthropic.types.message_param import MessageParam
Expand Down Expand Up @@ -71,22 +71,22 @@ def __init__(
self.async_client = anthropic.AsyncAnthropic(**kwargs)

def get_messages(
self, input: str, message_history: Optional[list[BaseMessage]] = None
self, input: str, message_history: Optional[list[dict[str, str]]] = None
) -> Iterable[MessageParam]:
messages = []
if message_history:
try:
MessageList(messages=message_history)
MessageList(messages=message_history) # type: ignore
except ValidationError as e:
raise LLMGenerationError(e.errors()) from e
messages.extend(message_history)
messages.append(UserMessage(content=input).model_dump())
return messages
return cast(Iterable[MessageParam], messages)

def invoke(
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
message_history: Optional[list[dict[str, str]]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Sends text to the LLM and returns a response.
Expand All @@ -108,18 +108,18 @@ def invoke(
)
response = self.client.messages.create(
model=self.model_name,
system=system_message,
system=system_message, # type: ignore
messages=messages,
**self.model_params,
)
return LLMResponse(content=response.content)
return LLMResponse(content=response.content) # type: ignore
except self.anthropic.APIError as e:
raise LLMGenerationError(e)

async def ainvoke(
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
message_history: Optional[list[dict[str, str]]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Asynchronously sends text to the LLM and returns a response.
Expand All @@ -141,10 +141,10 @@ async def ainvoke(
)
response = await self.async_client.messages.create(
model=self.model_name,
system=system_message,
system=system_message, # type: ignore
messages=messages,
**self.model_params,
)
return LLMResponse(content=response.content)
return LLMResponse(content=response.content) # type: ignore
except self.anthropic.APIError as e:
raise LLMGenerationError(e)
6 changes: 3 additions & 3 deletions src/neo4j_graphrag/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from abc import ABC, abstractmethod
from typing import Any, Optional

from .types import LLMResponse, BaseMessage
from .types import LLMResponse


class LLMInterface(ABC):
Expand Down Expand Up @@ -45,7 +45,7 @@ def __init__(
def invoke(
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
message_history: Optional[list[dict[str, str]]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Sends a text input to the LLM and retrieves a response.
Expand All @@ -66,7 +66,7 @@ def invoke(
async def ainvoke(
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
message_history: Optional[list[dict[str, str]]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Asynchronously sends a text input to the LLM and retrieves a response.
Expand Down
13 changes: 6 additions & 7 deletions src/neo4j_graphrag/llm/cohere_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, cast
from pydantic import ValidationError

from neo4j_graphrag.exceptions import LLMGenerationError
Expand All @@ -24,7 +24,6 @@
MessageList,
SystemMessage,
UserMessage,
BaseMessage,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -77,7 +76,7 @@ def __init__(
def get_messages(
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
message_history: Optional[list[dict[str, str]]] = None,
system_instruction: Optional[str] = None,
) -> ChatMessages:
messages = []
Expand All @@ -90,17 +89,17 @@ def get_messages(
messages.append(SystemMessage(content=system_message).model_dump())
if message_history:
try:
MessageList(messages=message_history)
MessageList(messages=message_history) # type: ignore
except ValidationError as e:
raise LLMGenerationError(e.errors()) from e
messages.extend(message_history)
messages.append(UserMessage(content=input).model_dump())
return messages
return cast(ChatMessages, messages)

def invoke(
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
message_history: Optional[list[dict[str, str]]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Sends text to the LLM and returns a response.
Expand Down Expand Up @@ -128,7 +127,7 @@ def invoke(
async def ainvoke(
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
message_history: Optional[list[dict[str, str]]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Asynchronously sends text to the LLM and returns a response.
Expand Down
9 changes: 4 additions & 5 deletions src/neo4j_graphrag/llm/mistralai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
MessageList,
SystemMessage,
UserMessage,
BaseMessage,
)

try:
Expand Down Expand Up @@ -68,7 +67,7 @@ def __init__(
def get_messages(
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
message_history: Optional[list[dict[str, str]]] = None,
system_instruction: Optional[str] = None,
) -> list[Messages]:
messages = []
Expand All @@ -81,7 +80,7 @@ def get_messages(
messages.append(SystemMessage(content=system_message).model_dump())
if message_history:
try:
MessageList(messages=message_history)
MessageList(messages=message_history) # type: ignore
except ValidationError as e:
raise LLMGenerationError(e.errors()) from e
messages.extend(message_history)
Expand All @@ -91,7 +90,7 @@ def get_messages(
def invoke(
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
message_history: Optional[list[dict[str, str]]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Sends a text input to the Mistral chat completion model
Expand Down Expand Up @@ -127,7 +126,7 @@ def invoke(
async def ainvoke(
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
message_history: Optional[list[dict[str, str]]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Asynchronously sends a text input to the MistralAI chat
Expand Down
14 changes: 7 additions & 7 deletions src/neo4j_graphrag/llm/ollama_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Sequence, TYPE_CHECKING
from typing import Any, Optional, Sequence, TYPE_CHECKING, cast

from pydantic import ValidationError

from neo4j_graphrag.exceptions import LLMGenerationError

from .base import LLMInterface
from .types import LLMResponse, SystemMessage, UserMessage, MessageList, BaseMessage
from .types import LLMResponse, SystemMessage, UserMessage, MessageList

if TYPE_CHECKING:
from ollama import Message
Expand Down Expand Up @@ -52,7 +52,7 @@ def __init__(
def get_messages(
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
message_history: Optional[list[dict[str, str]]] = None,
system_instruction: Optional[str] = None,
) -> Sequence[Message]:
messages = []
Expand All @@ -65,17 +65,17 @@ def get_messages(
messages.append(SystemMessage(content=system_message).model_dump())
if message_history:
try:
MessageList(messages=message_history)
MessageList(messages=message_history) # type: ignore
except ValidationError as e:
raise LLMGenerationError(e.errors()) from e
messages.extend(message_history)
messages.append(UserMessage(content=input).model_dump())
return messages
return cast(Sequence[Message], messages)

def invoke(
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
message_history: Optional[list[dict[str, str]]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Sends text to the LLM and returns a response.
Expand All @@ -102,7 +102,7 @@ def invoke(
async def ainvoke(
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
message_history: Optional[list[dict[str, str]]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Asynchronously sends a text input to the OpenAI chat
Expand Down
14 changes: 7 additions & 7 deletions src/neo4j_graphrag/llm/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
from __future__ import annotations

import abc
from typing import TYPE_CHECKING, Any, Iterable, Optional
from typing import TYPE_CHECKING, Any, Iterable, Optional, cast

from pydantic import ValidationError

from ..exceptions import LLMGenerationError
from .base import LLMInterface
from .types import LLMResponse, SystemMessage, UserMessage, MessageList, BaseMessage
from .types import LLMResponse, SystemMessage, UserMessage, MessageList

if TYPE_CHECKING:
import openai
Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
def get_messages(
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
message_history: Optional[list[dict[str, str]]] = None,
system_instruction: Optional[str] = None,
) -> Iterable[ChatCompletionMessageParam]:
messages = []
Expand All @@ -76,17 +76,17 @@ def get_messages(
messages.append(SystemMessage(content=system_message).model_dump())
if message_history:
try:
MessageList(messages=message_history)
MessageList(messages=message_history) # type: ignore
except ValidationError as e:
raise LLMGenerationError(e.errors()) from e
messages.extend(message_history)
messages.append(UserMessage(content=input).model_dump())
return messages
return cast(Iterable[ChatCompletionMessageParam], messages)

def invoke(
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
message_history: Optional[list[dict[str, str]]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Sends a text input to the OpenAI chat completion model
Expand Down Expand Up @@ -117,7 +117,7 @@ def invoke(
async def ainvoke(
self,
input: str,
message_history: Optional[list[BaseMessage]] = None,
message_history: Optional[list[dict[str, str]]] = None,
system_instruction: Optional[str] = None,
) -> LLMResponse:
"""Asynchronously sends a text input to the OpenAI chat
Expand Down
Loading

0 comments on commit 819179e

Please sign in to comment.