Skip to content

Commit

Permalink
fix: ensure text chunk order is retaind (#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
connor4312 authored Jul 3, 2024
1 parent d25770f commit 6e450fb
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 17 deletions.
39 changes: 22 additions & 17 deletions src/base/promptRenderer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ export class PromptRenderer<P extends BasePromptElementProps> {
private async _processPromptPieces(sizing: PromptSizingContext, pieces: QueueItem<PromptElementCtor<P, any>, P>[], progress?: Progress<ChatResponsePart>, token?: CancellationToken) {
// Collect all prompt elements in the next flex group to render, grouping
// by the flex order in which they're rendered.
const promptElements = new Map<number, { order: number; element: QueueItem<PromptElementCtor<P, any>, P>; promptElementInstance: PromptElement<any, any> }[]>();
const promptElements = new Map<number, { element: QueueItem<PromptElementCtor<P, any>, P>; promptElementInstance: PromptElement<any, any> }[]>();
for (const [i, element] of pieces.entries()) {
// Set any jsx children as the props.children
if (Array.isArray(element.children)) {
Expand All @@ -119,7 +119,7 @@ export class PromptRenderer<P extends BasePromptElementProps> {
promptElements.set(flexGroupValue, flexGroup);
}

flexGroup.push({ order: i, element, promptElementInstance: promptElement });
flexGroup.push({ element, promptElementInstance: promptElement });
}

const flexGroups = [...promptElements.entries()].sort(([a], [b]) => b - a).map(([_, group]) => group);
Expand Down Expand Up @@ -171,7 +171,7 @@ export class PromptRenderer<P extends BasePromptElementProps> {
}));

// Render
for (const [i, { element, promptElementInstance, order }] of promptElements.entries()) {
for (const [i, { element, promptElementInstance }] of promptElements.entries()) {
const elementSizing = elementSizings[i];
const template = templates[i];

Expand Down Expand Up @@ -316,7 +316,7 @@ export class PromptRenderer<P extends BasePromptElementProps> {

private _handlePromptChildren(element: QueueItem<PromptElementCtor<any, any>, P>, pieces: ProcessedPromptPiece[], sizing: PromptSizingContext, progress: Progress<ChatResponsePart> | undefined, token: CancellationToken | undefined) {
if (element.ctor === TextChunk) {
this._handleExtrinsicTextChunkChildren(element.node.parent!, element.props, pieces);
this._handleExtrinsicTextChunkChildren(element.node.parent!, element.node, element.props, pieces);
return;
}

Expand All @@ -339,12 +339,12 @@ export class PromptRenderer<P extends BasePromptElementProps> {
return this._processPromptPieces(sizing, todo, progress, token);
}

private _handleIntrinsic(node: PromptTreeElement, name: string, props: any, children: ProcessedPromptPiece[]): void {
private _handleIntrinsic(node: PromptTreeElement, name: string, props: any, children: ProcessedPromptPiece[], sortIndex?: number): void {
switch (name) {
case 'meta':
return this._handleIntrinsicMeta(node, props, children);
case 'br':
return this._handleIntrinsicLineBreak(node, props, children, props.priority);
return this._handleIntrinsicLineBreak(node, props, children, props.priority, sortIndex);
case 'usedContext':
return this._handleIntrinsicUsedContext(node, props, children);
case 'references':
Expand All @@ -366,11 +366,11 @@ export class PromptRenderer<P extends BasePromptElementProps> {
this._meta.set(key, props.value);
}

private _handleIntrinsicLineBreak(node: PromptTreeElement, props: JSX.IntrinsicElements['br'], children: ProcessedPromptPiece[], inheritedPriority?: number) {
private _handleIntrinsicLineBreak(node: PromptTreeElement, props: JSX.IntrinsicElements['br'], children: ProcessedPromptPiece[], inheritedPriority?: number, sortIndex?: number) {
if (children.length > 0) {
throw new Error(`<br /> must not have children!`);
}
node.appendLineBreak(true, inheritedPriority ?? Number.MAX_SAFE_INTEGER);
node.appendLineBreak(true, inheritedPriority ?? Number.MAX_SAFE_INTEGER, sortIndex);
}

private _handleIntrinsicUsedContext(node: PromptTreeElement, props: JSX.IntrinsicElements['usedContext'], children: ProcessedPromptPiece[]) {
Expand Down Expand Up @@ -398,10 +398,12 @@ export class PromptRenderer<P extends BasePromptElementProps> {

/**
* @param node Parent of the <TextChunk />
* @param textChunkNode The <TextChunk /> node. All children are in-order
* appended to the parent using the same sort index to ensure order is preserved.
* @param props Props of the <TextChunk />
* @param children Rendered children of the <TextChunk />
*/
private _handleExtrinsicTextChunkChildren(node: PromptTreeElement, props: BasePromptElementProps, children: ProcessedPromptPiece[]) {
private _handleExtrinsicTextChunkChildren(node: PromptTreeElement, textChunkNode: PromptTreeElement, props: BasePromptElementProps, children: ProcessedPromptPiece[]) {
const content: string[] = [];
const references: PromptReference[] = [];

Expand All @@ -422,13 +424,13 @@ export class PromptRenderer<P extends BasePromptElementProps> {
// For TextChunks, references must be propagated through the PromptText element that is appended to the node
references.push(...child.props.value);
} else {
this._handleIntrinsic(node, child.name, child.props, flattenAndReduceArr(child.children));
this._handleIntrinsic(node, child.name, child.props, flattenAndReduceArr(child.children), textChunkNode.childIndex);
}
}
}

node.appendLineBreak(false);
node.appendStringChild(content.join(''), props?.priority ?? Number.MAX_SAFE_INTEGER, references);
node.appendLineBreak(false, undefined, textChunkNode.childIndex);
node.appendStringChild(content.join(''), props?.priority ?? Number.MAX_SAFE_INTEGER, references, textChunkNode.childIndex);
}
}

Expand Down Expand Up @@ -572,12 +574,12 @@ class PromptTreeElement {
return child;
}

public appendStringChild(text: string, priority?: number, references?: PromptReference[]) {
this._children.push(new PromptText(this, text, priority, references));
public appendStringChild(text: string, priority?: number, references?: PromptReference[], sortIndex = this._children.length) {
this._children.push(new PromptText(this, sortIndex, text, priority, references));
}

public appendLineBreak(explicit = true, priority?: number): void {
this._children.push(new PromptLineBreak(this, explicit, priority));
public appendLineBreak(explicit = true, priority?: number, sortIndex = this._children.length): void {
this._children.push(new PromptLineBreak(this, sortIndex, explicit, priority));
}

public materialize(): { result: MaterializedChatMessage[]; resultChunks: MaterializedChatMessageTextChunk[] } {
Expand All @@ -588,6 +590,7 @@ class PromptTreeElement {
}

private _materialize(result: MaterializedChatMessage[], resultChunks: MaterializedChatMessageTextChunk[]): void {
this._children.sort((a, b) => a.childIndex - b.childIndex);
if (this._obj instanceof BaseChatMessage) {
if (!this._obj.props.role) {
throw new Error(`Invalid ChatMessage!`);
Expand Down Expand Up @@ -637,7 +640,7 @@ class PromptTreeElement {
}
if (this._obj?.insertLineBreakBefore) {
// Add an implicit <br/> before the element
result.push(new PromptLineBreak(this, false));
result.push(new PromptLineBreak(this, 0, false));
}
for (const child of this._children) {
child.collectLeafs(result);
Expand Down Expand Up @@ -736,6 +739,7 @@ class PromptText {

constructor(
public readonly parent: PromptTreeElement,
public readonly childIndex: number,
public readonly text: string,
public readonly priority?: number,
public readonly references?: PromptReference[]
Expand All @@ -753,6 +757,7 @@ class PromptLineBreak {

constructor(
public readonly parent: PromptTreeElement,
public readonly childIndex: number,
public readonly isExplicit: boolean,
public readonly priority?: number
) { }
Expand Down
35 changes: 35 additions & 0 deletions src/base/test/renderer.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,41 @@ suite('PromptRenderer', () => {
]);
});

test('maintains element order', async () => {
class Prompt2 extends PromptElement<{ content: string } & BasePromptElementProps> {
render() {
return (
<TextChunk>{this.props.content}</TextChunk>
);
}
}

class Prompt1 extends PromptElement {
render() {
return (
<>
<SystemMessage>
a
<Prompt2 content='b' />
c
<TextChunk>d</TextChunk>
e
<TextChunk flexGrow={2}>f</TextChunk>
g
<Prompt2 content='h' flexGrow={1} />
i
</SystemMessage>
</>
);
}
}

const inst = new PromptRenderer(fakeEndpoint, Prompt1, {}, tokenizer);
const res = await inst.render(undefined, undefined);
assert.deepStrictEqual(res.messages.length, 1);
assert.deepStrictEqual(res.messages[0].content.replace(/\n/g, ''), 'abcdefghi');
});

suite('truncates tokens exceeding token budget', async () => {
class Prompt1 extends PromptElement {
render(_: void, sizing: PromptSizing) {
Expand Down

0 comments on commit 6e450fb

Please sign in to comment.