Skip to content

Commit

Permalink
chore: add legacy prioritization for migration
Browse files Browse the repository at this point in the history
  • Loading branch information
connor4312 committed Oct 9, 2024
1 parent bd18e35 commit 09caf7e
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 7 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@vscode/prompt-tsx",
"version": "0.3.0-alpha",
"version": "0.3.0-alpha.1",
"description": "Declare LLM prompts with TSX",
"main": "./dist/base/index.js",
"types": "./dist/base/index.d.ts",
Expand Down
64 changes: 59 additions & 5 deletions src/base/materialized.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ export class MaterializedContainer implements IMaterializedNode {
public readonly priority: number,
public readonly children: MaterializedNode[],
public readonly metadata: PromptMetadata[],
public readonly isLegacyPrioritization = false,
) { }

/** @inheritdoc */
Expand Down Expand Up @@ -75,7 +76,11 @@ export class MaterializedContainer implements IMaterializedNode {

/** Removes the node in the tree with the lowest priority. */
removeLowestPriorityChild(): void {
removeLowestPriorityChild(this.children);
if (this.isLegacyPrioritization) {
removeLowestPriorityLegacy(this);
} else {
removeLowestPriorityChild(this.children);
}
}
}

Expand All @@ -92,8 +97,8 @@ export class MaterializedChatMessageTextChunk {
return this._upperBound(tokenizer);
}

private readonly _upperBound = once((tokenizer: ITokenizer) => {
return tokenizer.tokenLength(this.text);
private readonly _upperBound = once(async (tokenizer: ITokenizer) => {
return await tokenizer.tokenLength(this.text) + (this.lineBreakBefore ? 1 : 0);
});
}

Expand Down Expand Up @@ -131,7 +136,7 @@ export class MaterializedChatMessage implements IMaterializedNode {
this.onChunksChange();
}

private onChunksChange() {
onChunksChange() {
this._tokenCount.clear();
this._upperBound.clear();
this._text.clear();
Expand All @@ -145,7 +150,7 @@ export class MaterializedChatMessage implements IMaterializedNode {
let total = await this._baseMessageTokenCount(tokenizer)
await Promise.all(this.children.map(async (chunk) => {
const amt = await chunk.upperBoundTokenCount(tokenizer);
total += amt + (chunk instanceof MaterializedChatMessageTextChunk && chunk.lineBreakBefore ? 1 : 0);
total += amt;
}));
return total;
});
Expand Down Expand Up @@ -231,6 +236,55 @@ function* textChunks(node: MaterializedNode): Generator<MaterializedChatMessageT
}
}

function removeLowestPriorityLegacy(root: MaterializedNode) {
let lowest: undefined | {
chain: (MaterializedContainer | MaterializedChatMessage)[],
node: MaterializedChatMessageTextChunk;
};

function findLowestInTree(node: MaterializedNode, chain: (MaterializedContainer | MaterializedChatMessage)[]) {
if (node instanceof MaterializedChatMessageTextChunk) {
if (!lowest || node.priority < lowest.node.priority) {
lowest = { chain: chain.slice(), node };
}
} else {
chain.push(node);
for (const child of node.children) {
findLowestInTree(child, chain);
}
chain.pop();
}
}

findLowestInTree(root, []);

if (!lowest) {
throw new Error('No lowest priority node found');
}

let needle: MaterializedNode = lowest.node;
let i = lowest.chain.length - 1;
for (; i >= 0; i--) {
const node = lowest.chain[i];
node.children.splice(node.children.indexOf(needle), 1);
if (node instanceof MaterializedChatMessage) {
node.onChunksChange();
}
if (node.children.length > 0) {
break;
}

needle = node;
}

for (; i >= 0; i--) {
const node = lowest.chain[i];
if (node instanceof MaterializedChatMessage) {
node.onChunksChange();
}
}
}

function removeLowestPriorityChild(children: MaterializedNode[]) {
if (!children.length) {
return;
Expand Down
13 changes: 13 additions & 0 deletions src/base/promptElements.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,16 @@ export class ToolResult extends PromptElement<IToolResultProps> {
}
}
}

/**
* Marker element that uses the legacy global prioritization algorithm (0.2.x
* if this library) for pruning child elements. This will be removed in
* the future.
*
* @deprecated
*/
export class LegacyPrioritization extends PromptElement {
render() {
return <>{this.props.children}</>;
}
}
3 changes: 2 additions & 1 deletion src/base/promptRenderer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { PromptNodeType } from './jsonTypes';
import { MaterializedChatMessage, MaterializedChatMessageTextChunk, MaterializedContainer } from './materialized';
import { ChatMessage } from "./openai";
import { PromptElement } from "./promptElement";
import { AssistantMessage, BaseChatMessage, ChatMessagePromptElement, TextChunk, ToolMessage, isChatMessagePromptElement } from "./promptElements";
import { AssistantMessage, BaseChatMessage, ChatMessagePromptElement, LegacyPrioritization, TextChunk, ToolMessage, isChatMessagePromptElement } from "./promptElements";
import { PromptMetadata, PromptReference } from "./results";
import { ITokenizer } from "./tokenizer/tokenizer";
import { ITracer } from './tracer';
Expand Down Expand Up @@ -650,6 +650,7 @@ class PromptTreeElement {
this._obj?.props.priority || 0,
this._children.map(child => child.materialize()),
this._metadata,
this._obj instanceof LegacyPrioritization,
);
}
}
Expand Down
29 changes: 29 additions & 0 deletions src/base/test/renderer.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { BaseTokensPerCompletion, ChatMessage, ChatRole } from '../openai';
import { PromptElement } from '../promptElement';
import {
AssistantMessage,
LegacyPrioritization,
PrioritizedList,
SystemMessage,
TextChunk,
Expand Down Expand Up @@ -407,6 +408,34 @@ suite('PromptRenderer', () => {
</SystemMessage>
</>, ['a', 'b', 'c', 'd']);
});

test('uses legacy prioritization', async () => {
class Wrap1 extends PromptElement {
render() {
return <>
<TextChunk priority={1}>a</TextChunk>
<TextChunk priority={10}>b</TextChunk>
</>
}
}
class Wrap2 extends PromptElement {
render() {
return <>
<TextChunk priority={2}>c</TextChunk>
<TextChunk priority={15}>d</TextChunk>
</>
}
}
await assertPruningOrder(<LegacyPrioritization>
<UserMessage>
<Wrap1 priority={1} />
<Wrap2 priority={2} />
</UserMessage>
<UserMessage>
<TextChunk priority={5}>e</TextChunk>
</UserMessage>
</LegacyPrioritization>, ['a', 'c', 'e', 'b', 'd']);
});
});

suite('truncates tokens exceeding token budget', async () => {
Expand Down

0 comments on commit 09caf7e

Please sign in to comment.