Skip to content

Commit

Permalink
[JS] feat: add managed identity auth support for AOAI and update AI S…
Browse files Browse the repository at this point in the history
…earch sample (#1725)

## Linked issues

closes: #1714 #1665 #1664 #1666 

## Details

- add managed identity auth support for AOAI and update AI Search sample
- validate bot tenant id on incoming activity
- fix citations parsing error when title is null
  • Loading branch information
aacebo authored Jun 11, 2024
1 parent 37917eb commit a441b82
Show file tree
Hide file tree
Showing 21 changed files with 267 additions and 210 deletions.
182 changes: 96 additions & 86 deletions js/packages/teams-ai/src/AI.ts
Original file line number Diff line number Diff line change
Expand Up @@ -347,106 +347,116 @@ export class AI<TState extends TurnState = TurnState> {
* @returns {Promise<boolean>} True if the plan was completely executed, otherwise false.
*/
public async run(context: TurnContext, state: TState, start_time?: number, step_count?: number): Promise<boolean> {
// Initialize start time and action count
const { max_steps, max_time } = this._options;
if (start_time === undefined) {
start_time = Date.now();
}
if (step_count === undefined) {
step_count = 0;
}

// Review input on first loop
let plan: Plan | undefined =
step_count == 0 ? await this._options.moderator.reviewInput(context, state) : undefined;

// Generate plan
if (!plan) {
if (step_count == 0) {
plan = await this._options.planner.beginTask(context, state, this);
} else {
plan = await this._options.planner.continueTask(context, state, this);
try {
// Initialize start time and action count
const { max_steps, max_time } = this._options;
if (start_time === undefined) {
start_time = Date.now();
}
if (step_count === undefined) {
step_count = 0;
}

// Review the plans output
plan = await this._options.moderator.reviewOutput(context, state, plan);
}
// Review input on first loop
let plan: Plan | undefined =
step_count == 0 ? await this._options.moderator.reviewInput(context, state) : undefined;

// Process generated plan
let completed = false;
const response = await this._actions
.get(AI.PlanReadyActionName)!
.handler(context, state, plan, AI.PlanReadyActionName);
if (response == AI.StopCommandName) {
return false;
}
// Generate plan
if (!plan) {
if (step_count == 0) {
plan = await this._options.planner.beginTask(context, state, this);
} else {
plan = await this._options.planner.continueTask(context, state, this);
}

// Run predicted commands
// - If the plan ends on a SAY command then the plan is considered complete, otherwise we'll loop
completed = true;
let should_loop = false;
for (let i = 0; i < plan.commands.length; i++) {
// Check for timeout
if (Date.now() - start_time! > max_time || ++step_count! > max_steps) {
completed = false;
const parameters: actions.TooManyStepsParameters = {
max_steps,
max_time,
start_time: start_time!,
step_count: step_count!
};
await this._actions
.get(AI.TooManyStepsActionName)!
.handler(context, state, parameters, AI.TooManyStepsActionName);
break;
// Review the plans output
plan = await this._options.moderator.reviewOutput(context, state, plan);
}

let output: string;
const cmd = plan.commands[i];
switch (cmd.type) {
case 'DO': {
const { action } = cmd as PredictedDoCommand;
if (this._actions.has(action)) {
// Call action handler
const handler = this._actions.get(action)!.handler;
output = await this._actions
.get(AI.DoCommandActionName)!
.handler(context, state, { handler, ...(cmd as PredictedDoCommand) }, action);
should_loop = output.length > 0;
state.temp.actionOutputs[action] = output;
} else {
// Redirect to UnknownAction handler
output = await this._actions.get(AI.UnknownActionName)!.handler(context, state, plan, action);
}
// Process generated plan
let completed = false;
const response = await this._actions
.get(AI.PlanReadyActionName)!
.handler(context, state, plan, AI.PlanReadyActionName);
if (response == AI.StopCommandName) {
return false;
}

// Run predicted commands
// - If the plan ends on a SAY command then the plan is considered complete, otherwise we'll loop
completed = true;
let should_loop = false;
for (let i = 0; i < plan.commands.length; i++) {
// Check for timeout
if (Date.now() - start_time! > max_time || ++step_count! > max_steps) {
completed = false;
const parameters: actions.TooManyStepsParameters = {
max_steps,
max_time,
start_time: start_time!,
step_count: step_count!
};
await this._actions
.get(AI.TooManyStepsActionName)!
.handler(context, state, parameters, AI.TooManyStepsActionName);
break;
}
case 'SAY':
should_loop = false;
output = await this._actions
.get(AI.SayCommandActionName)!
.handler(context, state, cmd, AI.SayCommandActionName);

let output: string;
const cmd = plan.commands[i];
switch (cmd.type) {
case 'DO': {
const { action } = cmd as PredictedDoCommand;
if (this._actions.has(action)) {
// Call action handler
const handler = this._actions.get(action)!.handler;
output = await this._actions
.get(AI.DoCommandActionName)!
.handler(context, state, { handler, ...(cmd as PredictedDoCommand) }, action);
should_loop = output.length > 0;
state.temp.actionOutputs[action] = output;
} else {
// Redirect to UnknownAction handler
output = await this._actions.get(AI.UnknownActionName)!.handler(context, state, plan, action);
}
break;
}
case 'SAY':
should_loop = false;
output = await this._actions
.get(AI.SayCommandActionName)!
.handler(context, state, cmd, AI.SayCommandActionName);
break;
default:
throw new Error(`AI.run(): unknown command of '${cmd.type}' predicted.`);
}

// Check for stop command
if (output == AI.StopCommandName) {
completed = false;
break;
default:
throw new Error(`AI.run(): unknown command of '${cmd.type}' predicted.`);
}
}

// Check for stop command
if (output == AI.StopCommandName) {
completed = false;
break;
// Copy the actions output to the input
state.temp.lastOutput = output;
state.temp.input = output;
state.temp.inputFiles = [];
}

// Copy the actions output to the input
state.temp.lastOutput = output;
state.temp.input = output;
state.temp.inputFiles = [];
}
// Check for looping
if (completed && should_loop && this._options.allow_looping) {
return await this.run(context, state, start_time, step_count);
}

// Check for looping
if (completed && should_loop && this._options.allow_looping) {
return await this.run(context, state, start_time, step_count);
} else {
return completed;
} catch (err) {
const onHttpError = this._actions.get(AI.HttpErrorActionName);

if (onHttpError) {
await onHttpError.handler(context, state, err, AI.HttpErrorActionName);
}

return false;
}
}
}
10 changes: 8 additions & 2 deletions js/packages/teams-ai/src/actions/HttpError.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@ import { httpError } from './HttpError';
describe('actions.httpError', () => {
const handler = httpError();

it('should throw', async () => {
it('should throw default error', async () => {
assert.rejects(async () => {
await handler();
await handler({} as any, {} as any);
}, 'An AI http request failed');
});

it('should throw given error', async () => {
assert.rejects(async () => {
await handler({} as any, {} as any, new Error('a given error'));
}, 'a given error');
});
});
10 changes: 7 additions & 3 deletions js/packages/teams-ai/src/actions/HttpError.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@
* Licensed under the MIT License.
*/

import { TurnContext } from 'botbuilder-core';

import { TurnState } from '../TurnState';

/**
* @private
*/
export function httpError() {
return async (): Promise<string> => {
throw new Error(`An AI http request failed`);
export function httpError<TState extends TurnState = TurnState>() {
return async (_context: TurnContext, _state: TState, err?: Error): Promise<string> => {
throw err || new Error(`An AI http request failed`);
};
}
8 changes: 5 additions & 3 deletions js/packages/teams-ai/src/actions/SayCommand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -173,15 +173,17 @@ export function sayCommand<TState extends TurnState = TurnState>(feedbackLoopEna

if (data.response.context && data.response.context.citations.length > 0) {
citations = data.response.context!.citations.map((citation, i) => {
return {
const clientCitation: ClientCitation = {
'@type': 'Claim',
position: `${i + 1}`,
appearance: {
'@type': 'DigitalDocument',
name: citation.title,
name: citation.title || `Document #${i + 1}`,
abstract: Utilities.snippet(citation.content, 500)
}
} as ClientCitation;
};

return clientCitation;
});
}

Expand Down
72 changes: 47 additions & 25 deletions js/packages/teams-ai/src/models/OpenAIModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ export interface AzureOpenAIModelOptions extends BaseOpenAIModelOptions {
/**
* API key to use when making requests to Azure OpenAI.
*/
azureApiKey: string;
azureApiKey?: string;

/**
* Default name of the Azure OpenAI deployment (model) to use.
Expand All @@ -148,6 +148,12 @@ export interface AzureOpenAIModelOptions extends BaseOpenAIModelOptions {
* Optional. Version of the API being called. Defaults to `2023-05-15`.
*/
azureApiVersion?: string;

/**
* Optional. A function that returns an access token for Microsoft Entra (formerly known as Azure Active Directory),
* which will be invoked on every request.
*/
azureADTokenProvider?: () => Promise<string>;
}

/**
Expand All @@ -170,7 +176,7 @@ export class OpenAIModel implements PromptCompletionModel {
*/
public constructor(options: OpenAIModelOptions | AzureOpenAIModelOptions | OpenAILikeModelOptions) {
// Check for azure config
if ((options as AzureOpenAIModelOptions).azureApiKey) {
if ('azureApiKey' in options || 'azureADTokenProvider' in options) {
this._useAzure = true;
this.options = Object.assign(
{
Expand Down Expand Up @@ -377,41 +383,57 @@ export class OpenAIModel implements PromptCompletionModel {
// Initialize request config
const requestConfig: AxiosRequestConfig = Object.assign({}, this.options.requestConfig);

// Initialize request headers
if (!requestConfig.headers) {
requestConfig.headers = {};
}

if (!requestConfig.headers['Content-Type']) {
requestConfig.headers['Content-Type'] = 'application/json';
}

if (!requestConfig.headers['User-Agent']) {
requestConfig.headers['User-Agent'] = this.UserAgent;
}
if (this._useAzure) {
const options = this.options as AzureOpenAIModelOptions;
requestConfig.headers['api-key'] = options.azureApiKey;
} else if ((this.options as OpenAIModelOptions).apiKey) {
const options = this.options as OpenAIModelOptions;
requestConfig.headers['Authorization'] = `Bearer ${options.apiKey}`;
if (options.organization) {
requestConfig.headers['OpenAI-Organization'] = options.organization;

if ('apiKey' in this.options) {
requestConfig.headers['api-key'] = this.options.apiKey || '';
}

if ('azureApiKey' in this.options || 'azureADTokenProvider' in this.options) {
let apiKey = this.options.azureApiKey;

if (!apiKey && this.options.azureADTokenProvider) {
apiKey = await this.options.azureADTokenProvider();
}

requestConfig.headers['Authorization'] = `Bearer ${apiKey}`;
}

// Send request
const response = await this._httpClient.post(url, body, requestConfig);

// Check for rate limit error
if (
response.status == 429 &&
Array.isArray(this.options.retryPolicy) &&
retryCount < this.options.retryPolicy.length
) {
const delay = this.options.retryPolicy[retryCount];
await new Promise((resolve) => setTimeout(resolve, delay));
return this.post(url, body, retryCount + 1);
} else {
return response;
if ('organization' in this.options && this.options.organization) {
requestConfig.headers['OpenAI-Organization'] = this.options.organization;
}

try {
const res = await this._httpClient.post(url, body, requestConfig);

// Check for rate limit error
if (
res.status == 429 &&
Array.isArray(this.options.retryPolicy) &&
retryCount < this.options.retryPolicy.length
) {
const delay = this.options.retryPolicy[retryCount];
await new Promise((resolve) => setTimeout(resolve, delay));
return this.post(url, body, retryCount + 1);
}

return res;
} catch (err) {
if (this.options.logRequests) {
console.error(Colorize.error(err as Error));
}

throw err;
}
}
}
6 changes: 3 additions & 3 deletions js/packages/teams-ai/src/prompts/Message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,17 @@ export interface Citation {
/**
* The title of the citation.
*/
title: string;
title: string | null;

/**
* The URL of the citation.
*/
url: string;
url: string | null;

/**
* The filepath of the document.
*/
filepath: string;
filepath: string | null;
}

export interface MessageContext {
Expand Down
Loading

0 comments on commit a441b82

Please sign in to comment.