Skip to content

Commit

Permalink
[Security Assistant] Abort signal fix (elastic#203041)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephmilovic authored and SoniaSanzV committed Dec 9, 2024
1 parent ddbf94d commit 02df3c9
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ export class ActionsClientChatVertexAI extends ChatVertexAI {
tools: data?.tools,
temperature: this.temperature,
...systemInstruction,
signal: options?.signal,
},
},
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ export class ActionsClientChatConnection<Auth> extends ChatConnection<Auth> {
tools: data?.tools,
temperature: this.temperature,
...systemInstruction,
signal: options?.signal,
},
},
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ export interface GetDefaultAssistantGraphParams {
dataClients?: AssistantDataClients;
createLlmInstance: () => BaseChatModel;
logger: Logger;
signal?: AbortSignal;
tools: StructuredTool[];
replacements: Replacements;
}
Expand All @@ -45,6 +46,8 @@ export const getDefaultAssistantGraph = ({
dataClients,
createLlmInstance,
logger,
// some chat models (bedrock) require a signal to be passed on agent invoke rather than the signal passed to the chat model
signal,
tools,
replacements,
}: GetDefaultAssistantGraphParams) => {
Expand Down Expand Up @@ -137,11 +140,19 @@ export const getDefaultAssistantGraph = ({
})
)
.addNode(NodeType.AGENT, (state: AgentState) =>
runAgent({ ...nodeParams, state, agentRunnable, kbDataClient: dataClients?.kbDataClient })
runAgent({
...nodeParams,
config: { signal },
state,
agentRunnable,
kbDataClient: dataClients?.kbDataClient,
})
)
.addNode(NodeType.TOOLS, (state: AgentState) =>
executeTools({ ...nodeParams, config: { signal }, state, tools })
)
.addNode(NodeType.TOOLS, (state: AgentState) => executeTools({ ...nodeParams, state, tools }))
.addNode(NodeType.RESPOND, (state: AgentState) =>
respond({ ...nodeParams, state, model: createLlmInstance() })
respond({ ...nodeParams, config: { signal }, state, model: createLlmInstance() })
)
.addNode(NodeType.MODEL_INPUT, (state: AgentState) => modelInput({ ...nodeParams, state }))
.addEdge(START, NodeType.MODEL_INPUT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,7 @@ export const streamGraph = async ({
finalMessage += msg.content;
}
} else if (event.event === 'on_llm_end' && !didEnd) {
const generations = event.data.output?.generations[0];
if (generations && generations[0]?.generationInfo.finish_reason === 'stop') {
handleStreamEnd(generations[0]?.text ?? finalMessage);
}
handleStreamEnd(event.data.output?.generations[0][0]?.text ?? finalMessage);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
// we need to pass it like this or streaming does not work for bedrock
createLlmInstance,
logger,
signal: abortSignal,
tools,
replacements,
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,19 @@

import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { StringWithAutocomplete } from '@langchain/core/dist/utils/types';
import { RunnableConfig } from '@langchain/core/runnables';
import { AGENT_NODE_TAG } from './run_agent';
import { AgentState, NodeParamsBase } from '../types';
import { NodeType } from '../constants';

export interface RespondParams extends NodeParamsBase {
state: AgentState;
model: BaseChatModel;
config?: RunnableConfig;
}

export async function respond({
config,
logger,
state,
model,
Expand All @@ -34,7 +37,7 @@ export async function respond({

const responseMessage = await model
// use AGENT_NODE_TAG to identify as agent node for stream parsing
.withConfig({ runName: 'Summarizer', tags: [AGENT_NODE_TAG] })
.withConfig({ runName: 'Summarizer', tags: [AGENT_NODE_TAG], signal: config?.signal })
.invoke([userMessage]);

return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,22 @@ export async function runAgent({
logger.debug(() => `${NodeType.AGENT}: Node state:\n${JSON.stringify(state, null, 2)}`);

const knowledgeHistory = await kbDataClient?.getRequiredKnowledgeBaseDocumentEntries();

const agentOutcome = await agentRunnable.withConfig({ tags: [AGENT_NODE_TAG] }).invoke(
{
...state,
knowledge_history: `${KNOWLEDGE_HISTORY_PREFIX}\n${
knowledgeHistory?.length
? JSON.stringify(knowledgeHistory.map((e) => e.text))
: NO_KNOWLEDGE_HISTORY
}`,
// prepend any user prompt (gemini)
input: formatLatestUserMessage(state.input, state.llmType),
chat_history: state.messages, // TODO: Message de-dupe with ...state spread
},
config
);
const agentOutcome = await agentRunnable
.withConfig({ tags: [AGENT_NODE_TAG], signal: config?.signal })
.invoke(
{
...state,
knowledge_history: `${KNOWLEDGE_HISTORY_PREFIX}\n${
knowledgeHistory?.length
? JSON.stringify(knowledgeHistory.map((e) => e.text))
: NO_KNOWLEDGE_HISTORY
}`,
// prepend any user prompt (gemini)
input: formatLatestUserMessage(state.input, state.llmType),
chat_history: state.messages, // TODO: Message de-dupe with ...state spread
},
config
);

return {
agentOutcome,
Expand Down

0 comments on commit 02df3c9

Please sign in to comment.