Skip to content

Commit

Permalink
[8.x] [Security Assistant] Fix abort stream OpenAI issue (#203193) (#…
Browse files Browse the repository at this point in the history
…203230)

# Backport

This will backport the following commits from `main` to `8.x`:
- [[Security Assistant] Fix abort stream OpenAI issue
(#203193)](#203193)

<!--- Backport version: 9.4.3 -->

### Questions ?
Please refer to the [Backport tool
documentation](https://github.com/sqren/backport)

<!--BACKPORT [{"author":{"name":"Steph
Milovic","email":"[email protected]"},"sourceCommit":{"committedDate":"2024-12-06T11:20:13Z","message":"[Security
Assistant] Fix abort stream OpenAI issue
(#203193)","sha":"4b42953aacfe903b6c9ecef4ac841348e65ed5d5","branchLabelMapping":{"^v9.0.0$":"main","^v8.18.0$":"8.x","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","v9.0.0","Team:
SecuritySolution","backport:prev-minor","Team:Security Generative
AI"],"title":"[Security Assistant] Fix abort stream OpenAI
issue","number":203193,"url":"https://github.com/elastic/kibana/pull/203193","mergeCommit":{"message":"[Security
Assistant] Fix abort stream OpenAI issue
(#203193)","sha":"4b42953aacfe903b6c9ecef4ac841348e65ed5d5"}},"sourceBranch":"main","suggestedTargetBranches":[],"targetPullRequestStates":[{"branch":"main","label":"v9.0.0","branchLabelMappingKey":"^v9.0.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/203193","number":203193,"mergeCommit":{"message":"[Security
Assistant] Fix abort stream OpenAI issue
(#203193)","sha":"4b42953aacfe903b6c9ecef4ac841348e65ed5d5"}}]}]
BACKPORT-->

Co-authored-by: Steph Milovic <[email protected]>
  • Loading branch information
kibanamachine and stephmilovic authored Dec 6, 2024
1 parent da176d3 commit b5dcebb
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,131 @@ describe('streamGraph', () => {
);
});
});
it('on_llm_end events with finish_reason != stop should not end the stream', async () => {
mockStreamEvents.mockReturnValue({
next: jest
.fn()
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
},
done: false,
})
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_end',
data: {
output: {
generations: [[{ generationInfo: { finish_reason: 'function_call' }, text: '' }]],
},
},
tags: [AGENT_NODE_TAG],
},
})
.mockResolvedValue({
done: true,
}),
return: jest.fn(),
});

const response = await streamGraph(requestArgs);

expect(response).toBe(mockResponseWithHeaders);
expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' });
await waitFor(() => {
expect(mockOnLlmResponse).not.toHaveBeenCalled();
});
});
it('on_llm_end events without a finish_reason should end the stream', async () => {
mockStreamEvents.mockReturnValue({
next: jest
.fn()
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
},
done: false,
})
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_end',
data: {
output: {
generations: [[{ generationInfo: {}, text: 'final message' }]],
},
},
tags: [AGENT_NODE_TAG],
},
})
.mockResolvedValue({
done: true,
}),
return: jest.fn(),
});

const response = await streamGraph(requestArgs);

expect(response).toBe(mockResponseWithHeaders);
expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' });
await waitFor(() => {
expect(mockOnLlmResponse).toHaveBeenCalledWith(
'final message',
{ transactionId: 'transactionId', traceId: 'traceId' },
false
);
});
});
it('on_llm_end events is called with chunks if there is no final text value', async () => {
mockStreamEvents.mockReturnValue({
next: jest
.fn()
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
},
done: false,
})
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_end',
data: {
output: {
generations: [[{ generationInfo: {}, text: '' }]],
},
},
tags: [AGENT_NODE_TAG],
},
})
.mockResolvedValue({
done: true,
}),
return: jest.fn(),
});

const response = await streamGraph(requestArgs);

expect(response).toBe(mockResponseWithHeaders);
expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' });
await waitFor(() => {
expect(mockOnLlmResponse).toHaveBeenCalledWith(
'content',
{ transactionId: 'transactionId', traceId: 'traceId' },
false
);
});
});
});

describe('Tool Calling Agent and Structured Chat Agent streaming', () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,16 @@ export const streamGraph = async ({
finalMessage += msg.content;
}
} else if (event.event === 'on_llm_end' && !didEnd) {
handleStreamEnd(event.data.output?.generations[0][0]?.text ?? finalMessage);
const generation = event.data.output?.generations[0][0];
if (
// no finish_reason means the stream was aborted
!generation?.generationInfo?.finish_reason ||
generation?.generationInfo?.finish_reason === 'stop'
) {
handleStreamEnd(
generation?.text && generation?.text.length ? generation?.text : finalMessage
);
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,10 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
// we need to pass it like this or streaming does not work for bedrock
createLlmInstance,
logger,
signal: abortSignal,
tools,
replacements,
// some chat models (bedrock) require a signal to be passed on agent invoke rather than the signal passed to the chat model
...(llmType === 'bedrock' ? { signal: abortSignal } : {}),
});
const inputs: GraphInputs = {
responseLanguage,
Expand Down

0 comments on commit b5dcebb

Please sign in to comment.