Skip to content

Commit 81d1f31

Browse files
authored
More on LLM message history and system instruction (#240)
* Add examples for LLM message history and system instructions * Fix for when system_message is None * Ruff * Better (?) system prompt for RAG * Update system_instruction * Mypy * Mypy * Fix ollama test * Fix anthropic test * Fix cohere test * Fix vertexai test * Fix mistralai test * Fix graphrag test * Ruff * Mypy * Variable is not used * Ruff... * Mypy + e2e tests * Ruffffffff * CHANGELOG * Fix examples * Remove useless commented api_key from examples
1 parent feeddbb commit 81d1f31

30 files changed

+381
-299
lines changed

CHANGELOG.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
### Added
66
- Support for conversations with message history, including a new `message_history` parameter for LLM interactions.
7-
- Ability to include system instructions and override them for specific invocations.
8-
- Summarization of chat history to enhance query embedding and context handling.
7+
- Ability to include system instructions in LLM invoke method.
8+
- Summarization of chat history to enhance query embedding and context handling in GraphRAG.
99

1010
### Changed
1111
- Updated LLM implementations to handle message history consistently across providers.

examples/README.md

+4
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ are listed in [the last section of this file](#customize).
5151
## Answer: GraphRAG
5252

5353
- [End to end GraphRAG](./answer/graphrag.py)
54+
- [GraphRAG with message history](./question_answering/graphrag_with_message_history.py)
5455

5556

5657
## Customize
@@ -73,6 +74,9 @@ are listed in [the last section of this file](#customize).
7374
- [Ollama](./customize/llms/ollama_llm.py)
7475
- [Custom LLM](./customize/llms/custom_llm.py)
7576

77+
- [Message history](./customize/llms/llm_with_message_history.py)
78+
- [System Instruction](./customize/llms/llm_with_system_instructions.py)
79+
7680

7781
### Prompts
7882

examples/customize/embeddings/cohere_embeddings.py

-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
# set api key here on in the CO_API_KEY env var
44
api_key = None
5-
# api_key = "sk-..."
65

76
embeder = CohereEmbeddings(
87
model="embed-english-v3.0",

examples/customize/embeddings/mistalai_embeddings.py

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
# set api key here on in the MISTRAL_API_KEY env var
88
api_key = None
9-
# api_key = "sk-..."
109

1110
embeder = MistralAIEmbeddings(model="mistral-embed", api_key=api_key)
1211
res = embeder.embed_query("my question")

examples/customize/embeddings/openai_embeddings.py

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
# set api key here on in the OPENAI_API_KEY env var
88
api_key = None
9-
# api_key = "sk-..."
109

1110
embeder = OpenAIEmbeddings(model="text-embedding-ada-002", api_key=api_key)
1211
res = embeder.embed_query("my question")

examples/customize/llms/anthropic_llm.py

-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
# set api key here on in the ANTHROPIC_API_KEY env var
44
api_key = None
5-
# api_key = "sk-..."
6-
75

86
llm = AnthropicLLM(
97
model_name="claude-3-opus-20240229",

examples/customize/llms/cohere_llm.py

-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
# set api key here on in the CO_API_KEY env var
44
api_key = None
5-
# api_key = "sk-..."
65

76
llm = CohereLLM(
87
model_name="command-r",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
"""This example illustrates the message_history feature
2+
of the LLMInterface by mocking a conversation between a user
3+
and an LLM about Tom Hanks.
4+
5+
OpenAILLM can be replaced by any supported LLM from this package.
6+
"""
7+
8+
from neo4j_graphrag.llm import LLMResponse, OpenAILLM
9+
10+
# set api key here on in the OPENAI_API_KEY env var
11+
api_key = None
12+
13+
llm = OpenAILLM(model_name="gpt-4o", api_key=api_key)
14+
15+
questions = [
16+
"What are some movies Tom Hanks starred in?",
17+
"Is he also a director?",
18+
"Wow, that's impressive. And what about his personal life, does he have children?",
19+
]
20+
21+
history: list[dict[str, str]] = []
22+
for question in questions:
23+
res: LLMResponse = llm.invoke(
24+
question,
25+
message_history=history, # type: ignore
26+
)
27+
history.append(
28+
{
29+
"role": "user",
30+
"content": question,
31+
}
32+
)
33+
history.append(
34+
{
35+
"role": "assistant",
36+
"content": res.content,
37+
}
38+
)
39+
40+
print("#" * 50, question)
41+
print(res.content)
42+
print("#" * 50)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
"""This example illustrates how to set system instructions for LLM.
2+
3+
OpenAILLM can be replaced by any supported LLM from this package.
4+
"""
5+
6+
from neo4j_graphrag.llm import LLMResponse, OpenAILLM
7+
8+
# set api key here on in the OPENAI_API_KEY env var
9+
api_key = None
10+
11+
llm = OpenAILLM(
12+
model_name="gpt-4o",
13+
api_key=api_key,
14+
)
15+
16+
question = "How fast is Santa Claus during the Christmas eve?"
17+
18+
res: LLMResponse = llm.invoke(
19+
question,
20+
system_instruction="Answer with a serious tone",
21+
)
22+
print(res.content)

examples/customize/llms/mistalai_llm.py

-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
# set api key here on in the MISTRAL_API_KEY env var
44
api_key = None
5-
# api_key = "sk-..."
6-
75

86
llm = MistralAILLM(
97
model_name="mistral-small-latest",

examples/customize/llms/openai_llm.py

-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
# set api key here on in the OPENAI_API_KEY env var
44
api_key = None
5-
# api_key = "sk-..."
65

76
llm = OpenAILLM(model_name="gpt-4o", api_key=api_key)
87
res: LLMResponse = llm.invoke("say something")

examples/question_answering/graphrag.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
from neo4j_graphrag.retrievers import VectorCypherRetriever
1717
from neo4j_graphrag.types import RetrieverResultItem
1818

19-
URI = "neo4j://localhost:7687"
20-
AUTH = ("neo4j", "password")
21-
DATABASE = "neo4j"
19+
# Define database credentials
20+
URI = "neo4j+s://demo.neo4jlabs.com"
21+
AUTH = ("recommendations", "recommendations")
22+
DATABASE = "recommendations"
2223
INDEX = "moviePlotsEmbedding"
2324

2425

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""End to end example of building a RAG pipeline backed by a Neo4j database,
2+
simulating a chat with message history feature.
3+
4+
Requires OPENAI_API_KEY to be in the env var.
5+
"""
6+
7+
import neo4j
8+
from neo4j_graphrag.embeddings.openai import OpenAIEmbeddings
9+
from neo4j_graphrag.generation import GraphRAG
10+
from neo4j_graphrag.llm import OpenAILLM
11+
from neo4j_graphrag.retrievers import VectorCypherRetriever
12+
13+
# Define database credentials
14+
URI = "neo4j+s://demo.neo4jlabs.com"
15+
AUTH = ("recommendations", "recommendations")
16+
DATABASE = "recommendations"
17+
INDEX = "moviePlotsEmbedding"
18+
19+
20+
driver = neo4j.GraphDatabase.driver(
21+
URI,
22+
auth=AUTH,
23+
)
24+
25+
embedder = OpenAIEmbeddings()
26+
27+
retriever = VectorCypherRetriever(
28+
driver,
29+
index_name=INDEX,
30+
retrieval_query="""
31+
WITH node as movie, score
32+
CALL(movie) {
33+
MATCH (movie)<-[:ACTED_IN]-(p:Person)
34+
RETURN collect(p.name) as actors
35+
}
36+
CALL(movie) {
37+
MATCH (movie)<-[:DIRECTED]-(p:Person)
38+
RETURN collect(p.name) as directors
39+
}
40+
RETURN movie.title as title, movie.plot as plot, movie.year as year, actors, directors
41+
""",
42+
embedder=embedder,
43+
neo4j_database=DATABASE,
44+
)
45+
46+
llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0})
47+
48+
rag = GraphRAG(
49+
retriever=retriever,
50+
llm=llm,
51+
)
52+
53+
questions = [
54+
"Who starred in the Apollo 13 movies?",
55+
"Who was its director?",
56+
"In which year was this movie released?",
57+
]
58+
59+
history: list[dict[str, str]] = []
60+
for question in questions:
61+
result = rag.search(
62+
question,
63+
return_context=False,
64+
message_history=history, # type: ignore
65+
)
66+
67+
answer = result.answer
68+
print("#" * 50, question)
69+
print(answer)
70+
print("#" * 50)
71+
72+
history.append(
73+
{
74+
"role": "user",
75+
"content": question,
76+
}
77+
)
78+
history.append(
79+
{
80+
"role": "assistant",
81+
"content": answer,
82+
}
83+
)
84+
85+
driver.close()

src/neo4j_graphrag/generation/graphrag.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,11 @@ def search(
137137
)
138138
logger.debug(f"RAG: retriever_result={retriever_result}")
139139
logger.debug(f"RAG: prompt={prompt}")
140-
answer = self.llm.invoke(prompt, message_history)
140+
answer = self.llm.invoke(
141+
prompt,
142+
message_history,
143+
system_instruction=self.prompt_template.system_instructions,
144+
)
141145
result: dict[str, Any] = {"answer": answer.content}
142146
if return_context:
143147
result["retriever_result"] = retriever_result
@@ -172,9 +176,9 @@ def chat_summary_prompt(self, message_history: list[LLMMessage]) -> str:
172176

173177
def conversation_prompt(self, summary: str, current_query: str) -> str:
174178
return f"""
175-
Message Summary:
179+
Message Summary:
176180
{summary}
177181
178-
Current Query:
182+
Current Query:
179183
{current_query}
180184
"""

src/neo4j_graphrag/generation/prompts.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,21 @@ class PromptTemplate:
3232
missing, a `PromptMissingInputError` is raised.
3333
"""
3434

35+
DEFAULT_SYSTEM_INSTRUCTIONS: str = ""
3536
DEFAULT_TEMPLATE: str = ""
3637
EXPECTED_INPUTS: list[str] = list()
3738

3839
def __init__(
3940
self,
4041
template: Optional[str] = None,
4142
expected_inputs: Optional[list[str]] = None,
43+
system_instructions: Optional[str] = None,
4244
) -> None:
4345
self.template = template or self.DEFAULT_TEMPLATE
4446
self.expected_inputs = expected_inputs or self.EXPECTED_INPUTS
47+
self.system_instructions = (
48+
system_instructions or self.DEFAULT_SYSTEM_INSTRUCTIONS
49+
)
4550

4651
for e in self.expected_inputs:
4752
if f"{{{e}}}" not in self.template:
@@ -88,9 +93,8 @@ def format(self, *args: Any, **kwargs: Any) -> str:
8893

8994

9095
class RagTemplate(PromptTemplate):
91-
DEFAULT_TEMPLATE = """Answer the user question using the following context
92-
93-
Context:
96+
DEFAULT_SYSTEM_INSTRUCTIONS = "Answer the user question using the provided context."
97+
DEFAULT_TEMPLATE = """Context:
9498
{context}
9599
96100
Examples:

src/neo4j_graphrag/llm/anthropic_llm.py

+5-16
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def __init__(
6161
self,
6262
model_name: str,
6363
model_params: Optional[dict[str, Any]] = None,
64-
system_instruction: Optional[str] = None,
6564
**kwargs: Any,
6665
):
6766
try:
@@ -71,7 +70,7 @@ def __init__(
7170
"""Could not import Anthropic Python client.
7271
Please install it with `pip install "neo4j-graphrag[anthropic]"`."""
7372
)
74-
super().__init__(model_name, model_params, system_instruction)
73+
super().__init__(model_name, model_params)
7574
self.anthropic = anthropic
7675
self.client = anthropic.Anthropic(**kwargs)
7776
self.async_client = anthropic.AsyncAnthropic(**kwargs)
@@ -107,18 +106,13 @@ def invoke(
107106
"""
108107
try:
109108
messages = self.get_messages(input, message_history)
110-
system_message = (
111-
system_instruction
112-
if system_instruction is not None
113-
else self.system_instruction
114-
)
115109
response = self.client.messages.create(
116110
model=self.model_name,
117-
system=system_message, # type: ignore
111+
system=system_instruction or self.anthropic.NOT_GIVEN,
118112
messages=messages,
119113
**self.model_params,
120114
)
121-
return LLMResponse(content=response.content) # type: ignore
115+
return LLMResponse(content=response.content)
122116
except self.anthropic.APIError as e:
123117
raise LLMGenerationError(e)
124118

@@ -140,17 +134,12 @@ async def ainvoke(
140134
"""
141135
try:
142136
messages = self.get_messages(input, message_history)
143-
system_message = (
144-
system_instruction
145-
if system_instruction is not None
146-
else self.system_instruction
147-
)
148137
response = await self.async_client.messages.create(
149138
model=self.model_name,
150-
system=system_message, # type: ignore
139+
system=system_instruction or self.anthropic.NOT_GIVEN,
151140
messages=messages,
152141
**self.model_params,
153142
)
154-
return LLMResponse(content=response.content) # type: ignore
143+
return LLMResponse(content=response.content)
155144
except self.anthropic.APIError as e:
156145
raise LLMGenerationError(e)

0 commit comments

Comments
 (0)