Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(community): Neo4j chat message history #7331

Merged
merged 7 commits into from
Dec 14, 2024
172 changes: 172 additions & 0 deletions libs/langchain-community/src/stores/message/neo4j.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import neo4j, { Driver, Record, Neo4jError, auth } from "neo4j-driver";
import { v4 as uuidv4 } from "uuid";
import { BaseListChatMessageHistory } from "@langchain/core/chat_history";
import {
BaseMessage,
mapStoredMessagesToChatMessages,
} from "@langchain/core/messages";

export type Neo4jChatMessageHistoryConfigInput = {
sessionId?: string | number;
sessionNodeLabel?: string;
messageNodeLabel?: string;
url: string;
username: string;
password: string;
windowSize?: number;
};

const defaultConfig = {
sessionNodeLabel: "ChatSession",
messageNodeLabel: "ChatMessage",
windowSize: 3,
};

export class Neo4jChatMessageHistory extends BaseListChatMessageHistory {
lc_namespace: string[] = ["langchain", "stores", "message", "neo4j"];
BernardFaucher marked this conversation as resolved.
Show resolved Hide resolved

sessionId: string | number;

sessionNodeLabel: string;

messageNodeLabel: string;

windowSize: number;

private driver: Driver;

constructor({
sessionId = uuidv4(),
sessionNodeLabel = defaultConfig.sessionNodeLabel,
messageNodeLabel = defaultConfig.messageNodeLabel,
url,
username,
password,
windowSize = defaultConfig.windowSize,
}: Neo4jChatMessageHistoryConfigInput) {
super();

this.sessionId = sessionId;
this.sessionNodeLabel = sessionNodeLabel;
this.messageNodeLabel = messageNodeLabel;
this.windowSize = windowSize;

if (url && username && password) {
try {
this.driver = neo4j.driver(url, auth.basic(username, password));
} catch (e) {
throw new Neo4jError({
message:
"Could not create a Neo4j driver instance. Please check the connection details.",
cause: e,
});
}
} else {
throw new Neo4jError({
message: "Neo4j connection details not provided.",
});
}
}

static async initialize(
props: Neo4jChatMessageHistoryConfigInput
): Promise<Neo4jChatMessageHistory> {
const instance = new Neo4jChatMessageHistory(props);

try {
await instance.verifyConnectivity();
} catch (e) {
throw new Neo4jError({
message: `Could not verify connection to the Neo4j database. Cause: ${e}`,
cause: e,
});
}

return instance;
}

async verifyConnectivity() {
const connectivity = await this.driver.getServerInfo();
return connectivity;
}

async getMessages(): Promise<BaseMessage[]> {
const getMessagesCypherQuery = `
MERGE (chatSession:${this.sessionNodeLabel} {id: $sessionId})
WITH chatSession
MATCH (chatSession)-[:LAST_MESSAGE]->(lastMessage)
MATCH p=(lastMessage)<-[:NEXT*0..${this.windowSize * 2 - 1}]-()
WITH p, length(p) AS length
ORDER BY length DESC LIMIT 1
UNWIND reverse(nodes(p)) AS node
RETURN {data:{content: node.content}, type:node.type} AS result
`;

try {
const { records } = await this.driver.executeQuery(
getMessagesCypherQuery,
{
sessionId: this.sessionId,
}
);
const results = records.map((record: Record) => record.get("result"));

return mapStoredMessagesToChatMessages(results);
} catch (e) {
throw new Neo4jError({
message: `Ohno! Couldn't get messages. Cause: ${e}`,
cause: e,
});
}
}

async addMessage(message: BaseMessage): Promise<void> {
const addMessageCypherQuery = `
MERGE (chatSession:${this.sessionNodeLabel} {id: $sessionId})
WITH chatSession
OPTIONAL MATCH (chatSession)-[lastMessageRel:LAST_MESSAGE]->(lastMessage)
CREATE (chatSession)-[:LAST_MESSAGE]->(newLastMessage:${this.messageNodeLabel})
SET newLastMessage += {type:$type, content:$content}
WITH newLastMessage, lastMessageRel, lastMessage
WHERE lastMessage IS NOT NULL
CREATE (lastMessage)-[:NEXT]->(newLastMessage)
DELETE lastMessageRel
`;

try {
await this.driver.executeQuery(addMessageCypherQuery, {
sessionId: this.sessionId,
type: message.getType(),
content: message.content,
});
} catch (e) {
throw new Neo4jError({
message: `Ohno! Couldn't add message. Cause: ${e}`,
cause: e,
});
}
}

async clear() {
const clearMessagesCypherQuery = `
MATCH p=(chatSession:${this.sessionNodeLabel} {id: $sessionId})-[:LAST_MESSAGE]->(lastMessage)<-[:NEXT*0..]-()
UNWIND nodes(p) as node
DETACH DELETE node
`;

try {
await this.driver.executeQuery(clearMessagesCypherQuery, {
sessionId: this.sessionId,
});
} catch (e) {
throw new Neo4jError({
message: `Ohno! Couldn't clear chat history. Cause: ${e}`,
cause: e,
});
}
}

async close() {
await this.driver.close();
}
}
138 changes: 138 additions & 0 deletions libs/langchain-community/src/stores/tests/neo4j.int.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import { describe, it, expect, beforeEach, afterEach } from "@jest/globals";
import { HumanMessage, AIMessage } from "@langchain/core/messages";
import neo4j from "neo4j-driver";
import { Neo4jChatMessageHistory } from "../message/neo4j.js";

