diff --git a/README.md b/README.md index a2c6439..999cf0e 100644 --- a/README.md +++ b/README.md @@ -273,6 +273,24 @@ There are a few similar properties which control budget allocation you mind find It's important to note that all of the `flex*` properties allow for cooperative use of the token budget for a prompt, but have no effect on the prioritization and pruning logic undertaken once all elements are rendered. +### Local Priority Limits + +`prompt-tsx` provides a `TokenLimit` element that can be used to set a hard cap on the number of tokens that can be consumed by a prompt or part of a prompt. Using it is fairly straightforward: + +```tsx +class PromptWithLimit extends PromptElement { + render() { + return ( + + {/* Your elements here! */} + + ); + } +} +``` + +`TokenLimit` subtrees are pruned before the prompt gets pruned. As you would expect, the `PromptSizing` of child elements inside of a limit reflect the reduced budget. If the `TokenLimit` would get `tokenBudget` smaller than its maximum via the usual distribution rules, then that's given it child elements instead (but pruning to the `max` value still happens.) + ### Expandable Text The tools provided by `flex*` attributes are good, but sometimes you may still end up with unused space in your token budget that you'd like to utilize. We provide a special `` element that can be used in this case. It takes a callback that can return a text string. diff --git a/package.json b/package.json index 2d905c9..578377a 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@vscode/prompt-tsx", - "version": "0.3.0-alpha.12", + "version": "0.3.0-alpha.13", "description": "Declare LLM prompts with TSX", "main": "./dist/base/index.js", "types": "./dist/base/index.d.ts", diff --git a/src/base/materialized.ts b/src/base/materialized.ts index 90d0f13..60ddcb6 100644 --- a/src/base/materialized.ts +++ b/src/base/materialized.ts @@ -54,12 +54,9 @@ export class MaterializedContainer implements IMaterializedNode { let total = 0; await Promise.all( this.children.map(async child => { - // note: this method is not called when the container is inside a chat - // message, because in that case the chat message generates the text - // and counts that. - assertContainerOrChatMessage(child); - - const amt = await child.tokenCount(tokenizer); + const amt = isContainerType(child) + ? await child.tokenCount(tokenizer) + : await child.upperBoundTokenCount(tokenizer); total += amt; }) ); @@ -92,6 +89,13 @@ export class MaterializedContainer implements IMaterializedNode { return allMetadata(this); } + /** + * Finds a node in the tree by ID. + */ + findById(nodeId: number): MaterializedContainer | MaterializedChatMessage | undefined { + return findNodeById(nodeId, this); + } + /** * Gets the chat messages the container holds. */ @@ -200,6 +204,13 @@ export class MaterializedChatMessage implements IMaterializedNode { this._text.clear(); } + /** + * Finds a node in the tree by ID. + */ + findById(nodeId: number): MaterializedContainer | MaterializedChatMessage | undefined { + return findNodeById(nodeId, this); + } + private readonly _tokenCount = once(async (tokenizer: ITokenizer) => { return tokenizer.countMessageTokens(this.toChatMessage()); }); @@ -359,7 +370,12 @@ function removeLowestPriorityLegacy(root: MaterializedNode) { function removeLowestPriorityChild(node: MaterializedContainer | MaterializedChatMessage) { let lowest: | undefined - | { chain: (MaterializedContainer | MaterializedChatMessage)[]; index: number; value: MaterializedNode; lowestNested?: number }; + | { + chain: (MaterializedContainer | MaterializedChatMessage)[]; + index: number; + value: MaterializedNode; + lowestNested?: number; + }; // In *most* cases the chain is always [node], but it can be longer if // the `passPriority` is used. We need to keep track of the chain to @@ -457,3 +473,21 @@ function replaceNode( } } } + +function findNodeById( + nodeId: number, + container: MaterializedContainer | MaterializedChatMessage +): MaterializedContainer | MaterializedChatMessage | undefined { + if (container.id === nodeId) { + return container; + } + + for (const child of container.children) { + if (isContainerType(child)) { + const inner = findNodeById(nodeId, child); + if (inner) { + return inner; + } + } + } +} diff --git a/src/base/promptElements.tsx b/src/base/promptElements.tsx index 96f272c..b15946d 100644 --- a/src/base/promptElements.tsx +++ b/src/base/promptElements.tsx @@ -2,11 +2,16 @@ * Copyright (c) Microsoft Corporation and GitHub. All rights reserved. *--------------------------------------------------------------------------------------------*/ -import type { CancellationToken, LanguageModelPromptTsxPart, LanguageModelTextPart, LanguageModelToolResult } from 'vscode'; +import type { + CancellationToken, + LanguageModelPromptTsxPart, + LanguageModelTextPart, + LanguageModelToolResult, +} from 'vscode'; import { contentType } from '.'; import { ChatRole } from './openai'; import { PromptElement } from './promptElement'; -import { BasePromptElementProps, PromptPiece, PromptSizing } from './types'; +import { BasePromptElementProps, PromptElementProps, PromptPiece, PromptSizing } from './types'; import { PromptElementJSON } from './jsonTypes'; export type ChatMessagePromptElement = SystemMessage | UserMessage | AssistantMessage; @@ -221,7 +226,7 @@ export interface PrioritizedListProps extends BasePromptElementProps { * Priority of the list element. * All rendered elements in this list receive a priority that is offset from this value. */ - priority: number; + priority?: number; /** * If `true`, assign higher priority to elements declared earlier in this list. */ @@ -233,7 +238,7 @@ export interface PrioritizedListProps extends BasePromptElementProps { */ export class PrioritizedList extends PromptElement { override render() { - const children = this.props.children; + const { children, priority = 0, descending } = this.props; if (!children) { return; } @@ -245,18 +250,18 @@ export class PrioritizedList extends PromptElement { return; } - const priority = this.props.descending + const thisPriority = descending ? // First element in array of children has highest priority - this.props.priority - i + priority - i : // Last element in array of children has highest priority - this.props.priority - children.length + i; + priority - children.length + i; if (typeof child !== 'object') { - return {child}; + return {child}; } child.props ??= {}; - child.props.priority = priority; + child.props.priority = thisPriority; return child; })} @@ -284,15 +289,19 @@ export class ToolResult extends PromptElement { // note: future updates to content types should be handled here for backwards compatibility const vscode = require('vscode'); // TODO proper way to handle types here? - return <> - {this.props.data.content.map(part => { - if (part instanceof vscode.LanguageModelTextPart) { - return (part as LanguageModelTextPart).value; - } else if (part instanceof vscode.LanguageModelPromptTsxPart) { - return ; - } - })} - ; + return ( + <> + {this.props.data.content.map(part => { + if (part instanceof vscode.LanguageModelTextPart) { + return (part as LanguageModelTextPart).value; + } else if (part instanceof vscode.LanguageModelPromptTsxPart) { + return ( + + ); + } + })} + + ); } } @@ -335,3 +344,18 @@ export class Expandable extends PromptElement { return <>{await this.props.value(sizing)}; } } + +export interface TokenLimitProps extends BasePromptElementProps { + max: number; +} + +/** + * An element that ensures its children don't exceed a certain number of + * `maxTokens`. Its contents are pruned to fit within the budget before + * the overall prompt pruning is run. + */ +export class TokenLimit extends PromptElement { + render(): PromptPiece { + return <>{this.props.children}; + } +} diff --git a/src/base/promptRenderer.ts b/src/base/promptRenderer.ts index 519dadf..ac7c879 100644 --- a/src/base/promptRenderer.ts +++ b/src/base/promptRenderer.ts @@ -23,6 +23,8 @@ import { isChatMessagePromptElement, LegacyPrioritization, TextChunk, + TokenLimit, + TokenLimitProps, ToolMessage, } from './promptElements'; import { PromptMetadata, PromptReference } from './results'; @@ -84,6 +86,7 @@ export class PromptRenderer

