From af5ac44261f24e00b980945dc47b75a249567b7b Mon Sep 17 00:00:00 2001 From: Connor Peet Date: Wed, 9 Oct 2024 13:38:41 -0700 Subject: [PATCH] add better handling of whitespace, Chunk utility --- README.md | 8 +++ src/base/index.ts | 2 +- src/base/materialized.ts | 72 ++++++++++++++++------- src/base/promptElements.tsx | 26 ++++++++- src/base/promptRenderer.ts | 19 ++++-- src/base/test/materialized.test.ts | 14 ++--- src/base/test/renderer.test.tsx | 92 +++++++++++++++++++++++++++++- src/base/types.ts | 2 +- 8 files changed, 195 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index e925b0d..aec3d49 100644 --- a/README.md +++ b/README.md @@ -173,6 +173,14 @@ In this case, a very long `userQuery` would get pruned from the output first if ...would be pruned in the order `B->A->D->C`. If two sibling elements share the same priority, the renderer looks ahead at their direct children and picks whichever one has a child with the lowest priority: if the `SystemMessage` and `UserMessage` in the above example did not declare priorities, the pruning order would be `B->D->A->C`. +Continuous text strings and elements can both be pruned from the tree. If you have a set of elements that you want to either be include all the time or none of the time, you can use the simple `Chunk` utility element: + +```html + + The file I'm editing is: + +``` + ### Flex Behavior Wholesale pruning is not always already. Instead, we'd prefer to include as much of the query as possible. To do this, we can use the `flexGrow` property, which allows an element to use the remainder of its parent's token budget when it's rendered. diff --git a/src/base/index.ts b/src/base/index.ts index d1336f4..6351417 100644 --- a/src/base/index.ts +++ b/src/base/index.ts @@ -20,7 +20,7 @@ export * from './tracer'; export * from './tsx-globals'; export * from './types'; -export { AssistantMessage, FunctionMessage, PrioritizedList, PrioritizedListProps, SystemMessage, TextChunk, TextChunkProps, UserMessage } from './promptElements'; +export { AssistantMessage, FunctionMessage, PrioritizedList, PrioritizedListProps, SystemMessage, TextChunk, TextChunkProps, UserMessage, LegacyPrioritization, Chunk } from './promptElements'; export { PromptElement } from './promptElement'; export { MetadataMap, PromptRenderer, QueueItem, RenderPromptResult } from './promptRenderer'; diff --git a/src/base/materialized.ts b/src/base/materialized.ts index 98afc6a..7cfe4ec 100644 --- a/src/base/materialized.ts +++ b/src/base/materialized.ts @@ -23,15 +23,26 @@ export interface IMaterializedNode { export type MaterializedNode = MaterializedContainer | MaterializedChatMessage | MaterializedChatMessageTextChunk; +export const enum ContainerFlags { + /** It's a {@link LegacyPrioritization} instance */ + IsLegacyPrioritization = 1 << 0, + /** It's a {@link Chunk} instance */ + IsChunk = 1 << 1, +} + export class MaterializedContainer implements IMaterializedNode { constructor( public readonly priority: number, public readonly children: MaterializedNode[], public readonly metadata: PromptMetadata[], - public readonly isLegacyPrioritization = false, + public readonly flags: number, ) { } + public has(flag: ContainerFlags) { + return !!(this.flags & flag); + } + /** @inheritdoc */ async tokenCount(tokenizer: ITokenizer): Promise { let total = 0; @@ -67,16 +78,22 @@ export class MaterializedContainer implements IMaterializedNode { /** * Gets the chat messages the container holds. */ - toChatMessages(): ChatMessage[] { - return this.children.flatMap(child => { + *toChatMessages(): Generator { + for (const child of this.children) { assertContainerOrChatMessage(child); - return child instanceof MaterializedContainer ? child.toChatMessages() : [child.toChatMessage()]; - }) + if (child instanceof MaterializedContainer) { + yield* child.toChatMessages(); + } else if (!child.isEmpty) { + // note: empty messages are already removed during pruning, but the + // consumer might themselves have given us empty messages that we should omit. + yield child.toChatMessage(); + } + } } /** Removes the node in the tree with the lowest priority. */ removeLowestPriorityChild(): void { - if (this.isLegacyPrioritization) { + if (this.has(ContainerFlags.IsLegacyPrioritization)) { removeLowestPriorityLegacy(this); } else { removeLowestPriorityChild(this.children); @@ -84,13 +101,19 @@ export class MaterializedContainer implements IMaterializedNode { } } +export const enum LineBreakBefore { + None, + Always, + IfNotTextSibling, +} + /** A chunk of text in a {@link MaterializedChatMessage} */ export class MaterializedChatMessageTextChunk { constructor( public readonly text: string, public readonly priority: number, public readonly metadata: PromptMetadata[] = [], - public readonly lineBreakBefore: boolean, + public readonly lineBreakBefore: LineBreakBefore, ) { } public upperBoundTokenCount(tokenizer: ITokenizer) { @@ -98,7 +121,7 @@ export class MaterializedChatMessageTextChunk { } private readonly _upperBound = once(async (tokenizer: ITokenizer) => { - return await tokenizer.tokenLength(this.text) + (this.lineBreakBefore ? 1 : 0); + return await tokenizer.tokenLength(this.text) + (this.lineBreakBefore !== LineBreakBefore.None ? 1 : 0); }); } @@ -130,6 +153,11 @@ export class MaterializedChatMessage implements IMaterializedNode { return this._text() } + /** Gets whether the message is empty */ + public get isEmpty() { + return !/\S/.test(this.text); + } + /** Remove the lowest priority chunk among this message's children. */ removeLowestPriorityChild() { removeLowestPriorityChild(this.children); @@ -161,14 +189,17 @@ export class MaterializedChatMessage implements IMaterializedNode { private readonly _text = once(() => { let result = ''; - for (const chunk of textChunks(this)) { - if (chunk.lineBreakBefore && result.length && !result.endsWith('\n')) { - result += '\n'; + for (const { text, isTextSibling } of textChunks(this)) { + if (text.lineBreakBefore === LineBreakBefore.Always || (text.lineBreakBefore === LineBreakBefore.IfNotTextSibling && !isTextSibling)) { + if (result.length && !result.endsWith('\n')) { + result += '\n'; + } } - result += chunk.text; + + result += text.text; } - return result; + return result.trim(); }); public toChatMessage(): ChatMessage { @@ -221,17 +252,14 @@ function assertContainerOrChatMessage(v: MaterializedNode): asserts v is Materia } -function* textChunks(node: MaterializedNode): Generator { - if (node instanceof MaterializedChatMessageTextChunk) { - yield node; - return; - } - +function* textChunks(node: MaterializedContainer | MaterializedChatMessage, isTextSibling = false): Generator<{ text: MaterializedChatMessageTextChunk; isTextSibling: boolean }> { for (const child of node.children) { if (child instanceof MaterializedChatMessageTextChunk) { - yield child; + yield { text: child, isTextSibling }; + isTextSibling = true; } else { - yield* textChunks(child); + yield* textChunks(child, isTextSibling); + isTextSibling = false; } } } @@ -309,7 +337,7 @@ function removeLowestPriorityChild(children: MaterializedNode[]) { } const lowest = children[lowestIndex]; - if (lowest instanceof MaterializedChatMessageTextChunk) { + if (lowest instanceof MaterializedChatMessageTextChunk || (lowest instanceof MaterializedContainer && lowest.has(ContainerFlags.IsChunk))) { children.splice(lowestIndex, 1); } else { lowest.removeLowestPriorityChild(); diff --git a/src/base/promptElements.tsx b/src/base/promptElements.tsx index ffd21e1..759e966 100644 --- a/src/base/promptElements.tsx +++ b/src/base/promptElements.tsx @@ -4,7 +4,6 @@ import type { CancellationToken } from 'vscode'; import { contentType } from '.'; -import * as JSONT from './jsonTypes'; import { ChatRole } from './openai'; import { PromptElement } from './promptElement'; import { BasePromptElementProps, PromptPiece, PromptSizing } from './types'; @@ -232,12 +231,22 @@ export class PrioritizedList extends PromptElement { return ( <> {children.map((child, i) => { - child.props ??= {}; - child.props.priority = this.props.descending + if (!child) { + return; + } + + const priority = this.props.descending ? // First element in array of children has highest priority this.props.priority - i : // Last element in array of children has highest priority this.props.priority - children.length + i; + + if (typeof child !== 'object') { + return {child}; + } + + child.props ??= {}; + child.props.priority = priority; return child; })} @@ -283,3 +292,14 @@ export class LegacyPrioritization extends PromptElement { return <>{this.props.children}; } } + +/** + * Marker element that ensures all of its children are either included, or + * not included. This is similar to the `` element, but it is more + * basic and can contain extrinsic children. + */ +export class Chunk extends PromptElement { + render() { + return <>{this.props.children}; + } +} diff --git a/src/base/promptRenderer.ts b/src/base/promptRenderer.ts index 2f35cee..97e55ca 100644 --- a/src/base/promptRenderer.ts +++ b/src/base/promptRenderer.ts @@ -5,10 +5,10 @@ import type { CancellationToken, Progress } from "vscode"; import * as JSONT from './jsonTypes'; import { PromptNodeType } from './jsonTypes'; -import { MaterializedChatMessage, MaterializedChatMessageTextChunk, MaterializedContainer } from './materialized'; +import { ContainerFlags, LineBreakBefore, MaterializedChatMessage, MaterializedChatMessageTextChunk, MaterializedContainer } from './materialized'; import { ChatMessage } from "./openai"; import { PromptElement } from "./promptElement"; -import { AssistantMessage, BaseChatMessage, ChatMessagePromptElement, LegacyPrioritization, TextChunk, ToolMessage, isChatMessagePromptElement } from "./promptElements"; +import { AssistantMessage, BaseChatMessage, ChatMessagePromptElement, Chunk, LegacyPrioritization, TextChunk, ToolMessage, isChatMessagePromptElement } from "./promptElements"; import { PromptMetadata, PromptReference } from "./results"; import { ITokenizer } from "./tokenizer/tokenizer"; import { ITracer } from './tracer'; @@ -253,7 +253,7 @@ export class PromptRenderer

