Skip to content

Commit

Permalink
More on LLM message history and system instruction (#240)
Browse files Browse the repository at this point in the history
* Add examples for LLM message history and system instructions

* Fix for when system_message is None

* Ruff

* Better (?) system prompt for RAG

* Update system_instruction

* Mypy

* Mypy

* Fix ollama test

* Fix anthropic test

* Fix cohere test

* Fix vertexai test

* Fix mistralai test

* Fix graphrag test

* Ruff

* Mypy

* Variable is not used

* Ruff...

* Mypy + e2e tests

* Ruffffffff

* CHANGELOG

* Fix examples

* Remove useless commented api_key from examples
  • Loading branch information
stellasia authored Jan 13, 2025
1 parent feeddbb commit 81d1f31
Show file tree
Hide file tree
Showing 30 changed files with 381 additions and 299 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

### Added
- Support for conversations with message history, including a new `message_history` parameter for LLM interactions.
- Ability to include system instructions and override them for specific invocations.
- Summarization of chat history to enhance query embedding and context handling.
- Ability to include system instructions in LLM invoke method.
- Summarization of chat history to enhance query embedding and context handling in GraphRAG.

### Changed
- Updated LLM implementations to handle message history consistently across providers.
Expand Down
4 changes: 4 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ are listed in [the last section of this file](#customize).
## Answer: GraphRAG

- [End to end GraphRAG](./answer/graphrag.py)
- [GraphRAG with message history](./question_answering/graphrag_with_message_history.py)


## Customize
Expand All @@ -73,6 +74,9 @@ are listed in [the last section of this file](#customize).
- [Ollama](./customize/llms/ollama_llm.py)
- [Custom LLM](./customize/llms/custom_llm.py)

- [Message history](./customize/llms/llm_with_message_history.py)
- [System Instruction](./customize/llms/llm_with_system_instructions.py)


### Prompts

Expand Down
1 change: 0 additions & 1 deletion examples/customize/embeddings/cohere_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

# set api key here on in the CO_API_KEY env var
api_key = None
# api_key = "sk-..."

embeder = CohereEmbeddings(
model="embed-english-v3.0",
Expand Down
1 change: 0 additions & 1 deletion examples/customize/embeddings/mistalai_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

# set api key here on in the MISTRAL_API_KEY env var
api_key = None
# api_key = "sk-..."

embeder = MistralAIEmbeddings(model="mistral-embed", api_key=api_key)
res = embeder.embed_query("my question")
Expand Down
1 change: 0 additions & 1 deletion examples/customize/embeddings/openai_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

# set api key here on in the OPENAI_API_KEY env var
api_key = None
# api_key = "sk-..."

embeder = OpenAIEmbeddings(model="text-embedding-ada-002", api_key=api_key)
res = embeder.embed_query("my question")
Expand Down
2 changes: 0 additions & 2 deletions examples/customize/llms/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

# set api key here on in the ANTHROPIC_API_KEY env var
api_key = None
# api_key = "sk-..."


llm = AnthropicLLM(
model_name="claude-3-opus-20240229",
Expand Down
1 change: 0 additions & 1 deletion examples/customize/llms/cohere_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

# set api key here on in the CO_API_KEY env var
api_key = None
# api_key = "sk-..."

llm = CohereLLM(
model_name="command-r",
Expand Down
42 changes: 42 additions & 0 deletions examples/customize/llms/llm_with_message_history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""This example illustrates the message_history feature
of the LLMInterface by mocking a conversation between a user
and an LLM about Tom Hanks.
OpenAILLM can be replaced by any supported LLM from this package.
"""

from neo4j_graphrag.llm import LLMResponse, OpenAILLM

# set api key here on in the OPENAI_API_KEY env var
api_key = None

llm = OpenAILLM(model_name="gpt-4o", api_key=api_key)

questions = [
"What are some movies Tom Hanks starred in?",
"Is he also a director?",
"Wow, that's impressive. And what about his personal life, does he have children?",
]

history: list[dict[str, str]] = []
for question in questions:
res: LLMResponse = llm.invoke(
question,
message_history=history, # type: ignore
)
history.append(
{
"role": "user",
"content": question,
}
)
history.append(
{
"role": "assistant",
"content": res.content,
}
)

print("#" * 50, question)
print(res.content)
print("#" * 50)
22 changes: 22 additions & 0 deletions examples/customize/llms/llm_with_system_instructions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""This example illustrates how to set system instructions for LLM.
OpenAILLM can be replaced by any supported LLM from this package.
"""

from neo4j_graphrag.llm import LLMResponse, OpenAILLM

# set api key here on in the OPENAI_API_KEY env var
api_key = None

llm = OpenAILLM(
model_name="gpt-4o",
api_key=api_key,
)

question = "How fast is Santa Claus during the Christmas eve?"

res: LLMResponse = llm.invoke(
question,
system_instruction="Answer with a serious tone",
)
print(res.content)
2 changes: 0 additions & 2 deletions examples/customize/llms/mistalai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

# set api key here on in the MISTRAL_API_KEY env var
api_key = None
# api_key = "sk-..."


llm = MistralAILLM(
model_name="mistral-small-latest",
Expand Down
1 change: 0 additions & 1 deletion examples/customize/llms/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

# set api key here on in the OPENAI_API_KEY env var
api_key = None
# api_key = "sk-..."

llm = OpenAILLM(model_name="gpt-4o", api_key=api_key)
res: LLMResponse = llm.invoke("say something")
Expand Down
7 changes: 4 additions & 3 deletions examples/question_answering/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
from neo4j_graphrag.retrievers import VectorCypherRetriever
from neo4j_graphrag.types import RetrieverResultItem

URI = "neo4j://localhost:7687"
AUTH = ("neo4j", "password")
DATABASE = "neo4j"
# Define database credentials
URI = "neo4j+s://demo.neo4jlabs.com"
AUTH = ("recommendations", "recommendations")
DATABASE = "recommendations"
INDEX = "moviePlotsEmbedding"


Expand Down
85 changes: 85 additions & 0 deletions examples/question_answering/graphrag_with_message_history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""End to end example of building a RAG pipeline backed by a Neo4j database,
simulating a chat with message history feature.
Requires OPENAI_API_KEY to be in the env var.
"""

import neo4j
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
from neo4j_graphrag.generation import GraphRAG
from neo4j_graphrag.llm import OpenAILLM
from neo4j_graphrag.retrievers import VectorCypherRetriever

# Define database credentials
URI = "neo4j+s://demo.neo4jlabs.com"
AUTH = ("recommendations", "recommendations")
DATABASE = "recommendations"
INDEX = "moviePlotsEmbedding"


driver = neo4j.GraphDatabase.driver(
URI,
auth=AUTH,
)

embedder = OpenAIEmbeddings()

retriever = VectorCypherRetriever(
driver,
index_name=INDEX,
retrieval_query="""
WITH node as movie, score
CALL(movie) {
MATCH (movie)<-[:ACTED_IN]-(p:Person)
RETURN collect(p.name) as actors
}
CALL(movie) {
MATCH (movie)<-[:DIRECTED]-(p:Person)
RETURN collect(p.name) as directors
}
RETURN movie.title as title, movie.plot as plot, movie.year as year, actors, directors
""",
embedder=embedder,
neo4j_database=DATABASE,
)

llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})

rag = GraphRAG(
retriever=retriever,
llm=llm,
)

questions = [
"Who starred in the Apollo 13 movies?",
"Who was its director?",
"In which year was this movie released?",
]

history: list[dict[str, str]] = []
for question in questions:
result = rag.search(
question,
return_context=False,
message_history=history, # type: ignore
)

answer = result.answer
print("#" * 50, question)
print(answer)
print("#" * 50)

history.append(
{
"role": "user",
"content": question,
}
)
history.append(
{
"role": "assistant",
"content": answer,
}
)

driver.close()
10 changes: 7 additions & 3 deletions src/neo4j_graphrag/generation/graphrag.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,11 @@ def search(
)
logger.debug(f"RAG: retriever_result={retriever_result}")
logger.debug(f"RAG: prompt={prompt}")
answer = self.llm.invoke(prompt, message_history)
answer = self.llm.invoke(
prompt,
message_history,
system_instruction=self.prompt_template.system_instructions,
)
result: dict[str, Any] = {"answer": answer.content}
if return_context:
result["retriever_result"] = retriever_result
Expand Down Expand Up @@ -172,9 +176,9 @@ def chat_summary_prompt(self, message_history: list[LLMMessage]) -> str:

def conversation_prompt(self, summary: str, current_query: str) -> str:
return f"""
Message Summary:
Message Summary:
{summary}
Current Query:
Current Query:
{current_query}
"""
10 changes: 7 additions & 3 deletions src/neo4j_graphrag/generation/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,21 @@ class PromptTemplate:
missing, a `PromptMissingInputError` is raised.
"""

DEFAULT_SYSTEM_INSTRUCTIONS: str = ""
DEFAULT_TEMPLATE: str = ""
EXPECTED_INPUTS: list[str] = list()

def __init__(
self,
template: Optional[str] = None,
expected_inputs: Optional[list[str]] = None,
system_instructions: Optional[str] = None,
) -> None:
self.template = template or self.DEFAULT_TEMPLATE
self.expected_inputs = expected_inputs or self.EXPECTED_INPUTS
self.system_instructions = (
system_instructions or self.DEFAULT_SYSTEM_INSTRUCTIONS
)

for e in self.expected_inputs:
if f"{{{e}}}" not in self.template:
Expand Down Expand Up @@ -88,9 +93,8 @@ def format(self, *args: Any, **kwargs: Any) -> str:


class RagTemplate(PromptTemplate):
DEFAULT_TEMPLATE = """Answer the user question using the following context
Context:
DEFAULT_SYSTEM_INSTRUCTIONS = "Answer the user question using the provided context."
DEFAULT_TEMPLATE = """Context:
{context}
Examples:
Expand Down
21 changes: 5 additions & 16 deletions src/neo4j_graphrag/llm/anthropic_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def __init__(
self,
model_name: str,
model_params: Optional[dict[str, Any]] = None,
system_instruction: Optional[str] = None,
**kwargs: Any,
):
try:
Expand All @@ -71,7 +70,7 @@ def __init__(
"""Could not import Anthropic Python client.
Please install it with `pip install "neo4j-graphrag[anthropic]"`."""
)
super().__init__(model_name, model_params, system_instruction)
super().__init__(model_name, model_params)
self.anthropic = anthropic
self.client = anthropic.Anthropic(**kwargs)
self.async_client = anthropic.AsyncAnthropic(**kwargs)
Expand Down Expand Up @@ -107,18 +106,13 @@ def invoke(
"""
try:
messages = self.get_messages(input, message_history)
system_message = (
system_instruction
if system_instruction is not None
else self.system_instruction
)
response = self.client.messages.create(
model=self.model_name,
system=system_message, # type: ignore
system=system_instruction or self.anthropic.NOT_GIVEN,
messages=messages,
**self.model_params,
)
return LLMResponse(content=response.content) # type: ignore
return LLMResponse(content=response.content)
except self.anthropic.APIError as e:
raise LLMGenerationError(e)

Expand All @@ -140,17 +134,12 @@ async def ainvoke(
"""
try:
messages = self.get_messages(input, message_history)
system_message = (
system_instruction
if system_instruction is not None
else self.system_instruction
)
response = await self.async_client.messages.create(
model=self.model_name,
system=system_message, # type: ignore
system=system_instruction or self.anthropic.NOT_GIVEN,
messages=messages,
**self.model_params,
)
return LLMResponse(content=response.content) # type: ignore
return LLMResponse(content=response.content)
except self.anthropic.APIError as e:
raise LLMGenerationError(e)
Loading

0 comments on commit 81d1f31

Please sign in to comment.