From eda8fc8aabe1d25462539d1f381d857fa6989837 Mon Sep 17 00:00:00 2001 From: Kibana Machine <42973632+kibanamachine@users.noreply.github.com> Date: Tue, 24 Dec 2024 05:08:11 +1100 Subject: [PATCH] [8.x] OpenAI connector: send default model for "other" openAI provider (#204934) (#205107) # Backport This will backport the following commits from `main` to `8.x`: - [OpenAI connector: send default model for "other" openAI provider (#204934)](https://github.com/elastic/kibana/pull/204934) ### Questions ? Please refer to the [Backport tool documentation](https://github.com/sqren/backport) Co-authored-by: Pierre Gayvallet --- .../openai/lib/other_openai_utils.test.ts | 37 +++++++++++++++++++ .../openai/lib/other_openai_utils.ts | 10 ++++- .../connector_types/openai/lib/utils.test.ts | 14 ++++++- .../connector_types/openai/lib/utils.ts | 2 +- 4 files changed, 58 insertions(+), 5 deletions(-) diff --git a/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/other_openai_utils.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/other_openai_utils.test.ts index 33722314f5422..1cdcd40b11a30 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/other_openai_utils.test.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/other_openai_utils.test.ts @@ -112,5 +112,42 @@ describe('Other (OpenAI Compatible Service) Utils', () => { const sanitizedBodyString = getRequestWithStreamOption(bodyString, false); expect(sanitizedBodyString).toEqual(bodyString); }); + + it('sets model parameter if specified and not present in the body', () => { + const body = { + messages: [ + { + role: 'user', + content: 'This is a test', + }, + ], + }; + + const sanitizedBodyString = getRequestWithStreamOption(JSON.stringify(body), true, 'llama-3'); + expect(JSON.parse(sanitizedBodyString)).toEqual({ + messages: [{ content: 'This is a test', role: 'user' }], + model: 'llama-3', + stream: true, + }); + }); + + it('does not overrides model parameter if present in the body', () => { + const body = { + model: 'mistral', + messages: [ + { + role: 'user', + content: 'This is a test', + }, + ], + }; + + const sanitizedBodyString = getRequestWithStreamOption(JSON.stringify(body), true, 'llama-3'); + expect(JSON.parse(sanitizedBodyString)).toEqual({ + messages: [{ content: 'This is a test', role: 'user' }], + model: 'mistral', + stream: true, + }); + }); }); }); diff --git a/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/other_openai_utils.ts b/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/other_openai_utils.ts index 8288e0dba9ad1..0d3fb88ccc739 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/other_openai_utils.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/other_openai_utils.ts @@ -23,13 +23,19 @@ export const sanitizeRequest = (body: string): string => { * The stream parameter is accepted in the ChatCompletion * API and the Completion API only */ -export const getRequestWithStreamOption = (body: string, stream: boolean): string => { +export const getRequestWithStreamOption = ( + body: string, + stream: boolean, + defaultModel?: string +): string => { try { const jsonBody = JSON.parse(body); if (jsonBody) { jsonBody.stream = stream; } - + if (defaultModel && !jsonBody.model) { + jsonBody.model = defaultModel; + } return JSON.stringify(jsonBody); } catch (err) { // swallow the error diff --git a/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/utils.test.ts b/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/utils.test.ts index 142f3a319eeb6..08389a1195706 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/utils.test.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/utils.test.ts @@ -111,9 +111,19 @@ describe('Utils', () => { }); it('calls other_openai_utils getRequestWithStreamOption when provider is Other OpenAi', () => { - getRequestWithStreamOption(OpenAiProviderType.Other, OPENAI_CHAT_URL, bodyString, true); + getRequestWithStreamOption( + OpenAiProviderType.Other, + OPENAI_CHAT_URL, + bodyString, + true, + 'default-model' + ); - expect(mockOtherOpenAiGetRequestWithStreamOption).toHaveBeenCalledWith(bodyString, true); + expect(mockOtherOpenAiGetRequestWithStreamOption).toHaveBeenCalledWith( + bodyString, + true, + 'default-model' + ); expect(mockOpenAiGetRequestWithStreamOption).not.toHaveBeenCalled(); expect(mockAzureAiGetRequestWithStreamOption).not.toHaveBeenCalled(); }); diff --git a/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/utils.ts b/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/utils.ts index 3028433656503..ebe1d3bac578e 100644 --- a/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/utils.ts +++ b/x-pack/plugins/stack_connectors/server/connector_types/openai/lib/utils.ts @@ -75,7 +75,7 @@ export function getRequestWithStreamOption( case OpenAiProviderType.AzureAi: return azureAiGetRequestWithStreamOption(url, body, stream); case OpenAiProviderType.Other: - return otherOpenAiGetRequestWithStreamOption(body, stream); + return otherOpenAiGetRequestWithStreamOption(body, stream, defaultModel); default: return body; }