-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Chat with message history (draft) #225
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -83,6 +83,7 @@ def __init__( | |
def search( | ||
self, | ||
query_text: str = "", | ||
chat_history: Optional[list[str]] = None, | ||
examples: str = "", | ||
retriever_config: Optional[dict[str, Any]] = None, | ||
return_context: bool | None = None, | ||
|
@@ -100,6 +101,7 @@ def search( | |
Args: | ||
query_text (str): The user question | ||
chat_history: (Optional[list]): A list of previous messages in the conversation | ||
examples (str): Examples added to the LLM prompt. | ||
retriever_config (Optional[dict]): Parameters passed to the retriever | ||
search method; e.g.: top_k | ||
|
@@ -134,7 +136,10 @@ def search( | |
) | ||
logger.debug(f"RAG: retriever_result={retriever_result}") | ||
logger.debug(f"RAG: prompt={prompt}") | ||
answer = self.llm.invoke(prompt) | ||
if chat_history is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know this part isn't great. The idea I want to demonstrate is to have two different flows for chat with message history and single invokation. |
||
answer = self.llm.chat(prompt, chat_history) | ||
else: | ||
answer = self.llm.invoke(prompt) | ||
result: dict[str, Any] = {"answer": answer.content} | ||
if return_context: | ||
result["retriever_result"] = retriever_result | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,10 +33,12 @@ def __init__( | |
self, | ||
model_name: str, | ||
model_params: Optional[dict[str, Any]] = None, | ||
system_instruction: Optional[str] = None, | ||
**kwargs: Any, | ||
): | ||
self.model_name = model_name | ||
self.model_params = model_params or {} | ||
self.system_instruction = system_instruction | ||
|
||
@abstractmethod | ||
def invoke(self, input: str) -> LLMResponse: | ||
|
@@ -52,6 +54,21 @@ def invoke(self, input: str) -> LLMResponse: | |
LLMGenerationError: If anything goes wrong. | ||
""" | ||
|
||
@abstractmethod | ||
def chat(self, input: str, chat_history: list[str]) -> LLMResponse: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like it could be done in one method, with an optional
This is an advantage if most of the LLM providers share the same format for messages, in which case we can implement this method only once in the interface. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, that works. |
||
"""Sends a text input and a converstion history to the LLM and retrieves a response. | ||
|
||
Args: | ||
input (str): Text sent to the LLM | ||
chat_history (list[str]]): A list of previous messages in the conversation | ||
|
||
Returns: | ||
LLMResponse: The response from the LLM. | ||
|
||
Raises: | ||
LLMGenerationError: If anything goes wrong. | ||
""" | ||
|
||
@abstractmethod | ||
async def ainvoke(self, input: str) -> LLMResponse: | ||
"""Asynchronously sends a text input to the LLM and retrieves a response. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,7 @@ def __init__( | |
self, | ||
model_name: str, | ||
model_params: Optional[dict[str, Any]] = None, | ||
system_instruction: Optional[str] = None, | ||
): | ||
""" | ||
Base class for OpenAI LLM. | ||
|
@@ -54,7 +55,7 @@ def __init__( | |
"Please install it with `pip install openai`." | ||
) | ||
self.openai = openai | ||
super().__init__(model_name, model_params) | ||
super().__init__(model_name, model_params, system_instruction) | ||
|
||
def get_messages( | ||
self, | ||
|
@@ -64,6 +65,32 @@ def get_messages( | |
{"role": "system", "content": input}, | ||
] | ||
|
||
def get_conversation_history( | ||
self, | ||
input: str, | ||
chat_history: list[str], | ||
) -> Iterable[ChatCompletionMessageParam]: | ||
messages = [{"role": "system", "content": self.system_instruction}] | ||
for i, message in enumerate(chat_history): | ||
if i % 2 == 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This assumes that the chat history is a list of strings, alternating between There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can use dicts with "role" and "content" keys (it'd work for OpenAI, I don't know about the others) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can look into it. |
||
messages.append({"role": "user", "content": message}) | ||
else: | ||
messages.append({"role": "assistant", "content": message}) | ||
messages.append({"role": "user", "content": input}) | ||
return messages | ||
|
||
def chat(self, input: str, chat_history: list[str]) -> LLMResponse: | ||
try: | ||
response = self.client.chat.completions.create( | ||
messages=self.get_conversation_history(input, chat_history), | ||
model=self.model_name, | ||
**self.model_params, | ||
) | ||
content = response.choices[0].message.content or "" | ||
return LLMResponse(content=content) | ||
except self.openai.OpenAIError as e: | ||
raise LLMGenerationError(e) | ||
|
||
def invoke(self, input: str) -> LLMResponse: | ||
"""Sends a text input to the OpenAI chat completion model | ||
and returns the response's content. | ||
|
@@ -118,6 +145,7 @@ def __init__( | |
self, | ||
model_name: str, | ||
model_params: Optional[dict[str, Any]] = None, | ||
system_instruction: Optional[str] = None, | ||
**kwargs: Any, | ||
): | ||
"""OpenAI LLM | ||
|
@@ -129,7 +157,7 @@ def __init__( | |
model_params (str): Parameters like temperature that will be passed to the model when text is sent to it | ||
kwargs: All other parameters will be passed to the openai.OpenAI init. | ||
""" | ||
super().__init__(model_name, model_params) | ||
super().__init__(model_name, model_params, system_instruction) | ||
self.client = self.openai.OpenAI(**kwargs) | ||
self.async_client = self.openai.AsyncOpenAI(**kwargs) | ||
|
||
|
@@ -139,6 +167,7 @@ def __init__( | |
self, | ||
model_name: str, | ||
model_params: Optional[dict[str, Any]] = None, | ||
system_instruction: Optional[str] = None, | ||
**kwargs: Any, | ||
): | ||
"""Azure OpenAI LLM. Use this class when using an OpenAI model | ||
|
@@ -149,6 +178,6 @@ def __init__( | |
model_params (str): Parameters like temperature that will be passed to the model when text is sent to it | ||
kwargs: All other parameters will be passed to the openai.OpenAI init. | ||
""" | ||
super().__init__(model_name, model_params) | ||
super().__init__(model_name, model_params, system_instruction) | ||
self.client = self.openai.AzureOpenAI(**kwargs) | ||
self.async_client = self.openai.AsyncAzureOpenAI(**kwargs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,7 +20,12 @@ | |
from neo4j_graphrag.llm.types import LLMResponse | ||
|
||
try: | ||
from vertexai.generative_models import GenerativeModel, ResponseValidationError | ||
from vertexai.generative_models import ( | ||
GenerativeModel, | ||
ResponseValidationError, | ||
Part, | ||
Content, | ||
) | ||
except ImportError: | ||
GenerativeModel = None | ||
ResponseValidationError = None | ||
|
@@ -55,15 +60,18 @@ def __init__( | |
self, | ||
model_name: str = "gemini-1.5-flash-001", | ||
model_params: Optional[dict[str, Any]] = None, | ||
system_instruction: Optional[str] = None, | ||
**kwargs: Any, | ||
): | ||
if GenerativeModel is None or ResponseValidationError is None: | ||
raise ImportError( | ||
"Could not import Vertex AI Python client. " | ||
"Please install it with `pip install google-cloud-aiplatform`." | ||
) | ||
super().__init__(model_name, model_params) | ||
self.model = GenerativeModel(model_name=model_name, **kwargs) | ||
super().__init__(model_name, model_params, system_instruction) | ||
self.model = GenerativeModel( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is quite bad news actually, that means that we can use an LLM instance only in one specific use case? For instance, with RAG with a text2Cypher retriever, we can't use the same LLM for the retrieval and the generation? (I mean, if we want to use the system prompt, which is not what we do at the moment) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess we don't have to always use the |
||
model_name=model_name, system_instruction=[system_instruction], **kwargs | ||
) | ||
|
||
def invoke(self, input: str) -> LLMResponse: | ||
"""Sends text to the LLM and returns a response. | ||
|
@@ -80,7 +88,25 @@ def invoke(self, input: str) -> LLMResponse: | |
except ResponseValidationError as e: | ||
raise LLMGenerationError(e) | ||
|
||
async def ainvoke(self, input: str) -> LLMResponse: | ||
def chat(self, input: str, chat_history: list[str] = []) -> LLMResponse: | ||
"""Sends text to the LLM and returns a response. | ||
Args: | ||
input (str): The text to send to the LLM. | ||
Returns: | ||
LLMResponse: The response from the LLM. | ||
""" | ||
try: | ||
messages = self.get_conversation_history(input, chat_history) | ||
response = self.model.generate_content(messages, **self.model_params) | ||
return LLMResponse(content=response.text) | ||
except ResponseValidationError as e: | ||
raise LLMGenerationError(e) | ||
|
||
async def ainvoke( | ||
self, input: str, chat_history: Optional[list[str]] = [] | ||
) -> LLMResponse: | ||
"""Asynchronously sends text to the LLM and returns a response. | ||
Args: | ||
|
@@ -96,3 +122,17 @@ async def ainvoke(self, input: str) -> LLMResponse: | |
return LLMResponse(content=response.text) | ||
except ResponseValidationError as e: | ||
raise LLMGenerationError(e) | ||
|
||
def get_conversation_history( | ||
self, | ||
input: str, | ||
chat_history: list[str], | ||
) -> list[Content]: | ||
messages = [] | ||
for i, message in enumerate(chat_history): | ||
if i % 2 == 0: | ||
messages.append(Content(role="user", parts=[Part.from_text(message)])) | ||
else: | ||
messages.append(Content(role="model", parts=[Part.from_text(message)])) | ||
messages.append(Content(role="user", parts=[Part.from_text(input)])) | ||
return messages |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry about this, but I couldn't bother with mypy just for opening a Draft.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a good point actually, maybe we should disable this for draft PRs