diff --git a/CHANGELOG.md b/CHANGELOG.md index 1ef3f03..84540c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +## 0.3.0-alpha.5 + +- **feat:** add `Expandable` elements to the renderer. See the [readme](./README.md#expandable-text) for details. + ## 0.3.0-alpha.4 - **feat:** enhance the `HTMLTracer` to allow consumers to visualize element pruning order diff --git a/README.md b/README.md index b3fe031..b541848 100644 --- a/README.md +++ b/README.md @@ -183,7 +183,7 @@ Continuous text strings and elements can both be pruned from the tree. If you ha ### Flex Behavior -Wholesale pruning is not always already. Instead, we'd prefer to include as much of the query as possible. To do this, we can use the `flexGrow` property, which allows an element to use the remainder of its parent's token budget when it's rendered. +Wholesale pruning is not always ideal. Instead, we'd prefer to include as much of the query as possible. To do this, we can use the `flexGrow` property, which allows an element to use the remainder of its parent's token budget when it's rendered. `prompt-tsx` provides a utility component that supports this use case: `TextChunk`. Given input text, and optionally a delimiting string or regular expression, it'll include as much of the text as possible to fit within its budget: @@ -230,6 +230,27 @@ There are a few similar properties which control budget allocation you mind find It's important to note that all of the `flex*` properties allow for cooperative use of the token budget for a prompt, but have no effect on the prioritization and pruning logic undertaken once all elements are rendered. +### Expandable Text + +The tools provided by `flex*` attributes are good, but sometimes you may still end up with unused space in your token budget that you'd like to utilize. We provide a special `` element that can be used in this case. It takes a callback that can return a text string. + +```tsx + { + let data = 'hi'; + while (true) { + const more = getMoreUsefulData(); + if (await sizing.countTokens(data + more) > sizing.tokenBudget) { break } + data += more; + } + } + return data; +}} /> +``` + +After the prompt is rendered, the renderer sums up the tokens used by all messages. If there is unused budget, then any `` elements' values are called again with their `PromptSizing` is increased by the token excess. + +If there are multiple `` elements, then they're re-called in the order in which they were initially rendered. Because they're designed to fill up any remaining space, it usually makes sense to have at most one `` element per prompt. + #### Debugging Budgeting You can set a `tracer` property on the `PromptElement` to debug how your elements are rendered and how this library allocates your budget. We include a basic `HTMLTracer` you can use, which can be served on an address: diff --git a/package-lock.json b/package-lock.json index d3115be..1c02677 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "@vscode/prompt-tsx", - "version": "0.3.0-alpha.4", + "version": "0.3.0-alpha.5", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "@vscode/prompt-tsx", - "version": "0.3.0-alpha.4", + "version": "0.3.0-alpha.5", "license": "SEE LICENSE IN LICENSE", "devDependencies": { "@microsoft/tiktokenizer": "^1.0.6", diff --git a/package.json b/package.json index dfafe7a..8e1c135 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@vscode/prompt-tsx", - "version": "0.3.0-alpha.4", + "version": "0.3.0-alpha.5", "description": "Declare LLM prompts with TSX", "main": "./dist/base/index.js", "types": "./dist/base/index.d.ts", diff --git a/src/base/materialized.ts b/src/base/materialized.ts index 797d8a6..b37d102 100644 --- a/src/base/materialized.ts +++ b/src/base/materialized.ts @@ -70,6 +70,13 @@ export class MaterializedContainer implements IMaterializedNode { return total; } + /** + * Replaces a node in the tree with the given one, by its ID. + */ + replaceNode(nodeId: number, withNode: MaterializedNode): MaterializedNode | undefined { + return replaceNode(nodeId, this.children, withNode); + } + /** * Gets all metadata the container holds. */ @@ -160,6 +167,18 @@ export class MaterializedChatMessage implements IMaterializedNode { return !/\S/.test(this.text) && !this.toolCalls?.length && !this.toolCallId; } + /** + * Replaces a node in the tree with the given one, by its ID. + */ + replaceNode(nodeId: number, withNode: MaterializedNode): MaterializedNode | undefined { + const replaced = replaceNode(nodeId, this.children, withNode); + if (replaced) { + this.onChunksChange(); + } + + return replaced; + } + /** Remove the lowest priority chunk among this message's children. */ removeLowestPriorityChild() { removeLowestPriorityChild(this.children); @@ -240,13 +259,16 @@ export class MaterializedChatMessage implements IMaterializedNode { } } +function isContainerType(node: MaterializedNode): node is MaterializedContainer | MaterializedChatMessage { + return !(node instanceof MaterializedChatMessageTextChunk); +} + function assertContainerOrChatMessage(v: MaterializedNode): asserts v is MaterializedContainer | MaterializedChatMessage { if (!(v instanceof MaterializedContainer) && !(v instanceof MaterializedChatMessage)) { throw new Error(`Cannot have a text node outside a ChatMessage. Text: "${v.text}"`); } } - function* textChunks(node: MaterializedContainer | MaterializedChatMessage, isTextSibling = false): Generator<{ text: MaterializedChatMessageTextChunk; isTextSibling: boolean }> { for (const child of node.children) { if (child instanceof MaterializedChatMessageTextChunk) { @@ -343,7 +365,7 @@ function removeLowestPriorityChild(children: MaterializedNode[]) { } function getLowestPriorityAmongChildren(node: MaterializedNode): number { - if (node instanceof MaterializedChatMessageTextChunk) { + if (!(isContainerType(node))) { return -1; } @@ -358,10 +380,28 @@ function getLowestPriorityAmongChildren(node: MaterializedNode): number { function* allMetadata(node: MaterializedContainer | MaterializedChatMessage): Generator { yield* node.metadata; for (const child of node.children) { - if (child instanceof MaterializedChatMessageTextChunk) { - yield* child.metadata; - } else { + if (isContainerType(child)) { yield* allMetadata(child); + } else { + yield* child.metadata; + } + } +} + +function replaceNode(nodeId: number, children: MaterializedNode[], withNode: MaterializedNode): MaterializedNode | undefined { + for (let i = 0; i < children.length; i++) { + const child = children[i]; + if (isContainerType(child)) { + if (child.id === nodeId) { + const oldNode = children[i]; + children[i] = withNode; + return oldNode; + } + + const inner = child.replaceNode(nodeId, withNode); + if (inner) { + return inner; + } } } } diff --git a/src/base/promptElements.tsx b/src/base/promptElements.tsx index 759e966..a5551be 100644 --- a/src/base/promptElements.tsx +++ b/src/base/promptElements.tsx @@ -303,3 +303,19 @@ export class Chunk extends PromptElement { return <>{this.props.children}; } } + +export interface ExpandableProps extends BasePromptElementProps { + value: (sizing: PromptSizing) => string | Promise; +} + +/** + * An element that can expand to fill the remaining token budget. Takes + * a `value` function that is initially called with the element's token budget, + * and may be called multiple times with the new token budget as the prompt + * is resized. + */ +export class Expandable extends PromptElement { + async render(_state: void, sizing: PromptSizing): Promise { + return <>{await this.props.value(sizing)}; + } +} diff --git a/src/base/promptRenderer.ts b/src/base/promptRenderer.ts index 82bf021..4119832 100644 --- a/src/base/promptRenderer.ts +++ b/src/base/promptRenderer.ts @@ -8,7 +8,7 @@ import { PromptNodeType } from './jsonTypes'; import { ContainerFlags, LineBreakBefore, MaterializedChatMessage, MaterializedChatMessageTextChunk, MaterializedContainer } from './materialized'; import { ChatMessage } from "./openai"; import { PromptElement } from "./promptElement"; -import { AssistantMessage, BaseChatMessage, ChatMessagePromptElement, Chunk, LegacyPrioritization, TextChunk, ToolMessage, isChatMessagePromptElement } from "./promptElements"; +import { AssistantMessage, BaseChatMessage, ChatMessagePromptElement, Chunk, Expandable, LegacyPrioritization, TextChunk, ToolMessage, isChatMessagePromptElement } from "./promptElements"; import { PromptMetadata, PromptReference } from "./results"; import { ITokenizer } from "./tokenizer/tokenizer"; import { ITracer } from './tracer'; @@ -60,9 +60,9 @@ export class PromptRenderer