{ private readonly _ignoredFiles: URI[] = []; private readonly _growables: { initialConsume: number; elem: PromptTreeElement }[] = []; private readonly _root = new PromptTreeElement(null, 0); + private readonly _tokenLimits: { limit: number; id: number }[] = []; /** Epoch used to tracing the order in which elements render. */ public tracer: ITracer | undefined = undefined; @@ -124,6 +127,7 @@ export class PromptRenderer

{ { element: QueueItem, P>; promptElementInstance: PromptElement; + tokenLimit: number | undefined; }[] >(); for (const [i, element] of pieces.entries()) { @@ -141,6 +145,11 @@ export class PromptRenderer

{ } const promptElement = this.createElement(element); + let tokenLimit: number | undefined; + if (promptElement instanceof TokenLimit) { + tokenLimit = (element.props as unknown as TokenLimitProps).max; + this._tokenLimits.push({ limit: tokenLimit, id: element.node.id }); + } element.node.setObj(promptElement); // Prepare rendering @@ -151,7 +160,7 @@ export class PromptRenderer

{ promptElements.set(flexGroupValue, flexGroup); } - flexGroup.push({ element, promptElementInstance: promptElement }); + flexGroup.push({ element, promptElementInstance: promptElement, tokenLimit }); } if (promptElements.size === 0) { @@ -189,15 +198,37 @@ export class PromptRenderer

{ // Calculate the flex basis for dividing the budget amongst siblings in this group. let flexBasisSum = 0; for (const { element } of promptElements) { - // todo@connor4312: remove `flex` after transition - flexBasisSum += (element.props.flex || element.props.flexBasis) ?? 1; + flexBasisSum += element.props.flexBasis ?? 1; } + let constantTokenLimits = 0; + //.For elements that limit their token usage and would use less than we + // otherwise would assign to them, 'cap' their usage at the limit and + // remove their share directly from the budget in distribution. + const useConstantLimitsForIndex = promptElements.map(e => { + if (e.tokenLimit === undefined) { + return false; + } + + const flexBasis = e.element.props.flexBasis ?? 1; + const proportion = flexBasis / flexBasisSum; + const proportionateUsage = Math.floor(sizing.remainingTokenBudget * proportion); + if (proportionateUsage < e.tokenLimit) { + return false; + } + + flexBasisSum -= flexBasis; + constantTokenLimits += e.tokenLimit; + return true; + }); + // Finally calculate the final sizing for each element in this group. - const elementSizings: PromptSizing[] = promptElements.map(e => { + const elementSizings: PromptSizing[] = promptElements.map((e, i) => { const proportion = (e.element.props.flexBasis ?? 1) / flexBasisSum; return { - tokenBudget: Math.floor(sizing.remainingTokenBudget * proportion), + tokenBudget: useConstantLimitsForIndex[i] + ? e.tokenLimit! + : Math.floor((sizing.remainingTokenBudget - constantTokenLimits) * proportion), endpoint: sizing.endpoint, countTokens: (text, cancellation) => this._tokenizer.tokenLength(text, cancellation), }; @@ -399,35 +430,49 @@ export class PromptRenderer

{ * around with budgets. It should be side-effect-free. */ private async _getFinalElementTree(tokenBudget: number, token: CancellationToken | undefined) { - // Trim the elements to fit within the token budget. We check the "lower bound" - // first because that's much more cache-friendly as we remove elements. - const container = this._root.materialize() as MaterializedContainer; - const initialTokenCount = await container.tokenCount(this._tokenizer); - if (initialTokenCount < tokenBudget) { - const didChange = await this._grow(container, initialTokenCount, tokenBudget, token); - - // if nothing grew, we already counted tokens so we can safely return - if (!didChange) { - return { container, allMetadata: [...container.allMetadata()], removed: 0 }; + const root = this._root.materialize() as MaterializedContainer; + const allMetadata = [...root.allMetadata()]; + const limits = [{ limit: tokenBudget, id: this._root.id }, ...this._tokenLimits]; + let removed = 0; + + for (let i = limits.length - 1; i >= 0; i--) { + const limit = limits[i]; + if (limit.limit > tokenBudget) { + continue; } - } - const allMetadata = [...container.allMetadata()]; - let removed = 0; - while ( - (await container.upperBoundTokenCount(this._tokenizer)) > tokenBudget && - (await container.tokenCount(this._tokenizer)) > tokenBudget - ) { - container.removeLowestPriorityChild(); - removed++; + const container = root.findById(limit.id); + if (!container) { + continue; + } + + const initialTokenCount = await container.tokenCount(this._tokenizer); + if (initialTokenCount < limit.limit) { + const didChange = await this._grow(container, initialTokenCount, limit.limit, token); + + // if nothing grew, we already counted tokens so we can safely return + if (!didChange) { + continue; + } + } + + // Trim the elements to fit within the token budget. We check the "lower bound" + // first because that's much more cache-friendly as we remove elements. + while ( + (await container.upperBoundTokenCount(this._tokenizer)) > limit.limit && + (await container.tokenCount(this._tokenizer)) > limit.limit + ) { + container.removeLowestPriorityChild(); + removed++; + } } - return { container, allMetadata, removed }; + return { container: root, allMetadata, removed }; } /** Grows all Expandable elements, returns if any changes were made. */ private async _grow( - tree: MaterializedContainer, + tree: MaterializedContainer | MaterializedChatMessage, tokensUsed: number, tokenBudget: number, token: CancellationToken | undefined @@ -437,6 +482,10 @@ export class PromptRenderer

{ } for (const growable of this._growables) { + if (!tree.findById(growable.elem.id)) { + continue; // not in this subtree + } + const obj = growable.elem.getObj(); if (!(obj instanceof Expandable)) { throw new Error('unreachable: expected growable'); diff --git a/src/base/test/renderer.test.tsx b/src/base/test/renderer.test.tsx index 6652888..b8677eb 100644 --- a/src/base/test/renderer.test.tsx +++ b/src/base/test/renderer.test.tsx @@ -14,6 +14,7 @@ import { PrioritizedList, SystemMessage, TextChunk, + TokenLimit, ToolMessage, ToolResult, UserMessage, @@ -1920,6 +1921,206 @@ suite('PromptRenderer', () => { const inst = new PromptRenderer(fakeEndpoint, Wrapper, {}, tokenizer); const res = await inst.render(undefined, undefined); - assert.deepStrictEqual(res.messages.map(m => m.content).join(''), 'hello everyone in the world!'); + assert.deepStrictEqual( + res.messages.map(m => m.content).join(''), + 'hello everyone in the world!' + ); + }); + + suite('TokenLimit', () => { + test('limits tokens within budget', async () => { + class PromptWithLimit extends PromptElement { + render() { + return ( + + outside + + + 12345 + 67890 + extra + + + + ); + } + } + + const inst = new PromptRenderer(fakeEndpoint, PromptWithLimit, {}, tokenizer); + const res = await inst.render(undefined, undefined); + assert.deepStrictEqual(res.messages, [ + { + role: 'user', + content: 'outside\n12345\n67890', + }, + ]); + }); + + test('child elements get lower token limit', async () => { + class Wrapper extends PromptElement<{ expected: number } & BasePromptElementProps> { + render(_: void, sizing: PromptSizing) { + assert.strictEqual(sizing.tokenBudget, this.props.expected); + return <>asdf; + } + } + + class PromptWithLimit extends PromptElement { + render() { + return ( + + + + + + + ); + } + } + + const inst = new PromptRenderer(fakeEndpoint, PromptWithLimit, {}, tokenizer); + const res = await inst.render(undefined, undefined); + assert.deepStrictEqual(res.messages, [ + { + role: 'user', + content: 'asdf\nasdf', + }, + ]); + }); + + test('does not redistribute with higher than proportionate limit', async () => { + class Wrapper extends PromptElement<{ expected: number } & BasePromptElementProps> { + render(_: void, sizing: PromptSizing) { + assert.strictEqual(sizing.tokenBudget, this.props.expected); + return <>asdf; + } + } + + class PromptWithLimit extends PromptElement { + render() { + return ( + + + + + + + ); + } + } + + const inst = new PromptRenderer(fakeEndpoint, PromptWithLimit, {}, tokenizer); + const res = await inst.render(undefined, undefined); + assert.deepStrictEqual(res.messages, [ + { + role: 'user', + content: 'asdf\nasdf', + }, + ]); + }); + + test('works with multiple', async () => { + class Wrapper extends PromptElement<{ expected: number } & BasePromptElementProps> { + render(_: void, sizing: PromptSizing) { + assert.strictEqual(sizing.tokenBudget, this.props.expected); + return <>asdf; + } + } + + class PromptWithLimit extends PromptElement { + render() { + return ( + + + + {/* Included in distribution */} + + + + {/* excluded from distribution because of large size */} + + + + + ); + } + } + + const inst = new PromptRenderer(fakeEndpoint, PromptWithLimit, {}, tokenizer); + const res = await inst.render(undefined, undefined); + assert.deepStrictEqual(res.messages, [ + { + role: 'user', + content: 'asdf\nasdf\nasdf\nasdf', + }, + ]); + }); + + test('limits if nested outer < inner', async () => { + class PromptWithLimit extends PromptElement { + render() { + return ( + + + + 12345 + 67890 + extra + + + + 12345 + 67890 + extra + + + + + ); + } + } + + const inst = new PromptRenderer(fakeEndpoint, PromptWithLimit, {}, tokenizer); + const res = await inst.render(undefined, undefined); + assert.deepStrictEqual(res.messages, [ + { + role: 'user', + content: '12345\n67890\n12345\n67890\nextra', + }, + ]); + }); + + test('limits if nested outer > inner', async () => { + class PromptWithLimit extends PromptElement { + render() { + return ( + + + + 12345 + 67890 + extra + + + + 12345 + 67890 + extra + + + + + ); + } + } + + const inst = new PromptRenderer(fakeEndpoint, PromptWithLimit, {}, tokenizer); + const res = await inst.render(undefined, undefined); + assert.deepStrictEqual(res.messages, [ + { + role: 'user', + content: '12345\n67890\nextra\n12345\n67890', + }, + ]); + }); }); }); diff --git a/src/base/types.ts b/src/base/types.ts index 9c162ad..5be1577 100644 --- a/src/base/types.ts +++ b/src/base/types.ts @@ -71,9 +71,6 @@ export interface BasePromptElementProps { */ flexBasis?: number; - /** @deprecated renamed to {@link flexBasis} */ - flex?: number; - /** * If set, sibling elements will be rendered first, followed by this element. The remaining {@link PromptSizing.tokenBudget token budget} from the container will be distributed among the elements with `flexGrow` set. *