Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(community): Neo4j chat message history #7331

Merged
merged 7 commits into from
Dec 14, 2024
156 changes: 156 additions & 0 deletions libs/langchain-community/src/stores/message/neo4j.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import neo4j from "neo4j-driver";
import { Driver, Record, Neo4jError, auth } from "neo4j-driver";
import { v4 as uuidv4 } from "uuid";
import { BaseListChatMessageHistory } from "@langchain/core/chat_history";
import {
BaseMessage,
mapStoredMessagesToChatMessages,
} from "@langchain/core/messages";

export type Neo4jChatMessageHistoryConfigInput = {
sessionId?: string | number;
sessionNodeLabel?: string;
messageNodeLabel?: string;
url: string;
username: string;
password: string;
windowSize?: number;
}

const defaultConfig = {
sessionNodeLabel: "ChatSession",
messageNodeLabel: "ChatMessage",
windowSize: 3,
}

export class Neo4jChatMessageHistory extends BaseListChatMessageHistory {

lc_namespace: string[] = ["langchain", "stores", "message", "neo4j"];
BernardFaucher marked this conversation as resolved.
Show resolved Hide resolved
sessionId: string | number;
sessionNodeLabel: string;
messageNodeLabel: string;
windowSize: number;

#driver: Driver;
BernardFaucher marked this conversation as resolved.
Show resolved Hide resolved

constructor({
sessionId = uuidv4(),
sessionNodeLabel = defaultConfig.sessionNodeLabel,
messageNodeLabel = defaultConfig.messageNodeLabel,
url,
username,
password,
windowSize = defaultConfig.windowSize,
}: Neo4jChatMessageHistoryConfigInput) {
super()

this.sessionId = sessionId;
this.sessionNodeLabel = sessionNodeLabel;
this.messageNodeLabel = messageNodeLabel;
this.windowSize = windowSize;

if (url && username && password) {
try {
this.#driver = neo4j.driver(url, auth.basic(username, password));
BernardFaucher marked this conversation as resolved.
Show resolved Hide resolved
} catch (e) {
throw new Neo4jError({
message: "Could not create a Neo4j driver instance. Please check the connection details.",
cause: e
});
}
} else {
throw new Neo4jError({
message: "Neo4j connection details not provided."
});
}
}

async verifyConnectivity() {
const connectivity = await this.#driver.getServerInfo();
return connectivity
}

async getMessages(): Promise<BaseMessage[]> {
const getMessagesCypherQuery = `
MERGE (chatSession:${this.sessionNodeLabel} {id: $sessionId})
WITH chatSession
MATCH (chatSession)-[:LAST_MESSAGE]->(lastMessage)
MATCH p=(lastMessage)<-[:NEXT*0..${(this.windowSize * 2) - 1}]-()
WITH p, length(p) AS length
ORDER BY length DESC LIMIT 1
UNWIND reverse(nodes(p)) AS node
RETURN {data:{content: node.content}, type:node.type} AS result
`

try {
const messages = await this.#driver.session().run(
getMessagesCypherQuery,
{
sessionId: this.sessionId
}
)
const results = messages.records.map((record: Record) => record.get('result'))

return mapStoredMessagesToChatMessages(results)
} catch (e) {
throw new Neo4jError({
message: `Ohno! Couldn't get messages. Cause: ${e}`,
cause: e
});
}
}

async addMessage(message: BaseMessage): Promise<void> {
const addMessageCypherQuery = `
MERGE (chatSession:${this.sessionNodeLabel} {id: $sessionId})
WITH chatSession
OPTIONAL MATCH (chatSession)-[lastMessageRel:LAST_MESSAGE]->(lastMessage)
CREATE (chatSession)-[:LAST_MESSAGE]->(newLastMessage:${this.messageNodeLabel})
SET newLastMessage += {type:$type, content:$content}
WITH newLastMessage, lastMessageRel, lastMessage
WHERE lastMessage IS NOT NULL
CREATE (lastMessage)-[:NEXT]->(newLastMessage)
DELETE lastMessageRel
`

try {
await this.#driver.session().run(
addMessageCypherQuery,
{
sessionId: this.sessionId,
type: message.getType(),
content: message.content
})
} catch (e) {
throw new Neo4jError({
message: `Ohno! Couldn't add message. Cause: ${e}`,
cause: e
})
}
}

async clear() {
const clearMessagesCypherQuery = `
MATCH p=(chatSession:${this.sessionNodeLabel} {id: $sessionId})-[:LAST_MESSAGE]->(lastMessage)<-[:NEXT*0..]-()
UNWIND nodes(p) as node
DETACH DELETE node
`

try {
await this.#driver.session().run(
clearMessagesCypherQuery,
{
sessionId: this.sessionId,
})
} catch (e) {
throw new Neo4jError({
message: `Ohno! Couldn't clear chat history. Cause: ${e}`,
cause: e
})
}
}

async close() {
await this.#driver.close()
}
}
152 changes: 152 additions & 0 deletions libs/langchain-community/src/stores/tests/neo4j.int.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import { describe, it, expect, beforeEach, afterEach } from "@jest/globals";
import { Neo4jChatMessageHistory } from "../message/neo4j.js";
import { HumanMessage, AIMessage } from "@langchain/core/messages";
import neo4j from "neo4j-driver";