{ private readonly _usedContext: ChatDocumentContext[] = []; private readonly _ignoredFiles: URI[] = []; + private readonly _growables: { initialConsume: number; elem: PromptTreeElement }[] = []; private readonly _root = new PromptTreeElement(null, 0); /** Epoch used to tracing the order in which elements render. */ - private _epoch = 0; public tracer: ITracer | undefined = undefined; /** @@ -197,27 +197,52 @@ export class PromptRenderer

{ continue; } - const pieces = flattenAndReduce(template); - - // Compute token budget for the pieces that this child wants to render - const childSizing = new PromptSizingContext(elementSizing.tokenBudget, this._endpoint); - const { tokensConsumed } = await computeTokensConsumedByLiterals(this._tokenizer, element, promptElementInstance, pieces); - childSizing.consume(tokensConsumed); - await this._handlePromptChildren(element, pieces, childSizing, progress, token); + const childConsumption = await this._processPromptRenderPiece( + new PromptSizingContext(elementSizing.tokenBudget, this._endpoint), + element, + promptElementInstance, + template, + progress, + token, + ); + + // Append growables here so that when we go back and expand them we do so in render order. + if (promptElementInstance instanceof Expandable) { + this._growables.push({ initialConsume: childConsumption, elem: element.node }); + } // Tally up the child consumption into the parent context for any subsequent flex group - sizing.consume(childSizing.consumed); + sizing.consume(childConsumption); } } } + private async _processPromptRenderPiece( + elementSizing: PromptSizingContext, + element: QueueItem, any>, + promptElementInstance: PromptElement, + template: PromptPiece, + progress: Progress | undefined, + token: CancellationToken | undefined, + ) { + const pieces = flattenAndReduce(template); + + // Compute token budget for the pieces that this child wants to render + const childSizing = new PromptSizingContext(elementSizing.tokenBudget, this._endpoint); + const { tokensConsumed } = await computeTokensConsumedByLiterals(this._tokenizer, element, promptElementInstance, pieces); + childSizing.consume(tokensConsumed); + await this._handlePromptChildren(element, pieces, childSizing, progress, token); + + // Tally up the child consumption into the parent context for any subsequent flex group + return childSizing.consumed; + } + /** * Renders the prompt element and its children to a JSON-serializable state. * @returns A promise that resolves to an object containing the rendered chat messages and the total token count. * The total token count is guaranteed to be less than or equal to the token budget. */ public async renderElementJSON(token?: CancellationToken): Promise { - this._epoch = 0; await this._processPromptPieces( new PromptSizingContext(this._endpoint.modelMaxPromptTokens, this._endpoint), [{ node: this._root, ctor: this._ctor, props: this._props, children: [] }], @@ -237,7 +262,6 @@ export class PromptRenderer

{ * The total token count is guaranteed to be less than or equal to the token budget. */ public async render(progress?: Progress, token?: CancellationToken): Promise { - this._epoch = 0; // Convert root prompt element to prompt pieces await this._processPromptPieces( new PromptSizingContext(this._endpoint.modelMaxPromptTokens, this._endpoint), @@ -246,12 +270,12 @@ export class PromptRenderer

{ token, ); - const { container, allMetadata, removed } = await this._getFinalElementTree(this._endpoint.modelMaxPromptTokens); + const { container, allMetadata, removed } = await this._getFinalElementTree(this._endpoint.modelMaxPromptTokens, token); this.tracer?.didMaterializeTree?.({ budget: this._endpoint.modelMaxPromptTokens, renderedTree: { container, removed, budget: this._endpoint.modelMaxPromptTokens }, tokenizer: this._tokenizer, - renderTree: budget => this._getFinalElementTree(budget).then(r => ({ ...r, budget })), + renderTree: budget => this._getFinalElementTree(budget, undefined).then(r => ({ ...r, budget })), }); // Then finalize the chat messages @@ -305,10 +329,24 @@ export class PromptRenderer

{ }; } - private async _getFinalElementTree(tokenBudget: number) { + /** + * Note: this may be called multiple times from the tracer as users play + * around with budgets. It should be side-effect-free. + */ + private async _getFinalElementTree(tokenBudget: number, token: CancellationToken | undefined) { // Trim the elements to fit within the token budget. We check the "lower bound" // first because that's much more cache-friendly as we remove elements. const container = this._root.materialize() as MaterializedContainer; + const initialTokenCount = await container.tokenCount(this._tokenizer); + if (initialTokenCount < tokenBudget) { + const didChange = await this._grow(container, initialTokenCount, tokenBudget, token); + + // if nothing grew, we already counted tokens so we can safely return + if (!didChange) { + return { container, allMetadata: [...container.allMetadata()], removed: 0 }; + } + } + const allMetadata = [...container.allMetadata()]; let removed = 0; while ( @@ -322,6 +360,52 @@ export class PromptRenderer

{ return { container, allMetadata, removed }; } + /** Grows all Expandable elements, returns if any changes were made. */ + private async _grow(tree: MaterializedContainer, tokensUsed: number, tokenBudget: number, token: CancellationToken | undefined): Promise { + if (!this._growables.length) { + return false; + } + + for (const growable of this._growables) { + const obj = growable.elem.getObj(); + if (!(obj instanceof Expandable)) { + throw new Error('unreachable: expected growable'); + } + + const tempRoot = new PromptTreeElement(null, 0, growable.elem.id); + // Sizing for the grow is the remaining excess plus the initial consumption, + // since the element consuming the initial amount of tokens will be replaced + const sizing = new PromptSizingContext(tokenBudget - tokensUsed + growable.initialConsume, this._endpoint); + + const newConsumed = await this._processPromptRenderPiece( + sizing, + { node: tempRoot, ctor: this._ctor, props: {}, children: [] }, + obj, + await obj.render(undefined, { + tokenBudget: sizing.tokenBudget, + endpoint: this._endpoint, + countTokens: (text, cancellation) => this._tokenizer.tokenLength(text, cancellation) + }), + undefined, + token, + ); + + const newContainer = tempRoot.materialize() as MaterializedContainer; + const oldContainer = tree.replaceNode(growable.elem.id, newContainer); + if (!oldContainer) { + throw new Error('unreachable: could not find old element to replace'); + } + + tokensUsed -= growable.initialConsume; + tokensUsed += newConsumed; + if (tokensUsed >= tokenBudget) { + break; + } + } + + return true; + } + private _handlePromptChildren(element: QueueItem, P>, pieces: ProcessedPromptPiece[], sizing: PromptSizingContext, progress: Progress | undefined, token: CancellationToken | undefined) { if (element.ctor === TextChunk) { this._handleExtrinsicTextChunkChildren(element.node, element.node, element.props, pieces); @@ -584,7 +668,6 @@ class PromptTreeElement { } public readonly kind = PromptNodeType.Piece; - public readonly id = PromptTreeElement._nextId++; private _obj: PromptElement | null = null; private _state: any | undefined = undefined; @@ -594,12 +677,17 @@ class PromptTreeElement { constructor( public readonly parent: PromptTreeElement | null = null, public readonly childIndex: number, + public readonly id = PromptTreeElement._nextId++, ) { } public setObj(obj: PromptElement) { this._obj = obj; } + public getObj(): PromptElement | null { + return this._obj; + } + public setState(state: any) { this._state = state; } diff --git a/src/base/test/renderer.test.tsx b/src/base/test/renderer.test.tsx index 237b4fb..26c88be 100644 --- a/src/base/test/renderer.test.tsx +++ b/src/base/test/renderer.test.tsx @@ -3,12 +3,13 @@ *--------------------------------------------------------------------------------------------*/ import * as assert from 'assert'; -import { contentType, renderElementJSON, renderPrompt } from '..'; +import { contentType, HTMLTracer, renderElementJSON, renderPrompt } from '..'; import { BaseTokensPerCompletion, ChatMessage, ChatRole } from '../openai'; import { PromptElement } from '../promptElement'; import { AssistantMessage, Chunk, + Expandable, LegacyPrioritization, PrioritizedList, SystemMessage, @@ -1620,4 +1621,162 @@ suite('PromptRenderer', () => { assert.deepStrictEqual(res.metadata.getAll(MyMeta), [new MyMeta(true), new MyMeta(false)]); }); }); + + suite('growable', () => { + test('grows basic', async () => { + const sizingInCalls: number[] = []; + const res = await new PromptRenderer( + { modelMaxPromptTokens: 50 }, + class extends PromptElement { + render() { + return + { + sizingInCalls.push(sizing.tokenBudget); + let str = 'hi'; + while (await sizing.countTokens(str + 'a') <= sizing.tokenBudget) { + str += 'a'; + } + return str; + }} /> + smaller + ; + } + }, + {}, + tokenizer + ).render(); + + assert.deepStrictEqual(sizingInCalls, [ + 23, + 43, + ]); + assert.strictEqual(res.tokenCount, 50); + assert.deepStrictEqual(res.messages, [{ + role: 'user', + content: 'hiaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\nsmaller', + }]); + }); + + test('grows multiple in render order and uses budget', async () => { + const sizingInCalls: string[] = []; + const res = await new PromptRenderer( + { modelMaxPromptTokens: 50 }, + class extends PromptElement { + render() { + return + { + let str = 'hi'; + while (await sizing.countTokens(str + 'a') < sizing.tokenBudget / 2) { + str += 'a'; + } + sizingInCalls.push(`a=${sizing.tokenBudget}`); + return str; + }} /> + { + let str = 'hi'; + while (await sizing.countTokens(str + 'b') < sizing.tokenBudget / 2) { + str += 'b'; + } + sizingInCalls.push(`b=${sizing.tokenBudget}`); + return str; + }} /> + smaller + ; + } + }, + {}, + tokenizer + ).render(); + + assert.deepStrictEqual(sizingInCalls, [ + 'b=23', + 'a=33', + 'b=26', + 'a=30', + ]); + assert.deepStrictEqual(res.messages, [{ + role: 'user', + content: 'hiaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\nhibbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb\nsmaller', + }]); + assert.strictEqual(res.tokenCount, 34); + }); + + test('stops growing early if over budget', async () => { + const sizingInCalls: string[] = []; + const res = await new PromptRenderer( + { modelMaxPromptTokens: 50 }, + class extends PromptElement { + render() { + return + { + sizingInCalls.push(`a=${sizing.tokenBudget}`); + return 'hi'; + }} /> + { + sizingInCalls.push(`b=${sizing.tokenBudget}`); + if (sizing.tokenBudget < 30) { + return 'hi'; + } + let str = 'hi'; + while (await sizing.countTokens(str + 'a') <= sizing.tokenBudget) { + str += 'a'; + } + return str; + }} /> + smaller + ; + } + }, + {}, + tokenizer + ).render(); + + assert.deepStrictEqual(sizingInCalls, [ + 'b=23', + 'a=43', + 'b=41', + ]); + assert.deepStrictEqual(res.messages, [{ + role: 'user', + content: 'hi\nhiaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa\nsmaller', + }]); + }); + + test('still prunes over budget', async () => { + const sizingInCalls: string[] = []; + const res = await new PromptRenderer( + { modelMaxPromptTokens: 50 }, + class extends PromptElement { + render() { + return + { + sizingInCalls.push(`a=${sizing.tokenBudget}`); + return 'hi'; + }} /> + { + sizingInCalls.push(`b=${sizing.tokenBudget}`); + if (sizing.tokenBudget < 30) { + return 'hi'; + } + return 'hi'.repeat(1000); + }} /> + smaller + ; + } + }, + {}, + tokenizer + ).render(); + + assert.deepStrictEqual(sizingInCalls, [ + 'b=23', + 'a=43', + 'b=41', + ]); + assert.deepStrictEqual(res.messages, [{ + role: 'user', + content: 'smaller', + }]); + }); + }); }); diff --git a/src/tracer/index.tsx b/src/tracer/index.tsx index 8bdabe5..d5c993e 100644 --- a/src/tracer/index.tsx +++ b/src/tracer/index.tsx @@ -114,7 +114,7 @@ const App = () => {

{activeTab === 'tokens' - ?

Token changes here will prune elements and re-render 'pure' ones, but the entire prompt is not being re-rendered

+ ?

Token changes here will prune elements and re-render Expandable ones, but the entire prompt is not being re-rendered

:

Changing the render epoch lets you see the order in which elements are rendered and how the token budget is allocated.

}
Used / tokens