{ private readonly _ignoredFiles: URI[] = []; private readonly _growables: { initialConsume: number; elem: PromptTreeElement }[] = []; private readonly _root = new PromptTreeElement(null, 0); + private readonly _tokenLimits: { limit: number; id: number }[] = []; /** Epoch used to tracing the order in which elements render. */ public tracer: ITracer | undefined = undefined; @@ -124,6 +127,7 @@ export class PromptRenderer
{
{
element: QueueItem {
}
const promptElement = this.createElement(element);
+ let tokenLimit: number | undefined;
+ if (promptElement instanceof TokenLimit) {
+ tokenLimit = (element.props as unknown as TokenLimitProps).max;
+ this._tokenLimits.push({ limit: tokenLimit, id: element.node.id });
+ }
element.node.setObj(promptElement);
// Prepare rendering
@@ -151,7 +160,7 @@ export class PromptRenderer {
promptElements.set(flexGroupValue, flexGroup);
}
- flexGroup.push({ element, promptElementInstance: promptElement });
+ flexGroup.push({ element, promptElementInstance: promptElement, tokenLimit });
}
if (promptElements.size === 0) {
@@ -189,15 +198,37 @@ export class PromptRenderer {
// Calculate the flex basis for dividing the budget amongst siblings in this group.
let flexBasisSum = 0;
for (const { element } of promptElements) {
- // todo@connor4312: remove `flex` after transition
- flexBasisSum += (element.props.flex || element.props.flexBasis) ?? 1;
+ flexBasisSum += element.props.flexBasis ?? 1;
}
+ let constantTokenLimits = 0;
+ //.For elements that limit their token usage and would use less than we
+ // otherwise would assign to them, 'cap' their usage at the limit and
+ // remove their share directly from the budget in distribution.
+ const useConstantLimitsForIndex = promptElements.map(e => {
+ if (e.tokenLimit === undefined) {
+ return false;
+ }
+
+ const flexBasis = e.element.props.flexBasis ?? 1;
+ const proportion = flexBasis / flexBasisSum;
+ const proportionateUsage = Math.floor(sizing.remainingTokenBudget * proportion);
+ if (proportionateUsage < e.tokenLimit) {
+ return false;
+ }
+
+ flexBasisSum -= flexBasis;
+ constantTokenLimits += e.tokenLimit;
+ return true;
+ });
+
// Finally calculate the final sizing for each element in this group.
- const elementSizings: PromptSizing[] = promptElements.map(e => {
+ const elementSizings: PromptSizing[] = promptElements.map((e, i) => {
const proportion = (e.element.props.flexBasis ?? 1) / flexBasisSum;
return {
- tokenBudget: Math.floor(sizing.remainingTokenBudget * proportion),
+ tokenBudget: useConstantLimitsForIndex[i]
+ ? e.tokenLimit!
+ : Math.floor((sizing.remainingTokenBudget - constantTokenLimits) * proportion),
endpoint: sizing.endpoint,
countTokens: (text, cancellation) => this._tokenizer.tokenLength(text, cancellation),
};
@@ -399,35 +430,49 @@ export class PromptRenderer {
* around with budgets. It should be side-effect-free.
*/
private async _getFinalElementTree(tokenBudget: number, token: CancellationToken | undefined) {
- // Trim the elements to fit within the token budget. We check the "lower bound"
- // first because that's much more cache-friendly as we remove elements.
- const container = this._root.materialize() as MaterializedContainer;
- const initialTokenCount = await container.tokenCount(this._tokenizer);
- if (initialTokenCount < tokenBudget) {
- const didChange = await this._grow(container, initialTokenCount, tokenBudget, token);
-
- // if nothing grew, we already counted tokens so we can safely return
- if (!didChange) {
- return { container, allMetadata: [...container.allMetadata()], removed: 0 };
+ const root = this._root.materialize() as MaterializedContainer;
+ const allMetadata = [...root.allMetadata()];
+ const limits = [{ limit: tokenBudget, id: this._root.id }, ...this._tokenLimits];
+ let removed = 0;
+
+ for (let i = limits.length - 1; i >= 0; i--) {
+ const limit = limits[i];
+ if (limit.limit > tokenBudget) {
+ continue;
}
- }
- const allMetadata = [...container.allMetadata()];
- let removed = 0;
- while (
- (await container.upperBoundTokenCount(this._tokenizer)) > tokenBudget &&
- (await container.tokenCount(this._tokenizer)) > tokenBudget
- ) {
- container.removeLowestPriorityChild();
- removed++;
+ const container = root.findById(limit.id);
+ if (!container) {
+ continue;
+ }
+
+ const initialTokenCount = await container.tokenCount(this._tokenizer);
+ if (initialTokenCount < limit.limit) {
+ const didChange = await this._grow(container, initialTokenCount, limit.limit, token);
+
+ // if nothing grew, we already counted tokens so we can safely return
+ if (!didChange) {
+ continue;
+ }
+ }
+
+ // Trim the elements to fit within the token budget. We check the "lower bound"
+ // first because that's much more cache-friendly as we remove elements.
+ while (
+ (await container.upperBoundTokenCount(this._tokenizer)) > limit.limit &&
+ (await container.tokenCount(this._tokenizer)) > limit.limit
+ ) {
+ container.removeLowestPriorityChild();
+ removed++;
+ }
}
- return { container, allMetadata, removed };
+ return { container: root, allMetadata, removed };
}
/** Grows all Expandable elements, returns if any changes were made. */
private async _grow(
- tree: MaterializedContainer,
+ tree: MaterializedContainer | MaterializedChatMessage,
tokensUsed: number,
tokenBudget: number,
token: CancellationToken | undefined
@@ -437,6 +482,10 @@ export class PromptRenderer {
}
for (const growable of this._growables) {
+ if (!tree.findById(growable.elem.id)) {
+ continue; // not in this subtree
+ }
+
const obj = growable.elem.getObj();
if (!(obj instanceof Expandable)) {
throw new Error('unreachable: expected growable');
diff --git a/src/base/test/renderer.test.tsx b/src/base/test/renderer.test.tsx
index 6652888..b8677eb 100644
--- a/src/base/test/renderer.test.tsx
+++ b/src/base/test/renderer.test.tsx
@@ -14,6 +14,7 @@ import {
PrioritizedList,
SystemMessage,
TextChunk,
+ TokenLimit,
ToolMessage,
ToolResult,
UserMessage,
@@ -1920,6 +1921,206 @@ suite('PromptRenderer', () => {
const inst = new PromptRenderer(fakeEndpoint, Wrapper, {}, tokenizer);
const res = await inst.render(undefined, undefined);
- assert.deepStrictEqual(res.messages.map(m => m.content).join(''), 'hello everyone in the world!');
+ assert.deepStrictEqual(
+ res.messages.map(m => m.content).join(''),
+ 'hello everyone in the world!'
+ );
+ });
+
+ suite('TokenLimit', () => {
+ test('limits tokens within budget', async () => {
+ class PromptWithLimit extends PromptElement {
+ render() {
+ return (
+