diff --git a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/use_current_conversation/index.tsx b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/use_current_conversation/index.tsx index 267c39c402a1c..ab5e5532a19cb 100644 --- a/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/use_current_conversation/index.tsx +++ b/x-pack/platform/packages/shared/kbn-elastic-assistant/impl/assistant/use_current_conversation/index.tsx @@ -265,18 +265,24 @@ export const useCurrentConversation = ({ } const newSystemPrompt = getDefaultNewSystemPrompt(allSystemPrompts); + let conversation: Partial = {}; + if (currentConversation?.apiConfig) { + const { defaultSystemPromptId: _, ...restApiConfig } = currentConversation?.apiConfig; + conversation = + restApiConfig.actionTypeId != null + ? { + apiConfig: { + ...restApiConfig, + ...(newSystemPrompt?.id != null + ? { defaultSystemPromptId: newSystemPrompt.id } + : {}), + }, + } + : {}; + } const newConversation = await createConversation({ title: NEW_CHAT, - ...(currentConversation?.apiConfig != null && - currentConversation?.apiConfig?.actionTypeId != null - ? { - apiConfig: { - connectorId: currentConversation.apiConfig.connectorId, - actionTypeId: currentConversation.apiConfig.actionTypeId, - ...(newSystemPrompt?.id != null ? { defaultSystemPromptId: newSystemPrompt.id } : {}), - }, - } - : {}), + ...conversation, }); if (newConversation) { diff --git a/x-pack/plugins/stack_connectors/public/connector_types/openai/constants.tsx b/x-pack/plugins/stack_connectors/public/connector_types/openai/constants.tsx index 5f4238e52af78..a24db86804f95 100644 --- a/x-pack/plugins/stack_connectors/public/connector_types/openai/constants.tsx +++ b/x-pack/plugins/stack_connectors/public/connector_types/openai/constants.tsx @@ -11,23 +11,48 @@ import { FormattedMessage } from '@kbn/i18n-react'; import { EuiLink } from '@elastic/eui'; import { DEFAULT_OPENAI_MODEL, OpenAiProviderType } from '../../../common/openai/constants'; import * as i18n from './translations'; +import { Config } from './types'; export const DEFAULT_URL = 'https://api.openai.com/v1/chat/completions' as const; export const DEFAULT_URL_AZURE = 'https://{your-resource-name}.openai.azure.com/openai/deployments/{deployment-id}/chat/completions?api-version={api-version}' as const; -export const DEFAULT_BODY = `{ +const DEFAULT_BODY = `{ "messages": [{ "role":"user", "content":"Hello world" }] }`; -export const DEFAULT_BODY_AZURE = `{ +const DEFAULT_BODY_AZURE = `{ "messages": [{ "role":"user", "content":"Hello world" }] }`; +const DEFAULT_BODY_OTHER = (defaultModel: string) => `{ + "model": "${defaultModel}", + "messages": [{ + "role":"user", + "content":"Hello world" + }] +}`; + +export const getDefaultBody = (config?: Config) => { + if (!config) { + // default to OpenAiProviderType.OpenAi sample data + return DEFAULT_BODY; + } + if (config?.apiProvider === OpenAiProviderType.Other) { + // update sample data if Other (OpenAI Compatible Service) + return config.defaultModel ? DEFAULT_BODY_OTHER(config.defaultModel) : DEFAULT_BODY; + } + if (config?.apiProvider === OpenAiProviderType.AzureAi) { + // update sample data if AzureAi + return DEFAULT_BODY_AZURE; + } + // default to OpenAiProviderType.OpenAi sample data + return DEFAULT_BODY; +}; export const openAiConfig: ConfigFieldSchema[] = [ { diff --git a/x-pack/plugins/stack_connectors/public/connector_types/openai/params.test.tsx b/x-pack/plugins/stack_connectors/public/connector_types/openai/params.test.tsx index 7539cc6bf6373..c03582ba0b229 100644 --- a/x-pack/plugins/stack_connectors/public/connector_types/openai/params.test.tsx +++ b/x-pack/plugins/stack_connectors/public/connector_types/openai/params.test.tsx @@ -9,7 +9,7 @@ import React from 'react'; import { fireEvent, render } from '@testing-library/react'; import ParamsFields from './params'; import { OpenAiProviderType, SUB_ACTION } from '../../../common/openai/constants'; -import { DEFAULT_BODY, DEFAULT_BODY_AZURE, DEFAULT_URL } from './constants'; +import { DEFAULT_URL, getDefaultBody } from './constants'; const messageVariables = [ { @@ -73,14 +73,15 @@ describe('Gen AI Params Fields renders', () => { ); expect(editAction).toHaveBeenCalledTimes(2); expect(editAction).toHaveBeenCalledWith('subAction', SUB_ACTION.RUN, 0); + const body = getDefaultBody(actionConnector.config); if (apiProvider === OpenAiProviderType.OpenAi) { - expect(editAction).toHaveBeenCalledWith('subActionParams', { body: DEFAULT_BODY }, 0); + expect(editAction).toHaveBeenCalledWith('subActionParams', { body }, 0); } if (apiProvider === OpenAiProviderType.AzureAi) { - expect(editAction).toHaveBeenCalledWith('subActionParams', { body: DEFAULT_BODY_AZURE }, 0); + expect(editAction).toHaveBeenCalledWith('subActionParams', { body }, 0); } if (apiProvider === OpenAiProviderType.Other) { - expect(editAction).toHaveBeenCalledWith('subActionParams', { body: DEFAULT_BODY }, 0); + expect(editAction).toHaveBeenCalledWith('subActionParams', { body }, 0); } } ); diff --git a/x-pack/plugins/stack_connectors/public/connector_types/openai/params.tsx b/x-pack/plugins/stack_connectors/public/connector_types/openai/params.tsx index ad4398482d2c8..000abfa4872be 100644 --- a/x-pack/plugins/stack_connectors/public/connector_types/openai/params.tsx +++ b/x-pack/plugins/stack_connectors/public/connector_types/openai/params.tsx @@ -12,8 +12,8 @@ import { ActionConnectorMode, JsonEditorWithMessageVariables, } from '@kbn/triggers-actions-ui-plugin/public'; -import { OpenAiProviderType, SUB_ACTION } from '../../../common/openai/constants'; -import { DEFAULT_BODY, DEFAULT_BODY_AZURE } from './constants'; +import { SUB_ACTION } from '../../../common/openai/constants'; +import { getDefaultBody } from './constants'; import { OpenAIActionConnector, ActionParams } from './types'; const ParamsFields: React.FunctionComponent> = ({ @@ -41,16 +41,10 @@ const ParamsFields: React.FunctionComponent> = ( useEffect(() => { if (!subActionParams) { - // default to OpenAiProviderType.OpenAi sample data - let sampleBody = DEFAULT_BODY; - - if (typedActionConnector?.config?.apiProvider === OpenAiProviderType.AzureAi) { - // update sample data if AzureAi - sampleBody = DEFAULT_BODY_AZURE; - } + const sampleBody = getDefaultBody(typedActionConnector?.config); editAction('subActionParams', { body: sampleBody }, index); } - }, [typedActionConnector?.config?.apiProvider, editAction, index, subActionParams]); + }, [typedActionConnector?.config, editAction, index, subActionParams]); const editSubActionParams = useCallback( (params: ActionParams['subActionParams']) => { diff --git a/x-pack/plugins/stack_connectors/public/connector_types/openai/types.ts b/x-pack/plugins/stack_connectors/public/connector_types/openai/types.ts index 3ba19c04d13a7..ea37fee0de879 100644 --- a/x-pack/plugins/stack_connectors/public/connector_types/openai/types.ts +++ b/x-pack/plugins/stack_connectors/public/connector_types/openai/types.ts @@ -18,6 +18,7 @@ export interface ActionParams { export interface Config { apiProvider: OpenAiProviderType; apiUrl: string; + defaultModel?: string; } export interface Secrets {