From 6e450fbe767d184132a1456eced19d40e2d7db0e Mon Sep 17 00:00:00 2001 From: Connor Peet Date: Wed, 3 Jul 2024 12:41:13 -0700 Subject: [PATCH] fix: ensure text chunk order is retaind (#68) Refs https://github.com/microsoft/vscode-copilot/pull/6583 --- src/base/promptRenderer.ts | 39 +++++++++++++++++++-------------- src/base/test/renderer.test.tsx | 35 +++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 17 deletions(-) diff --git a/src/base/promptRenderer.ts b/src/base/promptRenderer.ts index 53d50d8..deb01ba 100644 --- a/src/base/promptRenderer.ts +++ b/src/base/promptRenderer.ts @@ -95,7 +95,7 @@ export class PromptRenderer

{ private async _processPromptPieces(sizing: PromptSizingContext, pieces: QueueItem, P>[], progress?: Progress, token?: CancellationToken) { // Collect all prompt elements in the next flex group to render, grouping // by the flex order in which they're rendered. - const promptElements = new Map, P>; promptElementInstance: PromptElement }[]>(); + const promptElements = new Map, P>; promptElementInstance: PromptElement }[]>(); for (const [i, element] of pieces.entries()) { // Set any jsx children as the props.children if (Array.isArray(element.children)) { @@ -119,7 +119,7 @@ export class PromptRenderer

{ promptElements.set(flexGroupValue, flexGroup); } - flexGroup.push({ order: i, element, promptElementInstance: promptElement }); + flexGroup.push({ element, promptElementInstance: promptElement }); } const flexGroups = [...promptElements.entries()].sort(([a], [b]) => b - a).map(([_, group]) => group); @@ -171,7 +171,7 @@ export class PromptRenderer

{ })); // Render - for (const [i, { element, promptElementInstance, order }] of promptElements.entries()) { + for (const [i, { element, promptElementInstance }] of promptElements.entries()) { const elementSizing = elementSizings[i]; const template = templates[i]; @@ -316,7 +316,7 @@ export class PromptRenderer

