From ae400c5028b0a9e773475a9e00fcf0328bdaef24 Mon Sep 17 00:00:00 2001 From: jacoblee93 Date: Thu, 30 Nov 2023 18:14:22 -0800 Subject: [PATCH] Replace with conversational retrieval chain --- package.json | 13 +-- .../conversational_retrieval_chain/index.ts | 97 +++++++++++++++++++ .../execute_custom_llm_chain/index.ts | 62 ++++-------- 3 files changed, 117 insertions(+), 55 deletions(-) create mode 100644 x-pack/plugins/elastic_assistant/server/lib/langchain/conversational_retrieval_chain/index.ts diff --git a/package.json b/package.json index ec394f7aeab6b..db7953c097f58 100644 --- a/package.json +++ b/package.json @@ -78,15 +78,6 @@ "yarn": "^1.22.19" }, "resolutions": { - "**/@hello-pangea/dnd": "16.2.0", - "**/@types/node": "18.18.5", - "**/@typescript-eslint/utils": "5.62.0", - "**/chokidar": "^3.5.3", - "**/globule/minimatch": "^3.1.2", - "**/hoist-non-react-statics": "^3.3.2", - "**/isomorphic-fetch/node-fetch": "^2.6.7", - "**/remark-parse/trim": "1.0.1", - "**/typescript": "4.7.4", "globby/fast-glob": "^3.2.11" }, "dependencies": { @@ -959,7 +950,7 @@ "jsonwebtoken": "^9.0.0", "jsts": "^1.6.2", "kea": "^2.6.0", - "langchain": "^0.0.197-rc.1", + "langchain": "^0.0.199", "langsmith": "^0.0.48", "launchdarkly-js-client-sdk": "^3.1.4", "launchdarkly-node-server-sdk": "^7.0.3", @@ -1638,4 +1629,4 @@ "yargs": "^15.4.1", "yarn-deduplicate": "^6.0.2" } -} \ No newline at end of file +} diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/conversational_retrieval_chain/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/conversational_retrieval_chain/index.ts new file mode 100644 index 0000000000000..05c09bbcb6ae3 --- /dev/null +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/conversational_retrieval_chain/index.ts @@ -0,0 +1,97 @@ +import { ChatPromptTemplate, MessagesPlaceholder } from 'langchain/prompts'; +import { + RunnableBranch, + RunnableSequence, +} from 'langchain/runnables'; +import type { VectorStoreRetriever } from 'langchain/vectorstores/base'; +import type { BaseLanguageModel } from 'langchain/base_language'; +import type { BaseMessage } from 'langchain/schema'; +import { Document } from 'langchain/document'; +import { StringOutputParser } from 'langchain/schema/output_parser'; + +const CONDENSE_QUESTION_SYSTEM_TEMPLATE = `You are an experienced researcher, expert at interpreting and answering questions based on provided sources. +Your job is to remove references to chat history from incoming questions, rephrasing them as standalone questions.`; + +const CONDENSE_QUESTION_HUMAN_TEMPLATE = `Using only previous conversation as context, rephrase the following question to be a standalone question. + +Do not respond with anything other than a rephrased standalone question. Be concise, but complete and resolve all references to the chat history. + + + {question} +`; +const condenseQuestionPrompt = ChatPromptTemplate.fromMessages([ + ['system', CONDENSE_QUESTION_SYSTEM_TEMPLATE], + new MessagesPlaceholder('chat_history'), + ['human', CONDENSE_QUESTION_HUMAN_TEMPLATE], +]); + +const ANSWER_SYSTEM_TEMPLATE = `You are an experienced researcher, expert at interpreting and answering questions based on provided sources. +Using the provided context, answer the user's question to the best of your ability using only the resources provided. +You must only use information from the provided search results. +If there is no information in the context relevant to the question at hand, just say "Hmm, I'm not sure." +Anything between the following \`context\` html blocks is retrieved from a knowledge bank, not part of the conversation with the user. + + + {context} +`; + +const ANSWER_HUMAN_TEMPLATE = `Answer the following question to the best of your ability: + +{standalone_question}`; + +const answerPrompt = ChatPromptTemplate.fromMessages([ + ["system", ANSWER_SYSTEM_TEMPLATE], + new MessagesPlaceholder("chat_history"), + ["human", ANSWER_HUMAN_TEMPLATE], +]); + +const formatDocuments = (docs: Document[]) => { + return docs + .map((doc, i) => { + return `\n${doc.pageContent}\n`; + }) + .join("\n"); +}; + +export function createConversationalRetrievalChain({ + model, + retriever, +}: { + model: BaseLanguageModel; + retriever: VectorStoreRetriever; +}) { + const retrievalChain = RunnableSequence.from([ + (input) => input.standalone_question, + retriever, + formatDocuments, + ]).withConfig({ runName: "RetrievalChain" }); + + const standaloneQuestionChain = RunnableSequence.from([ + condenseQuestionPrompt, + model, + new StringOutputParser(), + ]).withConfig({ runName: "RephraseQuestionChain" }); + + const answerChain = RunnableSequence.from([ + { + standalone_question: (input) => input.standalone_question, + chat_history: (input) => input.chat_history, + context: retrievalChain, + }, + answerPrompt, + model, + ]).withConfig({ runName: "AnswerGenerationChain" }); + + const conversationalRetrievalChain = RunnableSequence.from<{question: string, chat_history: BaseMessage[]}>([ + { + // Small optimization - only rephrase if the question is a followup + standalone_question: RunnableBranch.from([ + [(input) => input.chat_history.length > 0, standaloneQuestionChain], + (input) => input.question, + ]), + chat_history: (input) => input.chat_history, + }, + answerChain, + ]); + return conversationalRetrievalChain; +} \ No newline at end of file diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts index 932b81b93f61f..d8f3d3db2dacd 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts @@ -5,21 +5,19 @@ * 2.0. */ -import { initializeAgentExecutorWithOptions } from 'langchain/agents'; -import { RetrievalQAChain } from 'langchain/chains'; -import { BufferMemory, ChatMessageHistory } from 'langchain/memory'; -import { ChainTool, Tool } from 'langchain/tools'; import { PassThrough, Readable } from 'stream'; import { ActionsClientLlm } from '../llm/actions_client_llm'; import { ElasticsearchStore } from '../elasticsearch_store/elasticsearch_store'; import { KNOWLEDGE_BASE_INDEX_PATTERN } from '../../../routes/knowledge_base/constants'; import type { AgentExecutorParams, AgentExecutorResponse } from '../executors/types'; +import { createConversationalRetrievalChain } from "../conversational_retrieval_chain/index"; +import { HttpResponseOutputParser } from "langchain/output_parsers"; export const DEFAULT_AGENT_EXECUTOR_ID = 'Elastic AI Assistant Agent Executor'; /** - * The default agent executor used by the Elastic AI Assistant. Main agent/chain that wraps the ActionsClientLlm, - * sets up a conversation BufferMemory from chat history, and registers tools like the ESQLKnowledgeBaseTool. + * Use an implementation of a ConversationalRetrievalChain to generate + * output based on retrieved documents. * */ export const callAgentExecutor = async ({ @@ -44,15 +42,7 @@ export const callAgentExecutor = async ({ }); const pastMessages = langChainMessages.slice(0, -1); // all but the last message - const latestMessage = langChainMessages.slice(-1); // the last message - - const memory = new BufferMemory({ - chatHistory: new ChatMessageHistory(pastMessages), - memoryKey: 'chat_history', // this is the key expected by https://github.com/langchain-ai/langchainjs/blob/a13a8969345b0f149c1ca4a120d63508b06c52a5/langchain/src/agents/initialize.ts#L166 - inputKey: 'input', - outputKey: 'output', - returnMessages: true, - }); + const latestMessage = langChainMessages.slice(-1)[0]; // the last message // ELSER backed ElasticsearchStore for Knowledge Base const esStore = new ElasticsearchStore( @@ -69,20 +59,15 @@ export const callAgentExecutor = async ({ 'Please ensure ELSER is configured to use the Knowledge Base, otherwise disable the Knowledge Base in Advanced Settings to continue.' ); } - - // Create a chain that uses the ELSER backed ElasticsearchStore, override k=10 for esql query generation for now - const chain = RetrievalQAChain.fromLLM(llm, esStore.asRetriever(10)); - - // TODO: Dependency inject these tools - const tools: Tool[] = [ - new ChainTool({ - name: 'ESQLKnowledgeBaseTool', - description: - 'Call this for knowledge on how to build an ESQL query, or answer questions about the ES|QL query language.', - chain, - tags: ['esql', 'query-generation', 'knowledge-base'], - }), - ]; + + // Create a retriever that uses the ELSER backed ElasticsearchStore, override k=10 for esql query generation for now + const retriever = esStore.asRetriever(10); + + const chain = createConversationalRetrievalChain({ + model: llm, + retriever, + }); + const chainWithOutputParser = chain.pipe(new HttpResponseOutputParser({ contentType: "text/plain" })); // // Sets up tracer for tracing executions to APM. See x-pack/plugins/elastic_assistant/server/lib/langchain/tracers/README.mdx // // If LangSmith env vars are set, executions will be traced there as well. See https://docs.smith.langchain.com/tracing @@ -111,24 +96,13 @@ export const callAgentExecutor = async ({ // ); // }); - const executor = await initializeAgentExecutorWithOptions(tools, llm, { - agentType: 'chat-conversational-react-description', - // agentType: 'zero-shot-react-description', - returnIntermediateSteps: true, - memory, - verbose: true, - }); console.log('WE ARE HERE before stream call'); - const resp = await executor.stream({ input: latestMessage[0].content, chat_history: [] }); - const textEncoder = new TextEncoder(); - async function* generate() { - for await (const chunk of resp) { - console.log('WE ARE HERE CHUNK', chunk); - yield textEncoder.encode(JSON.stringify(chunk)); - } + if (typeof latestMessage.content !== "string") { + throw new Error("Multimodal messages not supported."); } + const stream = await chainWithOutputParser.stream({ question: latestMessage.content, chat_history: pastMessages }); - const readable = Readable.from(generate()); + const readable = Readable.from(stream); return readable.pipe(new PassThrough()); };