{ } // Then finalize the chat messages - const messageResult = container.toChatMessages(); + const messageResult = [...container.toChatMessages()]; const tokenCount = await container.tokenCount(this._tokenizer); const remainingMetadata = [...container.allMetadata()]; @@ -646,11 +646,15 @@ class PromptTreeElement { ); return parent; } else { + let flags = 0; + if (this._obj instanceof LegacyPrioritization) flags |= ContainerFlags.IsLegacyPrioritization; + if (this._obj instanceof Chunk) flags |= ContainerFlags.IsChunk; + return new MaterializedContainer( this._obj?.props.priority || 0, this._children.map(child => child.materialize()), this._metadata, - this._obj instanceof LegacyPrioritization, + flags, ); } } @@ -682,7 +686,12 @@ class PromptText { } public materialize() { - return new MaterializedChatMessageTextChunk(this.text, this.priority ?? Number.MAX_SAFE_INTEGER, this.metadata || [], this.lineBreakBefore || this.childIndex === 0); + const lineBreak = this.lineBreakBefore + ? LineBreakBefore.Always + : this.childIndex === 0 + ? LineBreakBefore.IfNotTextSibling + : LineBreakBefore.None; + return new MaterializedChatMessageTextChunk(this.text, this.priority ?? Number.MAX_SAFE_INTEGER, this.metadata || [], lineBreak); } public toJSON(): JSONT.TextJSON { diff --git a/src/base/test/materialized.test.ts b/src/base/test/materialized.test.ts index 50f3c9d..0aaf53b 100644 --- a/src/base/test/materialized.test.ts +++ b/src/base/test/materialized.test.ts @@ -3,7 +3,7 @@ *--------------------------------------------------------------------------------------------*/ import * as assert from 'assert'; -import { MaterializedChatMessage, MaterializedChatMessageTextChunk, MaterializedContainer } from '../materialized'; +import { LineBreakBefore, MaterializedChatMessage, MaterializedChatMessageTextChunk, MaterializedContainer } from '../materialized'; import { ChatRole } from '../openai'; import { ITokenizer } from '../tokenizer/tokenizer'; class MockTokenizer implements ITokenizer { @@ -17,10 +17,10 @@ class MockTokenizer implements ITokenizer { suite('Materialized', () => { test('should calculate token count correctly', async () => { const tokenizer = new MockTokenizer(); - const child1 = new MaterializedChatMessageTextChunk('Hello', 1, [], false); - const child2 = new MaterializedChatMessageTextChunk('World', 1, [], false); + const child1 = new MaterializedChatMessageTextChunk('Hello', 1, [], LineBreakBefore.None); + const child2 = new MaterializedChatMessageTextChunk('World', 1, [], LineBreakBefore.None); const message = new MaterializedChatMessage(ChatRole.User, 'user', undefined, undefined, 1, 0, [], [child1, child2]); - const container = new MaterializedContainer(1, [message], []); + const container = new MaterializedContainer(1, [message], [], 0); assert.deepStrictEqual(await container.tokenCount(tokenizer), 13); container.removeLowestPriorityChild(); @@ -29,10 +29,10 @@ suite('Materialized', () => { test('should calculate lower bound token count correctly', async () => { const tokenizer = new MockTokenizer(); - const child1 = new MaterializedChatMessageTextChunk('Hello', 1, [], false); - const child2 = new MaterializedChatMessageTextChunk('World', 1, [], false); + const child1 = new MaterializedChatMessageTextChunk('Hello', 1, [], LineBreakBefore.None); + const child2 = new MaterializedChatMessageTextChunk('World', 1, [], LineBreakBefore.None); const message = new MaterializedChatMessage(ChatRole.User, 'user', undefined, undefined, 1, 0, [], [child1, child2]); - const container = new MaterializedContainer(1, [message], []); + const container = new MaterializedContainer(1, [message], [], 0); assert.deepStrictEqual(await container.upperBoundTokenCount(tokenizer), 13); container.removeLowestPriorityChild(); diff --git a/src/base/test/renderer.test.tsx b/src/base/test/renderer.test.tsx index 4b23ea7..1383c07 100644 --- a/src/base/test/renderer.test.tsx +++ b/src/base/test/renderer.test.tsx @@ -8,6 +8,7 @@ import { BaseTokensPerCompletion, ChatMessage, ChatRole } from '../openai'; import { PromptElement } from '../promptElement'; import { AssistantMessage, + Chunk, LegacyPrioritization, PrioritizedList, SystemMessage, @@ -305,6 +306,18 @@ suite('PromptRenderer', () => { , ['a', 'b', 'c']); }); + test('chunks together', async () => { + await assertPruningOrder(<> + + + a + b + + c + + , ['a', 'c']); // 'b' should not get individually removed and cause a change + }); + test('does not scope priorities in fragments', async () => { await assertPruningOrder(<> @@ -1158,6 +1171,83 @@ suite('PromptRenderer', () => { ); }); + test('does not emit empty messages', async () => { + const inst = new PromptRenderer( + fakeEndpoint, + class extends PromptElement { + render() { + return <> + + Hello! + ; + } + }, + {}, + new FakeTokenizer() + ); + const res = await inst.render(undefined, undefined); + assert.deepStrictEqual(res.messages, [ + { + role: 'user', + content: 'Hello!', + } + ]); + }); + + test('does not add a line break in an embedded message', async () => { + class Inner extends PromptElement { + render() { + return <>world; + } + } + const inst = new PromptRenderer( + fakeEndpoint, + class extends PromptElement { + render() { + return <> + Hello ! + ; + } + }, + {}, + new FakeTokenizer() + ); + const res = await inst.render(undefined, undefined); + assert.deepStrictEqual(res.messages, [ + { + role: 'user', + content: 'Hello world!', + } + ]); + }); + + test('adds line break between two nested embedded messages', async () => { + class Inner extends PromptElement { + render() { + return <>world; + } + } + const inst = new PromptRenderer( + fakeEndpoint, + class extends PromptElement { + render() { + return <> + + ; + } + }, + {}, + new FakeTokenizer() + ); + const res = await inst.render(undefined, undefined); + assert.deepStrictEqual(res.messages, [ + { + role: 'user', + content: 'world\nworld', + } + ]); + }); + test('none-grow, greedy-grow, grow elements', async () => { await flexTest(<> @@ -1477,7 +1567,7 @@ suite('PromptRenderer', () => { test('local is pruned when chunk is pruned', async () => { const res = await new PromptRenderer( - { modelMaxPromptTokens: 5 } as any, + { modelMaxPromptTokens: 1 } as any, class extends PromptElement { render() { return diff --git a/src/base/types.ts b/src/base/types.ts index eecba7a..81fa68a 100644 --- a/src/base/types.ts +++ b/src/base/types.ts @@ -83,7 +83,7 @@ export interface PromptElementCtor

{ } export interface RuntimePromptElementProps { - children?: PromptPiece[]; + children?: PromptPieceChild[]; } export type PromptElementProps = T & BasePromptElementProps & RuntimePromptElementProps;