const goodConfig = {
url: 'bolt://host.docker.internal:7687',
username: 'neo4j',
password: 'langchain'
}

describe("The Neo4jChatMessageHistory class", () => {
describe("Test suite", () => {
it("Runs at all", () => {
expect(true).toEqual(true)
})
})

describe("Class instantiation", () => {
it(
"Requires a url, username and password, throwing an error if not provided",
() => {
const badInstantiation = () => {
const badConfig = {}
// @ts-ignore
const instance = new Neo4jChatMessageHistory(badConfig)
}
expect(badInstantiation).toThrow(neo4j.Neo4jError)
}
)

it(
"Creates a class instance from - at minimum - a url, username and password",
() => {
const instance = new Neo4jChatMessageHistory(goodConfig)
expect(instance).toBeInstanceOf(Neo4jChatMessageHistory)
}
)

it(
"Class instances have expected, configurable fields, and sensible defaults",
() => {
const instance = new Neo4jChatMessageHistory(goodConfig)

expect(instance.sessionId).toBeDefined()
expect(instance.sessionNodeLabel).toEqual("ChatSession")
expect(instance.windowSize).toEqual(3)
expect(instance.messageNodeLabel).toEqual("ChatMessage")

const secondInstance = new Neo4jChatMessageHistory({
...goodConfig,
sessionId: "Shibboleet",
sessionNodeLabel: "Conversation",
messageNodeLabel: "Communication",
windowSize: 4
})

expect(secondInstance.sessionId).toBeDefined()
expect(secondInstance.sessionId).toEqual("Shibboleet")
expect(instance.sessionId).not.toEqual(secondInstance.sessionId)
expect(secondInstance.sessionNodeLabel).toEqual("Conversation")
expect(secondInstance.messageNodeLabel).toEqual("Communication")
expect(secondInstance.windowSize).toEqual(4)
}
)
})

describe(
"Core functionality",
() => {

let instance: undefined | Neo4jChatMessageHistory;

beforeEach(() => {
instance = new Neo4jChatMessageHistory(goodConfig)
})

afterEach(async () => {
await instance?.clear()
await instance?.close()
})

it(
"Connects verifiably to the underlying Neo4j database",
async () => {
const connected = await instance?.verifyConnectivity()
expect(connected).toBeDefined()
}
)

it(
"getMessages()",
async () => {
let results = await instance?.getMessages()
expect(results).toEqual([])
const messages = [
new HumanMessage("My first name is a random set of numbers and letters"),
new AIMessage("And other alphanumerics that changes hourly forever"),
new HumanMessage("My last name, a thousand vowels fading down a sinkhole to a susurrus"),
new AIMessage("It couldn't just be John Doe or Bingo"),
new HumanMessage("My address, a made-up language written out in living glyphs"),
new AIMessage("Lifted from demonic literature and religious text"),
new HumanMessage("Telephone: uncovered by purveyors of the ouija"),
new AIMessage("When checked against the CBGB women's room graffiti"),
new HumanMessage("My social: a sudoku"),
new AIMessage("My age is obscure")
]
await instance?.addMessages(messages)
results = await instance?.getMessages() || []
const windowSize = instance?.windowSize || 0
expect(results.length).toEqual(windowSize * 2)
expect(results).toEqual(messages.slice(windowSize * -2))
}
)

it(
"addMessage()",
async () => {
const messages = [
new HumanMessage("99 Bottles of beer on the wall, 99 bottles of beer!"),
new AIMessage("Take one down, pass it around, 98 bottles of beer on the wall."),
new HumanMessage("How many bottles of beer are currently on the wall?"),
new AIMessage("There are currently 98 bottles of beer on the wall.")
]
for (let message of messages) {
await instance?.addMessage(message)
}
const results = await instance?.getMessages()
expect(results).toEqual(messages)
}
)

it(
"clear()",
async () => {
const messages = [
new AIMessage("I'm not your enemy."),
new HumanMessage("That sounds like something that my enemy would say."),
new AIMessage("You're being difficult."),
new HumanMessage("I'm being guarded.")
]
await instance?.addMessages(messages)
let results = await instance?.getMessages()
expect(results).toEqual(messages)
await instance?.clear()
results = await instance?.getMessages()
expect(results).toEqual([])
}
)
}
)
})
15 changes: 14 additions & 1 deletion test-int-deps-docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,17 @@ services:
qdrant:
image: qdrant/qdrant:v1.9.1
ports:
- 6333:6333
- 6333:6333
neo4j:
image: neo4j:latest
volumes:
- $HOME/neo4j/logs:/var/lib/neo4j/logs
- $HOME/neo4j/config:/var/lib/neo4j/config
- $HOME/neo4j/data:/var/lib/neo4j/data
- $HOME/neo4j/plugins:/var/lib/neo4j/plugins
environment:
- NEO4J_dbms_security_auth__enabled=false
ports:
- "7474:7474"
- "7687:7687"
restart: always