Skip to content

Commit

Permalink
refactor: improve speed of text trimming (#157)
Browse files Browse the repository at this point in the history
* refactor: improve speed of text trimming

I added a benchmark for a heavy case of pruning. The changes in this PR
reduce the runtime from ~17 seconds to ~0.2 seconds. There are more
optimizations that can be done, but an ~80x speedup is a good start.

Refs microsoft/vscode-copilot-release#4985

* up version
  • Loading branch information
connor4312 authored Feb 28, 2025
1 parent a887ff7 commit ef03f8a
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 35 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,5 @@ dist
.pnp.*

src/base/htmlTracerSrc.ts

*.cpuprofile
14 changes: 12 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@vscode/prompt-tsx",
"version": "0.3.0-alpha.19",
"version": "0.3.0-alpha.20",
"description": "Declare LLM prompts with TSX",
"main": "./dist/base/index.js",
"types": "./dist/base/index.d.ts",
Expand All @@ -13,6 +13,7 @@
"watch:base": "tsc --watch --sourceMap --preserveWatchOutput",
"test": "vscode-test",
"test:unit": "cross-env IS_OUTSIDE_VSCODE=1 mocha --import=tsx -u tdd \"src/base/test/**/*.test.{ts,tsx}\"",
"test:bench": "tsx ./src/base/test/renderer.bench.tsx",
"prettier": "prettier --list-different --write --cache .",
"prepare": "tsx ./build/postinstall.ts"
},
Expand All @@ -38,6 +39,7 @@
"mocha": "^10.2.0",
"preact": "^10.24.2",
"prettier": "^2.8.8",
"tinybench": "^3.1.1",
"tsx": "^4.19.1",
"typescript": "^5.6.2"
}
Expand Down
58 changes: 34 additions & 24 deletions src/base/materialized.ts
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,14 @@ export class MaterializedContainer implements IMaterializedContainer {
}
}

/** Removes the node in the tree with the lowest priority. */
removeLowestPriorityChild(): void {
if (this.has(ContainerFlags.IsLegacyPrioritization)) {
removeLowestPriorityLegacy(this);
} else {
removeLowestPriorityChild(this);
}
/**
* Removes the node in the tree with the lowest priority. Returns the
* list of nodes that were removed.
*/
removeLowestPriorityChild(): MaterializedNode[] {
const removed: MaterializedNode[] = [];
removeLowestPriorityChild(this, removed);
return removed;
}
}

Expand Down Expand Up @@ -256,11 +257,11 @@ export class MaterializedChatMessage implements IMaterializedNode {
return replaced;
}

