diff --git a/qiskit_code_assistant_jupyterlab/handlers.py b/qiskit_code_assistant_jupyterlab/handlers.py index 2d24200..f8aa211 100644 --- a/qiskit_code_assistant_jupyterlab/handlers.py +++ b/qiskit_code_assistant_jupyterlab/handlers.py @@ -88,7 +88,10 @@ def convert_openai(model): class ServiceUrlHandler(APIHandler): @tornado.web.authenticated def get(self): - self.finish(json.dumps({"url": runtime_configs["service_url"]})) + self.finish(json.dumps({ + "url": runtime_configs["service_url"], + "is_openai": runtime_configs["is_openai"] + })) @tornado.web.authenticated def post(self): @@ -102,7 +105,10 @@ def post(self): except (requests.exceptions.JSONDecodeError, KeyError): runtime_configs["is_openai"] = True finally: - self.finish(json.dumps({"url": runtime_configs["service_url"]})) + self.finish(json.dumps({ + "url": runtime_configs["service_url"], + "is_openai": runtime_configs["is_openai"] + })) class TokenHandler(APIHandler): @@ -282,7 +288,7 @@ class FeedbackHandler(APIHandler): @tornado.web.authenticated def post(self): if runtime_configs["is_openai"]: - self.finish(json.dumps({"success": "true"})) + self.finish(json.dumps({"message": "Feedback not supported for this service"})) else: url = url_path_join(runtime_configs["service_url"], "feedback") diff --git a/src/QiskitCompletionProvider.ts b/src/QiskitCompletionProvider.ts index 3f54d0a..a150ffe 100644 --- a/src/QiskitCompletionProvider.ts +++ b/src/QiskitCompletionProvider.ts @@ -39,6 +39,8 @@ const FEEDBACK_COMMAND = 'qiskit-code-assistant:prompt-feedback'; export let lastPrompt: ICompletionReturn | undefined = undefined; +export const wipeLastPrompt = () => (lastPrompt = undefined); + function getInputText(text: string, widget: Widget): string { const cellsContents: string[] = []; diff --git a/src/StatusBarWidget.ts b/src/StatusBarWidget.ts index 836148f..dcf69cb 100644 --- a/src/StatusBarWidget.ts +++ b/src/StatusBarWidget.ts @@ -19,6 +19,7 @@ import { Message } from '@lumino/messaging'; import { refreshIcon } from '@jupyterlab/ui-components'; import { Widget } from '@lumino/widgets'; +import { wipeLastPrompt } from './QiskitCompletionProvider'; import { showDisclaimer } from './service/disclaimer'; import { getCurrentModel, @@ -70,10 +71,11 @@ export class StatusBarWidget extends Widget { async onClick() { await checkAPIToken().then(() => { const modelsList = getModelsList(); + const dropDownList = [...modelsList.map(m => m.display_name)]; InputDialog.getItem({ title: 'Select a Model', - items: [...modelsList.map(m => m.display_name)], - current: getCurrentModel()?.display_name + items: dropDownList, + current: dropDownList.indexOf(getCurrentModel()?.display_name || '') }).then(result => { if (result.button.accept) { const model = modelsList.find(m => m.display_name === result.value); @@ -81,6 +83,7 @@ export class StatusBarWidget extends Widget { if (model) { showDisclaimer(model._id).then(accepted => { if (accepted) { + wipeLastPrompt(); setCurrentModel(model); } }); diff --git a/src/index.ts b/src/index.ts index 82f0c8c..414f90f 100644 --- a/src/index.ts +++ b/src/index.ts @@ -28,7 +28,8 @@ import { StatusBarWidget } from './StatusBarWidget'; import { lastPrompt, QiskitCompletionProvider, - QiskitInlineCompletionProvider + QiskitInlineCompletionProvider, + wipeLastPrompt } from './QiskitCompletionProvider'; import { postServiceUrl } from './service/api'; import { getFeedbackStatusBarWidget, getFeedback } from './service/feedback'; @@ -73,10 +74,21 @@ const plugin: JupyterFrontEndPlugin = { const settings = await settingRegistry.load(plugin.id); console.debug(EXTENSION_ID + ' settings loaded:', settings.composite); - postServiceUrl(settings.composite['serviceUrl'] as string); + let is_openai = false; + + postServiceUrl(settings.composite['serviceUrl'] as string).then( + response => { + is_openai = response.is_openai; + wipeLastPrompt(); + } + ); settings.changed.connect(() => - postServiceUrl(settings.composite['serviceUrl'] as string).then(() => - refreshModelsList() + postServiceUrl(settings.composite['serviceUrl'] as string).then( + response => { + is_openai = response.is_openai; + wipeLastPrompt(); + refreshModelsList(); + } ) ); @@ -87,7 +99,8 @@ const plugin: JupyterFrontEndPlugin = { statusBar.registerStatusItem(EXTENSION_ID + ':feedback', { item: getFeedbackStatusBarWidget(), - align: 'left' + align: 'left', + isActive: () => !is_openai }); const statusBarWidget = new StatusBarWidget(); @@ -104,11 +117,13 @@ const plugin: JupyterFrontEndPlugin = { label: 'Give feedback for the Qiskit Code Assistant', icon: feedbackIcon, execute: () => getFeedback(), - isEnabled: () => lastPrompt !== undefined, + isEnabled: () => !is_openai && lastPrompt !== undefined, isVisible: () => + !is_openai && ['code', 'markdown'].includes( notebookTracker.activeCell?.model.type || '' - ) && lastPrompt !== undefined + ) && + lastPrompt !== undefined }); app.commands.addCommand(CommandIDs.updateApiToken, { diff --git a/src/service/api.ts b/src/service/api.ts index 1fc6c1e..95abaab 100644 --- a/src/service/api.ts +++ b/src/service/api.ts @@ -22,7 +22,8 @@ import { IModelDisclaimer, IModelInfo, IModelPromptResponse, - IResponseMessage + IResponseMessage, + IServiceResponse } from '../utils/schema'; const AUTH_ERROR_CODES = [401, 403, 422]; @@ -40,14 +41,17 @@ async function notifyInvalid(response: Response): Promise { } // POST /service -export async function postServiceUrl(newUrl: string): Promise { +export async function postServiceUrl( + newUrl: string +): Promise { return await requestAPI('service', { method: 'POST', body: JSON.stringify({ url: newUrl }) }).then(response => { if (response.ok) { - response.json().then(json => { + return response.json().then(json => { console.debug('Updated service URL:', json.url); + return json; }); } else { console.error( diff --git a/src/utils/schema.ts b/src/utils/schema.ts index 87df69e..c9b2833 100644 --- a/src/utils/schema.ts +++ b/src/utils/schema.ts @@ -97,3 +97,8 @@ export interface IFeedbackForm { input?: string; output?: string; } + +export interface IServiceResponse { + url: string; + is_openai: boolean; +}