Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Security Assistant] Abort signal fix #203041

Merged
merged 6 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ export class ActionsClientChatVertexAI extends ChatVertexAI {
tools: data?.tools,
temperature: this.temperature,
...systemInstruction,
signal: options?.signal,
},
},
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ export class ActionsClientChatConnection<Auth> extends ChatConnection<Auth> {
tools: data?.tools,
temperature: this.temperature,
...systemInstruction,
signal: options?.signal,
},
},
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ export interface GetDefaultAssistantGraphParams {
dataClients?: AssistantDataClients;
createLlmInstance: () => BaseChatModel;
logger: Logger;
signal?: AbortSignal;
tools: StructuredTool[];
replacements: Replacements;
}
Expand All @@ -45,6 +46,8 @@ export const getDefaultAssistantGraph = ({
dataClients,
createLlmInstance,
logger,
// some chat models (bedrock) require a signal to be passed on agent invoke rather than the signal passed to the chat model
signal,
tools,
replacements,
}: GetDefaultAssistantGraphParams) => {
Expand Down Expand Up @@ -137,11 +140,19 @@ export const getDefaultAssistantGraph = ({
})
)
.addNode(NodeType.AGENT, (state: AgentState) =>
runAgent({ ...nodeParams, state, agentRunnable, kbDataClient: dataClients?.kbDataClient })
runAgent({
...nodeParams,
config: { signal },
state,
agentRunnable,
kbDataClient: dataClients?.kbDataClient,
})
)
.addNode(NodeType.TOOLS, (state: AgentState) =>
executeTools({ ...nodeParams, config: { signal }, state, tools })
)
.addNode(NodeType.TOOLS, (state: AgentState) => executeTools({ ...nodeParams, state, tools }))
.addNode(NodeType.RESPOND, (state: AgentState) =>
respond({ ...nodeParams, state, model: createLlmInstance() })
respond({ ...nodeParams, config: { signal }, state, model: createLlmInstance() })
)
.addNode(NodeType.MODEL_INPUT, (state: AgentState) => modelInput({ ...nodeParams, state }))
.addEdge(START, NodeType.MODEL_INPUT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,7 @@ export const streamGraph = async ({
finalMessage += msg.content;
}
} else if (event.event === 'on_llm_end' && !didEnd) {
const generations = event.data.output?.generations[0];
if (generations && generations[0]?.generationInfo.finish_reason === 'stop') {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this condition prevented aborted generations from being persisted in the conversation

handleStreamEnd(generations[0]?.text ?? finalMessage);
}
handleStreamEnd(event.data.output?.generations[0][0]?.text ?? finalMessage);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
// we need to pass it like this or streaming does not work for bedrock
createLlmInstance,
logger,
signal: abortSignal,
tools,
replacements,
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,19 @@

import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { StringWithAutocomplete } from '@langchain/core/dist/utils/types';
import { RunnableConfig } from '@langchain/core/runnables';
import { AGENT_NODE_TAG } from './run_agent';
import { AgentState, NodeParamsBase } from '../types';
import { NodeType } from '../constants';

export interface RespondParams extends NodeParamsBase {
state: AgentState;
model: BaseChatModel;
config?: RunnableConfig;
}

export async function respond({
config,
logger,
state,
model,
Expand All @@ -34,7 +37,7 @@ export async function respond({

const responseMessage = await model
// use AGENT_NODE_TAG to identify as agent node for stream parsing
.withConfig({ runName: 'Summarizer', tags: [AGENT_NODE_TAG] })
.withConfig({ runName: 'Summarizer', tags: [AGENT_NODE_TAG], signal: config?.signal })
.invoke([userMessage]);

return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,21 +43,22 @@ export async function runAgent({
logger.debug(() => `${NodeType.AGENT}: Node state:\n${JSON.stringify(state, null, 2)}`);

const knowledgeHistory = await kbDataClient?.getRequiredKnowledgeBaseDocumentEntries();

const agentOutcome = await agentRunnable.withConfig({ tags: [AGENT_NODE_TAG] }).invoke(
{
...state,
knowledge_history: `${KNOWLEDGE_HISTORY_PREFIX}\n${
knowledgeHistory?.length
? JSON.stringify(knowledgeHistory.map((e) => e.text))
: NO_KNOWLEDGE_HISTORY
}`,
// prepend any user prompt (gemini)
input: formatLatestUserMessage(state.input, state.llmType),
chat_history: state.messages, // TODO: Message de-dupe with ...state spread
},
config
);
const agentOutcome = await agentRunnable
.withConfig({ tags: [AGENT_NODE_TAG], signal: config?.signal })
.invoke(
{
...state,
knowledge_history: `${KNOWLEDGE_HISTORY_PREFIX}\n${
knowledgeHistory?.length
? JSON.stringify(knowledgeHistory.map((e) => e.text))
: NO_KNOWLEDGE_HISTORY
}`,
// prepend any user prompt (gemini)
input: formatLatestUserMessage(state.input, state.llmType),
chat_history: state.messages, // TODO: Message de-dupe with ...state spread
},
config
);

return {
agentOutcome,
Expand Down
Loading