/** Remove the lowest priority chunk among this message's children. */
removeLowestPriorityChild() {
removeLowestPriorityChild(this);
removeLowestPriorityChild(): MaterializedNode[] {
const removed: MaterializedNode[] = [];
removeLowestPriorityChild(this, removed);
return removed;
}

onChunksChange() {
this._tokenCount.clear();
this._upperBound.clear();
Expand Down Expand Up @@ -464,7 +465,7 @@ function* textChunks(
}
}

function removeLowestPriorityLegacy(root: MaterializedNode) {
function removeLowestPriorityLegacy(root: MaterializedNode, removed: MaterializedNode[]) {
let lowest:
| undefined
| {
Expand Down Expand Up @@ -498,10 +499,13 @@ function removeLowestPriorityLegacy(root: MaterializedNode) {
throw new Error('No lowest priority node found');
}

removeNode(lowest.node);
removeNode(lowest.node, removed);
}

function removeLowestPriorityChild(node: MaterializedContainer | MaterializedChatMessage) {
function removeLowestPriorityChild(
node: MaterializedContainer | MaterializedChatMessage,
removed: MaterializedNode[]
) {
let lowest:
| undefined
| {
Expand All @@ -511,6 +515,11 @@ function removeLowestPriorityChild(node: MaterializedContainer | MaterializedCha
lowestNested?: number;
};

if (node instanceof MaterializedContainer && node.has(ContainerFlags.IsLegacyPrioritization)) {
removeLowestPriorityLegacy(node, removed);
return;
}

// In *most* cases the chain is always [node], but it can be longer if
// the `passPriority` is used. We need to keep track of the chain to
// call `onChunksChange` as necessary.
Expand Down Expand Up @@ -539,16 +548,15 @@ function removeLowestPriorityChild(node: MaterializedContainer | MaterializedCha
throw new Error('No lowest priority node found');
}

const containingList = lowest.chain[lowest.chain.length - 1].children;
if (
lowest.value instanceof MaterializedChatMessageTextChunk ||
lowest.value instanceof MaterializedChatMessageImage ||
(lowest.value instanceof MaterializedContainer && lowest.value.has(ContainerFlags.IsChunk)) ||
(isContainerType(lowest.value) && !lowest.value.children.length)
) {
removeNode(lowest.value);
removeNode(lowest.value, removed);
} else {
lowest.value.removeLowestPriorityChild();
removeLowestPriorityChild(lowest.value, removed);
}
}

Expand Down Expand Up @@ -631,7 +639,7 @@ function isKeepWith(
/** Global list of 'keepWiths' currently being removed to avoid recursing indefinitely */
const currentlyBeingRemovedKeepWiths = new Set<number>();

function removeOtherKeepWiths(nodeThatWasRemoved: MaterializedNode) {
function removeOtherKeepWiths(nodeThatWasRemoved: MaterializedNode, removed: MaterializedNode[]) {
const removeKeepWithIds = new Set<number>();
for (const node of forEachNode(nodeThatWasRemoved)) {
if (isKeepWith(node) && !currentlyBeingRemovedKeepWiths.has(node.keepWithId)) {
Expand All @@ -651,15 +659,16 @@ function removeOtherKeepWiths(nodeThatWasRemoved: MaterializedNode) {
const root = getRoot(nodeThatWasRemoved);
for (const node of forEachNode(root)) {
if (isKeepWith(node) && removeKeepWithIds.has(node.keepWithId)) {
removeNode(node);
removeNode(node, removed);
} else if (node instanceof MaterializedChatMessage && node.toolCalls) {
node.toolCalls = filterIfDifferent(
node.toolCalls,
c => !(c.keepWith && removeKeepWithIds.has(c.keepWith.id))
);

if (node.isEmpty) { // may have become empty if it only contained tool calls
removeNode(node);
if (node.isEmpty) {
// may have become empty if it only contained tool calls
removeNode(node, removed);
}
}
}
Expand All @@ -685,7 +694,7 @@ function findNodeById(nodeId: number, container: ContainerType): ContainerType |
}
}

function removeNode(node: MaterializedNode) {
function removeNode(node: MaterializedNode, removed: MaterializedNode[]) {
const parent = node.parent;
if (!parent) {
return; // root
Expand All @@ -697,10 +706,11 @@ function removeNode(node: MaterializedNode) {
}

parent.children.splice(index, 1);
removeOtherKeepWiths(node);
removed.push(node);
removeOtherKeepWiths(node, removed);

if (parent.isEmpty) {
removeNode(parent);
removeNode(parent, removed);
} else {
parent.onChunksChange();
}
Expand Down
28 changes: 20 additions & 8 deletions src/base/promptRenderer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -461,14 +461,26 @@ export class PromptRenderer<P extends BasePromptElementProps> {
}
}

// Trim the elements to fit within the token budget. We check the "lower bound"
// first because that's much more cache-friendly as we remove elements.
while (
(await container.upperBoundTokenCount(this._tokenizer)) > limit.limit &&
(await container.tokenCount(this._tokenizer)) > limit.limit
) {
container.removeLowestPriorityChild();
removed++;
// Trim the elements to fit within the token budget. The "upper bound" count
// is a cachable count derived from the individual token counts of each component.
// The actual token count is <= the upper bound count due to BPE merging of tokens
// at component boundaries.
//
// To avoid excess tokenization, we first calculate the precise token
// usage of the message, and then remove components, subtracting their
// "upper bound" usage from the count until it's <= the budget. We then
// repeat this and refine as necessary, though most of the time we only
// need a single iteration of this.<sup>[citation needed]</sup>
let tokenCount = await container.tokenCount(this._tokenizer);
while (tokenCount > limit.limit) {
while (tokenCount > limit.limit) {
for (const node of container.removeLowestPriorityChild()) {
removed++;
const rmCount = node.upperBoundTokenCount(this._tokenizer);
tokenCount -= typeof rmCount === 'number' ? rmCount : await rmCount;
}
}
tokenCount = await container.tokenCount(this._tokenizer);
}
}

Expand Down
67 changes: 67 additions & 0 deletions src/base/test/renderer.bench.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import { existsSync, readFileSync } from 'fs';
import { Bench } from 'tinybench';
import { Cl100KBaseTokenizer } from '../tokenizer/cl100kBaseTokenizer';
import type * as promptTsx from '..';
import assert = require('assert');

const comparePathVar = 'PROMPT_TSX_COMPARE_PATH';
const tsxComparePath =
process.env[comparePathVar] ||
`${__dirname}/../../../../vscode-copilot/node_modules/@vscode/prompt-tsx`;
const canCompare = existsSync(tsxComparePath);
if (!canCompare) {
console.error(
`$${comparePathVar} was not set / ${tsxComparePath} doesn't exist, so the benchmark will not compare to past behavior`
);
process.exit(1);
}

const numberOfRepeats = 1;
const sampleText = readFileSync(`${__dirname}/renderer.test.tsx`, 'utf-8');
const sampleTextLines = readFileSync(`${__dirname}/renderer.test.tsx`, 'utf-8').split('\n');
const tokenizer = new Cl100KBaseTokenizer();
const bench = new Bench({
name: `trim ${tokenizer.tokenLength(sampleText) * numberOfRepeats}->1k tokens`,
time: 100,
});

async function benchTokenizationTrim({
PromptRenderer,
PromptElement,
UserMessage,
TextChunk,
}: typeof promptTsx) {
const r = await new PromptRenderer(
{ modelMaxPromptTokens: 1000 },
class extends PromptElement {
render() {
return (
<>
{Array.from({ length: numberOfRepeats }, () => (
<UserMessage>
{sampleTextLines.map(l => (
<TextChunk>{l}</TextChunk>
))}
</UserMessage>
))}
</>
);
}
},
{},
tokenizer
).render();
assert(r.tokenCount <= 1000);
assert(r.tokenCount > 100);
}

bench.add('current', () => benchTokenizationTrim(require('..')));
if (canCompare) {
const fn = require(tsxComparePath);
bench.add('previous', () => benchTokenizationTrim(fn));
}

bench.run().then(() => {
console.log(bench.name);
console.table(bench.table());
});

0 comments on commit ef03f8a

Please sign in to comment.