const goodConfig = {
url: "bolt://host.docker.internal:7687",
username: "neo4j",
password: "langchain",
};

describe("The Neo4jChatMessageHistory class", () => {
describe("Test suite", () => {
it("Runs at all", () => {
expect(true).toEqual(true);
});
});

describe("Class instantiation", () => {
it("Requires a url, username and password, throwing an error if not provided", async () => {
const badConfig = {};
await expect(
// @ts-expect-error Bad config
Neo4jChatMessageHistory.initialize(badConfig)
).rejects.toThrow(neo4j.Neo4jError);
});

it("Creates a class instance from - at minimum - a url, username and password", async () => {
const instance = await Neo4jChatMessageHistory.initialize(goodConfig);
expect(instance).toBeInstanceOf(Neo4jChatMessageHistory);
await instance.close();
});

it("Class instances have expected, configurable fields, and sensible defaults", async () => {
const instance = await Neo4jChatMessageHistory.initialize(goodConfig);

expect(instance.sessionId).toBeDefined();
expect(instance.sessionNodeLabel).toEqual("ChatSession");
expect(instance.windowSize).toEqual(3);
expect(instance.messageNodeLabel).toEqual("ChatMessage");

const secondInstance = await Neo4jChatMessageHistory.initialize({
...goodConfig,
sessionId: "Shibboleet",
sessionNodeLabel: "Conversation",
messageNodeLabel: "Communication",
windowSize: 4,
});

expect(secondInstance.sessionId).toBeDefined();
expect(secondInstance.sessionId).toEqual("Shibboleet");
expect(instance.sessionId).not.toEqual(secondInstance.sessionId);
expect(secondInstance.sessionNodeLabel).toEqual("Conversation");
expect(secondInstance.messageNodeLabel).toEqual("Communication");
expect(secondInstance.windowSize).toEqual(4);

await instance.close();
await secondInstance.close();
});
});

describe("Core functionality", () => {
let instance: undefined | Neo4jChatMessageHistory;

beforeEach(async () => {
instance = await Neo4jChatMessageHistory.initialize(goodConfig);
});

afterEach(async () => {
await instance?.clear();
await instance?.close();
});

it("Connects verifiably to the underlying Neo4j database", async () => {
const connected = await instance?.verifyConnectivity();
expect(connected).toBeDefined();
});

it("getMessages()", async () => {
let results = await instance?.getMessages();
expect(results).toEqual([]);
const messages = [
new HumanMessage(
"My first name is a random set of numbers and letters"
),
new AIMessage("And other alphanumerics that changes hourly forever"),
new HumanMessage(
"My last name, a thousand vowels fading down a sinkhole to a susurrus"
),
new AIMessage("It couldn't just be John Doe or Bingo"),
new HumanMessage(
"My address, a made-up language written out in living glyphs"
),
new AIMessage("Lifted from demonic literature and religious text"),
new HumanMessage("Telephone: uncovered by purveyors of the ouija"),
new AIMessage("When checked against the CBGB women's room graffiti"),
new HumanMessage("My social: a sudoku"),
new AIMessage("My age is obscure"),
];
await instance?.addMessages(messages);
results = (await instance?.getMessages()) || [];
const windowSize = instance?.windowSize || 0;
expect(results.length).toEqual(windowSize * 2);
expect(results).toEqual(messages.slice(windowSize * -2));
});

it("addMessage()", async () => {
const messages = [
new HumanMessage("99 Bottles of beer on the wall, 99 bottles of beer!"),
new AIMessage(
"Take one down, pass it around, 98 bottles of beer on the wall."
),
new HumanMessage("How many bottles of beer are currently on the wall?"),
new AIMessage("There are currently 98 bottles of beer on the wall."),
];
for (const message of messages) {
await instance?.addMessage(message);
}
const results = await instance?.getMessages();
expect(results).toEqual(messages);
});

it("clear()", async () => {
const messages = [
new AIMessage("I'm not your enemy."),
new HumanMessage("That sounds like something that my enemy would say."),
new AIMessage("You're being difficult."),
new HumanMessage("I'm being guarded."),
];
await instance?.addMessages(messages);
let results = await instance?.getMessages();
expect(results).toEqual(messages);
await instance?.clear();
results = await instance?.getMessages();
expect(results).toEqual([]);
});
});
});
15 changes: 14 additions & 1 deletion test-int-deps-docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,17 @@ services:
qdrant:
image: qdrant/qdrant:v1.9.1
ports:
- 6333:6333
- 6333:6333
neo4j:
image: neo4j:latest
volumes:
- $HOME/neo4j/logs:/var/lib/neo4j/logs
- $HOME/neo4j/config:/var/lib/neo4j/config
- $HOME/neo4j/data:/var/lib/neo4j/data
- $HOME/neo4j/plugins:/var/lib/neo4j/plugins
environment:
- NEO4J_dbms_security_auth__enabled=false
ports:
- "7474:7474"
- "7687:7687"
restart: always
Loading