diff --git a/src/neo4j_graphrag/generation/graphrag.py b/src/neo4j_graphrag/generation/graphrag.py index 6ce903f9..5889c867 100644 --- a/src/neo4j_graphrag/generation/graphrag.py +++ b/src/neo4j_graphrag/generation/graphrag.py @@ -31,7 +31,6 @@ ) from neo4j_graphrag.generation.types import RagInitModel, RagResultModel, RagSearchModel from neo4j_graphrag.llm import LLMInterface -from neo4j_graphrag.llm.types import BaseMessage from neo4j_graphrag.retrievers.base import Retriever from neo4j_graphrag.types import RetrieverResult @@ -88,7 +87,7 @@ def __init__( def search( self, query_text: str = "", - message_history: Optional[list[BaseMessage]] = None, + message_history: Optional[list[dict[str, str]]] = None, examples: str = "", retriever_config: Optional[dict[str, Any]] = None, return_context: bool | None = None, @@ -148,7 +147,7 @@ def search( return RagResultModel(**result) def build_query( - self, query_text: str, message_history: Optional[list[BaseMessage]] = None + self, query_text: str, message_history: Optional[list[dict[str, str]]] = None ) -> str: if message_history: summarization_prompt = ChatSummaryTemplate().format( diff --git a/src/neo4j_graphrag/generation/prompts.py b/src/neo4j_graphrag/generation/prompts.py index 6df4d069..862c5af0 100644 --- a/src/neo4j_graphrag/generation/prompts.py +++ b/src/neo4j_graphrag/generation/prompts.py @@ -17,7 +17,6 @@ import warnings from typing import Any, Optional -from neo4j_graphrag.llm.types import BaseMessage from neo4j_graphrag.exceptions import ( PromptMissingInputError, PromptMissingPlaceholderError, @@ -208,9 +207,10 @@ class ChatSummaryTemplate(PromptTemplate): EXPECTED_INPUTS = ["message_history"] SYSTEM_MESSAGE = "You are a summarization assistant. Summarize the given text in no more than 200 words" - def format(self, message_history: list[BaseMessage]) -> str: + def format(self, message_history: list[dict[str, str]]) -> str: message_list = [ - f"{message.role}: {message.content}" for message in message_history + ": ".join([f"{value}" for _, value in message.items()]) + for message in message_history ] history = "\n".join(message_list) return super().format(message_history=history) diff --git a/src/neo4j_graphrag/llm/anthropic_llm.py b/src/neo4j_graphrag/llm/anthropic_llm.py index 6e9207be..9699e562 100644 --- a/src/neo4j_graphrag/llm/anthropic_llm.py +++ b/src/neo4j_graphrag/llm/anthropic_llm.py @@ -13,13 +13,13 @@ # limitations under the License. from __future__ import annotations -from typing import Any, Iterable, Optional, TYPE_CHECKING +from typing import Any, Iterable, Optional, TYPE_CHECKING, cast from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.llm.types import LLMResponse, MessageList, UserMessage, BaseMessage +from neo4j_graphrag.llm.types import LLMResponse, MessageList, UserMessage if TYPE_CHECKING: from anthropic.types.message_param import MessageParam @@ -71,22 +71,22 @@ def __init__( self.async_client = anthropic.AsyncAnthropic(**kwargs) def get_messages( - self, input: str, message_history: Optional[list[BaseMessage]] = None + self, input: str, message_history: Optional[list[dict[str, str]]] = None ) -> Iterable[MessageParam]: messages = [] if message_history: try: - MessageList(messages=message_history) + MessageList(messages=message_history) # type: ignore except ValidationError as e: raise LLMGenerationError(e.errors()) from e messages.extend(message_history) messages.append(UserMessage(content=input).model_dump()) - return messages + return cast(Iterable[MessageParam], messages) def invoke( self, input: str, - message_history: Optional[list[BaseMessage]] = None, + message_history: Optional[list[dict[str, str]]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends text to the LLM and returns a response. @@ -108,18 +108,18 @@ def invoke( ) response = self.client.messages.create( model=self.model_name, - system=system_message, + system=system_message, # type: ignore messages=messages, **self.model_params, ) - return LLMResponse(content=response.content) + return LLMResponse(content=response.content) # type: ignore except self.anthropic.APIError as e: raise LLMGenerationError(e) async def ainvoke( self, input: str, - message_history: Optional[list[BaseMessage]] = None, + message_history: Optional[list[dict[str, str]]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. @@ -141,10 +141,10 @@ async def ainvoke( ) response = await self.async_client.messages.create( model=self.model_name, - system=system_message, + system=system_message, # type: ignore messages=messages, **self.model_params, ) - return LLMResponse(content=response.content) + return LLMResponse(content=response.content) # type: ignore except self.anthropic.APIError as e: raise LLMGenerationError(e) diff --git a/src/neo4j_graphrag/llm/base.py b/src/neo4j_graphrag/llm/base.py index 4c5616c3..488cf4c9 100644 --- a/src/neo4j_graphrag/llm/base.py +++ b/src/neo4j_graphrag/llm/base.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from typing import Any, Optional -from .types import LLMResponse, BaseMessage +from .types import LLMResponse class LLMInterface(ABC): @@ -45,7 +45,7 @@ def __init__( def invoke( self, input: str, - message_history: Optional[list[BaseMessage]] = None, + message_history: Optional[list[dict[str, str]]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends a text input to the LLM and retrieves a response. @@ -66,7 +66,7 @@ def invoke( async def ainvoke( self, input: str, - message_history: Optional[list[BaseMessage]] = None, + message_history: Optional[list[dict[str, str]]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends a text input to the LLM and retrieves a response. diff --git a/src/neo4j_graphrag/llm/cohere_llm.py b/src/neo4j_graphrag/llm/cohere_llm.py index 782757be..acdf669e 100644 --- a/src/neo4j_graphrag/llm/cohere_llm.py +++ b/src/neo4j_graphrag/llm/cohere_llm.py @@ -14,7 +14,7 @@ # limitations under the License. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, cast from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError @@ -24,7 +24,6 @@ MessageList, SystemMessage, UserMessage, - BaseMessage, ) if TYPE_CHECKING: @@ -77,7 +76,7 @@ def __init__( def get_messages( self, input: str, - message_history: Optional[list[BaseMessage]] = None, + message_history: Optional[list[dict[str, str]]] = None, system_instruction: Optional[str] = None, ) -> ChatMessages: messages = [] @@ -90,17 +89,17 @@ def get_messages( messages.append(SystemMessage(content=system_message).model_dump()) if message_history: try: - MessageList(messages=message_history) + MessageList(messages=message_history) # type: ignore except ValidationError as e: raise LLMGenerationError(e.errors()) from e messages.extend(message_history) messages.append(UserMessage(content=input).model_dump()) - return messages + return cast(ChatMessages, messages) def invoke( self, input: str, - message_history: Optional[list[BaseMessage]] = None, + message_history: Optional[list[dict[str, str]]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends text to the LLM and returns a response. @@ -128,7 +127,7 @@ def invoke( async def ainvoke( self, input: str, - message_history: Optional[list[BaseMessage]] = None, + message_history: Optional[list[dict[str, str]]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. diff --git a/src/neo4j_graphrag/llm/mistralai_llm.py b/src/neo4j_graphrag/llm/mistralai_llm.py index f04ac089..6b42cac5 100644 --- a/src/neo4j_graphrag/llm/mistralai_llm.py +++ b/src/neo4j_graphrag/llm/mistralai_llm.py @@ -25,7 +25,6 @@ MessageList, SystemMessage, UserMessage, - BaseMessage, ) try: @@ -68,7 +67,7 @@ def __init__( def get_messages( self, input: str, - message_history: Optional[list[BaseMessage]] = None, + message_history: Optional[list[dict[str, str]]] = None, system_instruction: Optional[str] = None, ) -> list[Messages]: messages = [] @@ -81,7 +80,7 @@ def get_messages( messages.append(SystemMessage(content=system_message).model_dump()) if message_history: try: - MessageList(messages=message_history) + MessageList(messages=message_history) # type: ignore except ValidationError as e: raise LLMGenerationError(e.errors()) from e messages.extend(message_history) @@ -91,7 +90,7 @@ def get_messages( def invoke( self, input: str, - message_history: Optional[list[BaseMessage]] = None, + message_history: Optional[list[dict[str, str]]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends a text input to the Mistral chat completion model @@ -127,7 +126,7 @@ def invoke( async def ainvoke( self, input: str, - message_history: Optional[list[BaseMessage]] = None, + message_history: Optional[list[dict[str, str]]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends a text input to the MistralAI chat diff --git a/src/neo4j_graphrag/llm/ollama_llm.py b/src/neo4j_graphrag/llm/ollama_llm.py index 00f15cd1..83840206 100644 --- a/src/neo4j_graphrag/llm/ollama_llm.py +++ b/src/neo4j_graphrag/llm/ollama_llm.py @@ -12,14 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, TYPE_CHECKING +from typing import Any, Optional, Sequence, TYPE_CHECKING, cast from pydantic import ValidationError from neo4j_graphrag.exceptions import LLMGenerationError from .base import LLMInterface -from .types import LLMResponse, SystemMessage, UserMessage, MessageList, BaseMessage +from .types import LLMResponse, SystemMessage, UserMessage, MessageList if TYPE_CHECKING: from ollama import Message @@ -52,7 +52,7 @@ def __init__( def get_messages( self, input: str, - message_history: Optional[list[BaseMessage]] = None, + message_history: Optional[list[dict[str, str]]] = None, system_instruction: Optional[str] = None, ) -> Sequence[Message]: messages = [] @@ -65,17 +65,17 @@ def get_messages( messages.append(SystemMessage(content=system_message).model_dump()) if message_history: try: - MessageList(messages=message_history) + MessageList(messages=message_history) # type: ignore except ValidationError as e: raise LLMGenerationError(e.errors()) from e messages.extend(message_history) messages.append(UserMessage(content=input).model_dump()) - return messages + return cast(Sequence[Message], messages) def invoke( self, input: str, - message_history: Optional[list[BaseMessage]] = None, + message_history: Optional[list[dict[str, str]]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends text to the LLM and returns a response. @@ -102,7 +102,7 @@ def invoke( async def ainvoke( self, input: str, - message_history: Optional[list[BaseMessage]] = None, + message_history: Optional[list[dict[str, str]]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends a text input to the OpenAI chat diff --git a/src/neo4j_graphrag/llm/openai_llm.py b/src/neo4j_graphrag/llm/openai_llm.py index 249413be..bc29469f 100644 --- a/src/neo4j_graphrag/llm/openai_llm.py +++ b/src/neo4j_graphrag/llm/openai_llm.py @@ -15,13 +15,13 @@ from __future__ import annotations import abc -from typing import TYPE_CHECKING, Any, Iterable, Optional +from typing import TYPE_CHECKING, Any, Iterable, Optional, cast from pydantic import ValidationError from ..exceptions import LLMGenerationError from .base import LLMInterface -from .types import LLMResponse, SystemMessage, UserMessage, MessageList, BaseMessage +from .types import LLMResponse, SystemMessage, UserMessage, MessageList if TYPE_CHECKING: import openai @@ -63,7 +63,7 @@ def __init__( def get_messages( self, input: str, - message_history: Optional[list[BaseMessage]] = None, + message_history: Optional[list[dict[str, str]]] = None, system_instruction: Optional[str] = None, ) -> Iterable[ChatCompletionMessageParam]: messages = [] @@ -76,17 +76,17 @@ def get_messages( messages.append(SystemMessage(content=system_message).model_dump()) if message_history: try: - MessageList(messages=message_history) + MessageList(messages=message_history) # type: ignore except ValidationError as e: raise LLMGenerationError(e.errors()) from e messages.extend(message_history) messages.append(UserMessage(content=input).model_dump()) - return messages + return cast(Iterable[ChatCompletionMessageParam], messages) def invoke( self, input: str, - message_history: Optional[list[BaseMessage]] = None, + message_history: Optional[list[dict[str, str]]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends a text input to the OpenAI chat completion model @@ -117,7 +117,7 @@ def invoke( async def ainvoke( self, input: str, - message_history: Optional[list[BaseMessage]] = None, + message_history: Optional[list[dict[str, str]]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends a text input to the OpenAI chat diff --git a/src/neo4j_graphrag/llm/vertexai_llm.py b/src/neo4j_graphrag/llm/vertexai_llm.py index 77d1ec4b..2ca506e0 100644 --- a/src/neo4j_graphrag/llm/vertexai_llm.py +++ b/src/neo4j_graphrag/llm/vertexai_llm.py @@ -19,7 +19,7 @@ from neo4j_graphrag.exceptions import LLMGenerationError from neo4j_graphrag.llm.base import LLMInterface -from neo4j_graphrag.llm.types import LLMResponse, MessageList, BaseMessage +from neo4j_graphrag.llm.types import LLMResponse, MessageList try: from vertexai.generative_models import ( @@ -74,26 +74,30 @@ def __init__( super().__init__(model_name, model_params) self.model_name = model_name self.system_instruction = system_instruction - self.model_params = kwargs + self.options = kwargs def get_messages( - self, input: str, message_history: Optional[list[BaseMessage]] = None + self, input: str, message_history: Optional[list[dict[str, str]]] = None ) -> list[Content]: messages = [] if message_history: try: - MessageList(messages=message_history) + MessageList(messages=message_history) # type: ignore except ValidationError as e: raise LLMGenerationError(e.errors()) from e for message in message_history: - if message.role == "user": + if message.get("role") == "user": messages.append( - Content(role="user", parts=[Part.from_text(message.content)]) + Content( + role="user", parts=[Part.from_text(message.get("content"))] + ) ) - elif message.role == "assistant": + elif message.get("role") == "assistant": messages.append( - Content(role="model", parts=[Part.from_text(message.content)]) + Content( + role="model", parts=[Part.from_text(message.get("content"))] + ) ) messages.append(Content(role="user", parts=[Part.from_text(input)])) @@ -102,7 +106,7 @@ def get_messages( def invoke( self, input: str, - message_history: Optional[list[BaseMessage]] = None, + message_history: Optional[list[dict[str, str]]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Sends text to the LLM and returns a response. @@ -123,7 +127,7 @@ def invoke( self.model = GenerativeModel( model_name=self.model_name, system_instruction=[system_message], - **self.model_params, + **self.options, ) try: messages = self.get_messages(input, message_history) @@ -135,7 +139,7 @@ def invoke( async def ainvoke( self, input: str, - message_history: Optional[list[BaseMessage]] = None, + message_history: Optional[list[dict[str, str]]] = None, system_instruction: Optional[str] = None, ) -> LLMResponse: """Asynchronously sends text to the LLM and returns a response. @@ -157,7 +161,7 @@ async def ainvoke( self.model = GenerativeModel( model_name=self.model_name, system_instruction=[system_message], - **self.model_params, + **self.options, ) messages = self.get_messages(input, message_history) response = await self.model.generate_content_async(