Skip to content

Commit

Permalink
test: update to reflect new behavior (renderer no longer accounts for…
Browse files Browse the repository at this point in the history
… base tokens per completions) (#56)
  • Loading branch information
joyceerhl authored Jun 6, 2024
1 parent 3749d49 commit c88008a
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 35 deletions.
1 change: 1 addition & 0 deletions .npmignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ tsconfig.json
dist/base/test/
*.map
dist/base/tokenizer/cl100kBaseTokenizer*.*
dist/base/tokenizer/cl100k_base.tiktoken
14 changes: 1 addition & 13 deletions build/postcompile.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,7 @@
* Copyright (c) Microsoft Corporation and GitHub. All rights reserved.
*--------------------------------------------------------------------------------------------*/

import * as fs from 'fs';
import * as path from 'path';

const REPO_ROOT = path.join(__dirname, '..');

export async function copyStaticAssets(srcpaths: string[], dst: string): Promise<void> {
await Promise.all(srcpaths.map(async srcpath => {
const src = path.join(REPO_ROOT, srcpath);
const dest = path.join(REPO_ROOT, dst, path.basename(srcpath));
await fs.promises.mkdir(path.dirname(dest), { recursive: true });
await fs.promises.copyFile(src, dest);
}));
}
import { copyStaticAssets } from './postinstall';

async function main() {
// Ship the vscodeTypes.d.ts file in the dist bundle
Expand Down
26 changes: 26 additions & 0 deletions build/postinstall.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation and GitHub. All rights reserved.
*--------------------------------------------------------------------------------------------*/

import * as fs from 'fs';
import * as path from 'path';

const REPO_ROOT = path.join(__dirname, '..');

export async function copyStaticAssets(srcpaths: string[], dst: string): Promise<void> {
await Promise.all(srcpaths.map(async srcpath => {
const src = path.join(REPO_ROOT, srcpath);
const dest = path.join(REPO_ROOT, dst, path.basename(srcpath));
await fs.promises.mkdir(path.dirname(dest), { recursive: true });
await fs.promises.copyFile(src, dest);
}));
}

async function main() {
// Ship the tiktoken file in the dist bundle
await copyStaticAssets([
'src/base/tokenizer/cl100k_base.tiktoken',
], 'dist/base/tokenizer');
}

main();
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
"compile": "tsc -p tsconfig.json && tsx ./build/postcompile.ts",
"watch": "tsc --watch --sourceMap",
"test": "vscode-test",
"prettier": "prettier --list-different --write --cache ."
"prettier": "prettier --list-different --write --cache .",
"prepare": "tsx ./build/postinstall.ts"
},
"keywords": [],
"author": "Microsoft Corporation",
Expand Down
2 changes: 1 addition & 1 deletion src/base/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ export async function renderPrompt<P extends BasePromptElementProps>(
mode: 'vscode' | 'none' = 'vscode',
): Promise<{ messages: (ChatMessage | LanguageModelChatMessage)[]; tokenCount: number; metadatas: MetadataMap; usedContext: ChatDocumentContext[]; references: PromptReference[] }> {
let tokenizer = 'countTokens' in tokenizerMetadata
? new AnyTokenizer(tokenizerMetadata.countTokens)
? new AnyTokenizer((text, token) => tokenizerMetadata.countTokens(text, token))
: tokenizerMetadata;
const renderer = new PromptRenderer(endpoint, ctor, props, tokenizer);
let { messages, tokenCount, references } = await renderer.render(progress, token);
Expand Down
40 changes: 20 additions & 20 deletions src/base/test/renderer.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
*--------------------------------------------------------------------------------------------*/

import * as assert from 'assert';
import { ChatMessage, ChatRole } from '../openai';
import { BaseTokensPerCompletion, ChatMessage, ChatRole } from '../openai';
import { PromptElement } from '../promptElement';
import {
AssistantMessage,
Expand All @@ -26,7 +26,7 @@ import {

suite('PromptRenderer', () => {
const fakeEndpoint: any = {
modelMaxPromptTokens: 8192,
modelMaxPromptTokens: 8192 - BaseTokensPerCompletion,
} satisfies Partial<IChatEndpointInfo>;
const tokenizer = new Cl100KBaseTokenizer();

Expand Down Expand Up @@ -98,7 +98,7 @@ suite('PromptRenderer', () => {
"This late pivot means we don't have time to boil the ocean for the client deliverable.",
},
]);
assert.deepStrictEqual(res.tokenCount, 129);
assert.deepStrictEqual(res.tokenCount, 129 - BaseTokensPerCompletion);
});

test('runs async prepare in parallel', async () => {
Expand Down Expand Up @@ -270,7 +270,7 @@ suite('PromptRenderer', () => {
{ role: 'assistant', content: 'I am terrific, how are you?' },
{ role: 'user', content: 'What time is it?' },
]);
assert.deepStrictEqual(res.tokenCount, 130);
assert.deepStrictEqual(res.tokenCount, 130 - BaseTokensPerCompletion);
});

test('no shaving at limit', async () => {
Expand Down Expand Up @@ -315,11 +315,11 @@ suite('PromptRenderer', () => {
{ role: 'assistant', content: 'I am terrific, how are you?' },
{ role: 'user', content: 'What time is it?' },
]);
assert.deepStrictEqual(res.tokenCount, 130);
assert.deepStrictEqual(res.tokenCount, 130 - BaseTokensPerCompletion);
});

test('shaving one', async () => {
const res = await renderWithMaxPromptTokens(129, Prompt1, {});
const res = await renderWithMaxPromptTokens(129 - BaseTokensPerCompletion, Prompt1, {});
assert.deepStrictEqual(res.messages, [
{
role: 'system',
Expand Down Expand Up @@ -355,11 +355,11 @@ suite('PromptRenderer', () => {
{ role: 'assistant', content: 'I am terrific, how are you?' },
{ role: 'user', content: 'What time is it?' },
]);
assert.deepStrictEqual(res.tokenCount, 118);
assert.deepStrictEqual(res.tokenCount, 118 - BaseTokensPerCompletion);
});

test('shaving two', async () => {
const res = await renderWithMaxPromptTokens(110, Prompt1, {});
const res = await renderWithMaxPromptTokens(110 - BaseTokensPerCompletion, Prompt1, {});
assert.deepStrictEqual(res.messages, [
{
role: 'system',
Expand Down Expand Up @@ -390,11 +390,11 @@ suite('PromptRenderer', () => {
{ role: 'assistant', content: 'I am terrific, how are you?' },
{ role: 'user', content: 'What time is it?' },
]);
assert.deepStrictEqual(res.tokenCount, 102);
assert.deepStrictEqual(res.tokenCount, 102 - BaseTokensPerCompletion);
});

test('shaving a lot', async () => {
const res = await renderWithMaxPromptTokens(54, Prompt1, {});
const res = await renderWithMaxPromptTokens(54 - BaseTokensPerCompletion, Prompt1, {});
assert.deepStrictEqual(res.messages, [
{
role: 'system',
Expand All @@ -413,7 +413,7 @@ suite('PromptRenderer', () => {
},
{ role: 'user', content: 'What time is it?' },
]);
assert.deepStrictEqual(res.tokenCount, 53);
assert.deepStrictEqual(res.tokenCount, 53 - BaseTokensPerCompletion);
});
});
suite('renders prompts based on dynamic token budget', function () {
Expand Down Expand Up @@ -461,7 +461,7 @@ suite('PromptRenderer', () => {

test('passes budget to children based on declared flex', async () => {
const fakeEndpoint: any = {
modelMaxPromptTokens: 100, // Total allowed tokens
modelMaxPromptTokens: 100 - BaseTokensPerCompletion, // Total allowed tokens
} satisfies Partial<IChatEndpointInfo>;
const inst = new PromptRenderer(
fakeEndpoint,
Expand Down Expand Up @@ -564,7 +564,7 @@ suite('PromptRenderer', () => {
test('are rendered to chat messages', async () => {
// First render with large token budget so nothing gets dropped
const largeTokenBudgetEndpoint: any = {
modelMaxPromptTokens: 8192,
modelMaxPromptTokens: 8192 - BaseTokensPerCompletion,
} satisfies Partial<IChatEndpointInfo>;
const inst1 = new PromptRenderer(
largeTokenBudgetEndpoint,
Expand Down Expand Up @@ -604,13 +604,13 @@ suite('PromptRenderer', () => {
},
{ role: 'user', content: 'What is your name?' },
]);
assert.deepStrictEqual(res1.tokenCount, 165);
assert.deepStrictEqual(res1.tokenCount, 165 - BaseTokensPerCompletion);
});

test('are prioritized and fit within token budget', async () => {
// Render with smaller token budget and ensure that messages are reduced in size
const smallTokenBudgetEndpoint: any = {
modelMaxPromptTokens: 140,
modelMaxPromptTokens: 140 - BaseTokensPerCompletion,
} satisfies Partial<IChatEndpointInfo>;
const inst2 = new PromptRenderer(
smallTokenBudgetEndpoint,
Expand All @@ -619,7 +619,7 @@ suite('PromptRenderer', () => {
tokenizer
);
const res2 = await inst2.render(undefined, undefined);
assert.equal(res2.tokenCount, 120);
assert.equal(res2.tokenCount, 120 - BaseTokensPerCompletion);
assert.deepStrictEqual(res2.messages, [
{
role: 'system',
Expand Down Expand Up @@ -706,7 +706,7 @@ suite('PromptRenderer', () => {
}

const smallTokenBudgetEndpoint: any = {
modelMaxPromptTokens: 150,
modelMaxPromptTokens: 150 - BaseTokensPerCompletion,
} satisfies Partial<IChatEndpointInfo>;
const inst2 = new PromptRenderer(
smallTokenBudgetEndpoint,
Expand Down Expand Up @@ -775,7 +775,7 @@ LOW MED 00 01 02 03 04 05 06 07 08 09
}

const smallTokenBudgetEndpoint: any = {
modelMaxPromptTokens: 150,
modelMaxPromptTokens: 150 - BaseTokensPerCompletion,
} satisfies Partial<IChatEndpointInfo>;
const inst2 = new PromptRenderer(
smallTokenBudgetEndpoint,
Expand Down Expand Up @@ -829,7 +829,7 @@ LOW MED 00 01 02 03 04 05 06 07 08 09

test('reports reference that survived prioritization', async () => {
const endpoint: any = {
modelMaxPromptTokens: 4096,
modelMaxPromptTokens: 4096 - BaseTokensPerCompletion,
} satisfies Partial<IChatEndpointInfo>;

const inst = new PromptRenderer(
Expand Down Expand Up @@ -891,7 +891,7 @@ LOW MED 00 01 02 03 04 05 06 07 08 09
}

const endpoint: any = {
modelMaxPromptTokens: 4096,
modelMaxPromptTokens: 4096 - BaseTokensPerCompletion,
} satisfies Partial<IChatEndpointInfo>;

const inst = new PromptRenderer(
Expand Down

0 comments on commit c88008a

Please sign in to comment.