{ private _handlePromptChildren(element: QueueItem, P>, pieces: ProcessedPromptPiece[], sizing: PromptSizingContext, progress: Progress | undefined, token: CancellationToken | undefined) { if (element.ctor === TextChunk) { - this._handleExtrinsicTextChunkChildren(element.node.parent!, element.props, pieces); + this._handleExtrinsicTextChunkChildren(element.node.parent!, element.node, element.props, pieces); return; } @@ -339,12 +339,12 @@ export class PromptRenderer

{ return this._processPromptPieces(sizing, todo, progress, token); } - private _handleIntrinsic(node: PromptTreeElement, name: string, props: any, children: ProcessedPromptPiece[]): void { + private _handleIntrinsic(node: PromptTreeElement, name: string, props: any, children: ProcessedPromptPiece[], sortIndex?: number): void { switch (name) { case 'meta': return this._handleIntrinsicMeta(node, props, children); case 'br': - return this._handleIntrinsicLineBreak(node, props, children, props.priority); + return this._handleIntrinsicLineBreak(node, props, children, props.priority, sortIndex); case 'usedContext': return this._handleIntrinsicUsedContext(node, props, children); case 'references': @@ -366,11 +366,11 @@ export class PromptRenderer

{ this._meta.set(key, props.value); } - private _handleIntrinsicLineBreak(node: PromptTreeElement, props: JSX.IntrinsicElements['br'], children: ProcessedPromptPiece[], inheritedPriority?: number) { + private _handleIntrinsicLineBreak(node: PromptTreeElement, props: JSX.IntrinsicElements['br'], children: ProcessedPromptPiece[], inheritedPriority?: number, sortIndex?: number) { if (children.length > 0) { throw new Error(`
must not have children!`); } - node.appendLineBreak(true, inheritedPriority ?? Number.MAX_SAFE_INTEGER); + node.appendLineBreak(true, inheritedPriority ?? Number.MAX_SAFE_INTEGER, sortIndex); } private _handleIntrinsicUsedContext(node: PromptTreeElement, props: JSX.IntrinsicElements['usedContext'], children: ProcessedPromptPiece[]) { @@ -398,10 +398,12 @@ export class PromptRenderer

{ /** * @param node Parent of the + * @param textChunkNode The node. All children are in-order + * appended to the parent using the same sort index to ensure order is preserved. * @param props Props of the * @param children Rendered children of the */ - private _handleExtrinsicTextChunkChildren(node: PromptTreeElement, props: BasePromptElementProps, children: ProcessedPromptPiece[]) { + private _handleExtrinsicTextChunkChildren(node: PromptTreeElement, textChunkNode: PromptTreeElement, props: BasePromptElementProps, children: ProcessedPromptPiece[]) { const content: string[] = []; const references: PromptReference[] = []; @@ -422,13 +424,13 @@ export class PromptRenderer

{ // For TextChunks, references must be propagated through the PromptText element that is appended to the node references.push(...child.props.value); } else { - this._handleIntrinsic(node, child.name, child.props, flattenAndReduceArr(child.children)); + this._handleIntrinsic(node, child.name, child.props, flattenAndReduceArr(child.children), textChunkNode.childIndex); } } } - node.appendLineBreak(false); - node.appendStringChild(content.join(''), props?.priority ?? Number.MAX_SAFE_INTEGER, references); + node.appendLineBreak(false, undefined, textChunkNode.childIndex); + node.appendStringChild(content.join(''), props?.priority ?? Number.MAX_SAFE_INTEGER, references, textChunkNode.childIndex); } } @@ -572,12 +574,12 @@ class PromptTreeElement { return child; } - public appendStringChild(text: string, priority?: number, references?: PromptReference[]) { - this._children.push(new PromptText(this, text, priority, references)); + public appendStringChild(text: string, priority?: number, references?: PromptReference[], sortIndex = this._children.length) { + this._children.push(new PromptText(this, sortIndex, text, priority, references)); } - public appendLineBreak(explicit = true, priority?: number): void { - this._children.push(new PromptLineBreak(this, explicit, priority)); + public appendLineBreak(explicit = true, priority?: number, sortIndex = this._children.length): void { + this._children.push(new PromptLineBreak(this, sortIndex, explicit, priority)); } public materialize(): { result: MaterializedChatMessage[]; resultChunks: MaterializedChatMessageTextChunk[] } { @@ -588,6 +590,7 @@ class PromptTreeElement { } private _materialize(result: MaterializedChatMessage[], resultChunks: MaterializedChatMessageTextChunk[]): void { + this._children.sort((a, b) => a.childIndex - b.childIndex); if (this._obj instanceof BaseChatMessage) { if (!this._obj.props.role) { throw new Error(`Invalid ChatMessage!`); @@ -637,7 +640,7 @@ class PromptTreeElement { } if (this._obj?.insertLineBreakBefore) { // Add an implicit
before the element - result.push(new PromptLineBreak(this, false)); + result.push(new PromptLineBreak(this, 0, false)); } for (const child of this._children) { child.collectLeafs(result); @@ -736,6 +739,7 @@ class PromptText { constructor( public readonly parent: PromptTreeElement, + public readonly childIndex: number, public readonly text: string, public readonly priority?: number, public readonly references?: PromptReference[] @@ -753,6 +757,7 @@ class PromptLineBreak { constructor( public readonly parent: PromptTreeElement, + public readonly childIndex: number, public readonly isExplicit: boolean, public readonly priority?: number ) { } diff --git a/src/base/test/renderer.test.tsx b/src/base/test/renderer.test.tsx index a9747b0..fc18a3b 100644 --- a/src/base/test/renderer.test.tsx +++ b/src/base/test/renderer.test.tsx @@ -169,6 +169,41 @@ suite('PromptRenderer', () => { ]); }); + test('maintains element order', async () => { + class Prompt2 extends PromptElement<{ content: string } & BasePromptElementProps> { + render() { + return ( + {this.props.content} + ); + } + } + + class Prompt1 extends PromptElement { + render() { + return ( + <> + + a + + c + d + e + f + g + + i + + + ); + } + } + + const inst = new PromptRenderer(fakeEndpoint, Prompt1, {}, tokenizer); + const res = await inst.render(undefined, undefined); + assert.deepStrictEqual(res.messages.length, 1); + assert.deepStrictEqual(res.messages[0].content.replace(/\n/g, ''), 'abcdefghi'); + }); + suite('truncates tokens exceeding token budget', async () => { class Prompt1 extends PromptElement { render(_: void, sizing: PromptSizing) {