Skip to content

Commit

Permalink
add better handling of whitespace, Chunk utility
Browse files Browse the repository at this point in the history
  • Loading branch information
connor4312 committed Oct 9, 2024
1 parent 09caf7e commit af5ac44
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 40 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,14 @@ In this case, a very long `userQuery` would get pruned from the output first if

...would be pruned in the order `B->A->D->C`. If two sibling elements share the same priority, the renderer looks ahead at their direct children and picks whichever one has a child with the lowest priority: if the `SystemMessage` and `UserMessage` in the above example did not declare priorities, the pruning order would be `B->D->A->C`.

Continuous text strings and elements can both be pruned from the tree. If you have a set of elements that you want to either be include all the time or none of the time, you can use the simple `Chunk` utility element:

```html
<Chunk>
The file I'm editing is: <FileLink file={f}>
</Chunk>
```

### Flex Behavior

Wholesale pruning is not always already. Instead, we'd prefer to include as much of the query as possible. To do this, we can use the `flexGrow` property, which allows an element to use the remainder of its parent's token budget when it's rendered.
Expand Down
2 changes: 1 addition & 1 deletion src/base/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ export * from './tracer';
export * from './tsx-globals';
export * from './types';

export { AssistantMessage, FunctionMessage, PrioritizedList, PrioritizedListProps, SystemMessage, TextChunk, TextChunkProps, UserMessage } from './promptElements';
export { AssistantMessage, FunctionMessage, PrioritizedList, PrioritizedListProps, SystemMessage, TextChunk, TextChunkProps, UserMessage, LegacyPrioritization, Chunk } from './promptElements';

export { PromptElement } from './promptElement';
export { MetadataMap, PromptRenderer, QueueItem, RenderPromptResult } from './promptRenderer';
Expand Down
72 changes: 50 additions & 22 deletions src/base/materialized.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,26 @@ export interface IMaterializedNode {

export type MaterializedNode = MaterializedContainer | MaterializedChatMessage | MaterializedChatMessageTextChunk;

export const enum ContainerFlags {
/** It's a {@link LegacyPrioritization} instance */
IsLegacyPrioritization = 1 << 0,
/** It's a {@link Chunk} instance */
IsChunk = 1 << 1,
}

export class MaterializedContainer implements IMaterializedNode {

constructor(
public readonly priority: number,
public readonly children: MaterializedNode[],
public readonly metadata: PromptMetadata[],
public readonly isLegacyPrioritization = false,
public readonly flags: number,
) { }

public has(flag: ContainerFlags) {
return !!(this.flags & flag);
}

/** @inheritdoc */
async tokenCount(tokenizer: ITokenizer): Promise<number> {
let total = 0;
Expand Down Expand Up @@ -67,38 +78,50 @@ export class MaterializedContainer implements IMaterializedNode {
/**
* Gets the chat messages the container holds.
*/
toChatMessages(): ChatMessage[] {
return this.children.flatMap(child => {
*toChatMessages(): Generator<ChatMessage> {
for (const child of this.children) {
assertContainerOrChatMessage(child);
return child instanceof MaterializedContainer ? child.toChatMessages() : [child.toChatMessage()];
})
if (child instanceof MaterializedContainer) {
yield* child.toChatMessages();
} else if (!child.isEmpty) {
// note: empty messages are already removed during pruning, but the
// consumer might themselves have given us empty messages that we should omit.
yield child.toChatMessage();
}
}
}

/** Removes the node in the tree with the lowest priority. */
removeLowestPriorityChild(): void {
if (this.isLegacyPrioritization) {
if (this.has(ContainerFlags.IsLegacyPrioritization)) {
removeLowestPriorityLegacy(this);
} else {
removeLowestPriorityChild(this.children);
}
}
}

export const enum LineBreakBefore {
None,
Always,
IfNotTextSibling,
}

/** A chunk of text in a {@link MaterializedChatMessage} */
export class MaterializedChatMessageTextChunk {
constructor(
public readonly text: string,
public readonly priority: number,
public readonly metadata: PromptMetadata[] = [],
public readonly lineBreakBefore: boolean,
public readonly lineBreakBefore: LineBreakBefore,
) { }

public upperBoundTokenCount(tokenizer: ITokenizer) {
return this._upperBound(tokenizer);
}

private readonly _upperBound = once(async (tokenizer: ITokenizer) => {
return await tokenizer.tokenLength(this.text) + (this.lineBreakBefore ? 1 : 0);
return await tokenizer.tokenLength(this.text) + (this.lineBreakBefore !== LineBreakBefore.None ? 1 : 0);
});
}

Expand Down Expand Up @@ -130,6 +153,11 @@ export class MaterializedChatMessage implements IMaterializedNode {
return this._text()
}

/** Gets whether the message is empty */
public get isEmpty() {
return !/\S/.test(this.text);
}

/** Remove the lowest priority chunk among this message's children. */
removeLowestPriorityChild() {
removeLowestPriorityChild(this.children);
Expand Down Expand Up @@ -161,14 +189,17 @@ export class MaterializedChatMessage implements IMaterializedNode {

private readonly _text = once(() => {
let result = '';
for (const chunk of textChunks(this)) {
if (chunk.lineBreakBefore && result.length && !result.endsWith('\n')) {
result += '\n';
for (const { text, isTextSibling } of textChunks(this)) {
if (text.lineBreakBefore === LineBreakBefore.Always || (text.lineBreakBefore === LineBreakBefore.IfNotTextSibling && !isTextSibling)) {
if (result.length && !result.endsWith('\n')) {
result += '\n';
}
}
result += chunk.text;

result += text.text;
}

return result;
return result.trim();
});

public toChatMessage(): ChatMessage {
Expand Down Expand Up @@ -221,17 +252,14 @@ function assertContainerOrChatMessage(v: MaterializedNode): asserts v is Materia
}


function* textChunks(node: MaterializedNode): Generator<MaterializedChatMessageTextChunk> {
if (node instanceof MaterializedChatMessageTextChunk) {
yield node;
return;
}

function* textChunks(node: MaterializedContainer | MaterializedChatMessage, isTextSibling = false): Generator<{ text: MaterializedChatMessageTextChunk; isTextSibling: boolean }> {
for (const child of node.children) {
if (child instanceof MaterializedChatMessageTextChunk) {
yield child;
yield { text: child, isTextSibling };
isTextSibling = true;
} else {
yield* textChunks(child);
yield* textChunks(child, isTextSibling);
isTextSibling = false;
}
}
}
Expand Down Expand Up @@ -309,7 +337,7 @@ function removeLowestPriorityChild(children: MaterializedNode[]) {
}

const lowest = children[lowestIndex];
if (lowest instanceof MaterializedChatMessageTextChunk) {
if (lowest instanceof MaterializedChatMessageTextChunk || (lowest instanceof MaterializedContainer && lowest.has(ContainerFlags.IsChunk))) {
children.splice(lowestIndex, 1);
} else {
lowest.removeLowestPriorityChild();
Expand Down
26 changes: 23 additions & 3 deletions src/base/promptElements.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import type { CancellationToken } from 'vscode';
import { contentType } from '.';
import * as JSONT from './jsonTypes';
import { ChatRole } from './openai';
import { PromptElement } from './promptElement';
import { BasePromptElementProps, PromptPiece, PromptSizing } from './types';
Expand Down Expand Up @@ -232,12 +231,22 @@ export class PrioritizedList extends PromptElement<PrioritizedListProps> {
return (
<>
{children.map((child, i) => {
child.props ??= {};
child.props.priority = this.props.descending
if (!child) {
return;
}

const priority = this.props.descending
? // First element in array of children has highest priority
this.props.priority - i
: // Last element in array of children has highest priority
this.props.priority - children.length + i;

if (typeof child !== 'object') {
return <TextChunk priority={priority}>{child}</TextChunk>;
}

child.props ??= {};
child.props.priority = priority;
return child;
})}
</>
Expand Down Expand Up @@ -283,3 +292,14 @@ export class LegacyPrioritization extends PromptElement {
return <>{this.props.children}</>;
}
}

/**
* Marker element that ensures all of its children are either included, or
* not included. This is similar to the `<TextChunk />` element, but it is more
* basic and can contain extrinsic children.
*/
export class Chunk extends PromptElement<BasePromptElementProps> {
render() {
return <>{this.props.children}</>;
}
}
19 changes: 14 additions & 5 deletions src/base/promptRenderer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import type { CancellationToken, Progress } from "vscode";
import * as JSONT from './jsonTypes';
import { PromptNodeType } from './jsonTypes';
import { MaterializedChatMessage, MaterializedChatMessageTextChunk, MaterializedContainer } from './materialized';
import { ContainerFlags, LineBreakBefore, MaterializedChatMessage, MaterializedChatMessageTextChunk, MaterializedContainer } from './materialized';
import { ChatMessage } from "./openai";
import { PromptElement } from "./promptElement";
import { AssistantMessage, BaseChatMessage, ChatMessagePromptElement, LegacyPrioritization, TextChunk, ToolMessage, isChatMessagePromptElement } from "./promptElements";
import { AssistantMessage, BaseChatMessage, ChatMessagePromptElement, Chunk, LegacyPrioritization, TextChunk, ToolMessage, isChatMessagePromptElement } from "./promptElements";
import { PromptMetadata, PromptReference } from "./results";
import { ITokenizer } from "./tokenizer/tokenizer";
import { ITracer } from './tracer';
Expand Down Expand Up @@ -253,7 +253,7 @@ export class PromptRenderer<P extends BasePromptElementProps> {
}

// Then finalize the chat messages
const messageResult = container.toChatMessages();
const messageResult = [...container.toChatMessages()];
const tokenCount = await container.tokenCount(this._tokenizer);
const remainingMetadata = [...container.allMetadata()];

Expand Down Expand Up @@ -646,11 +646,15 @@ class PromptTreeElement {
);
return parent;
} else {
let flags = 0;
if (this._obj instanceof LegacyPrioritization) flags |= ContainerFlags.IsLegacyPrioritization;
if (this._obj instanceof Chunk) flags |= ContainerFlags.IsChunk;

return new MaterializedContainer(
this._obj?.props.priority || 0,
this._children.map(child => child.materialize()),
this._metadata,
this._obj instanceof LegacyPrioritization,
flags,
);
}
}
Expand Down Expand Up @@ -682,7 +686,12 @@ class PromptText {
}

public materialize() {
return new MaterializedChatMessageTextChunk(this.text, this.priority ?? Number.MAX_SAFE_INTEGER, this.metadata || [], this.lineBreakBefore || this.childIndex === 0);
const lineBreak = this.lineBreakBefore
? LineBreakBefore.Always
: this.childIndex === 0
? LineBreakBefore.IfNotTextSibling
: LineBreakBefore.None;
return new MaterializedChatMessageTextChunk(this.text, this.priority ?? Number.MAX_SAFE_INTEGER, this.metadata || [], lineBreak);
}

public toJSON(): JSONT.TextJSON {
Expand Down
14 changes: 7 additions & 7 deletions src/base/test/materialized.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
*--------------------------------------------------------------------------------------------*/

import * as assert from 'assert';
import { MaterializedChatMessage, MaterializedChatMessageTextChunk, MaterializedContainer } from '../materialized';
import { LineBreakBefore, MaterializedChatMessage, MaterializedChatMessageTextChunk, MaterializedContainer } from '../materialized';
import { ChatRole } from '../openai';
import { ITokenizer } from '../tokenizer/tokenizer';
class MockTokenizer implements ITokenizer {
Expand All @@ -17,10 +17,10 @@ class MockTokenizer implements ITokenizer {
suite('Materialized', () => {
test('should calculate token count correctly', async () => {
const tokenizer = new MockTokenizer();
const child1 = new MaterializedChatMessageTextChunk('Hello', 1, [], false);
const child2 = new MaterializedChatMessageTextChunk('World', 1, [], false);
const child1 = new MaterializedChatMessageTextChunk('Hello', 1, [], LineBreakBefore.None);
const child2 = new MaterializedChatMessageTextChunk('World', 1, [], LineBreakBefore.None);
const message = new MaterializedChatMessage(ChatRole.User, 'user', undefined, undefined, 1, 0, [], [child1, child2]);
const container = new MaterializedContainer(1, [message], []);
const container = new MaterializedContainer(1, [message], [], 0);

assert.deepStrictEqual(await container.tokenCount(tokenizer), 13);
container.removeLowestPriorityChild();
Expand All @@ -29,10 +29,10 @@ suite('Materialized', () => {

test('should calculate lower bound token count correctly', async () => {
const tokenizer = new MockTokenizer();
const child1 = new MaterializedChatMessageTextChunk('Hello', 1, [], false);
const child2 = new MaterializedChatMessageTextChunk('World', 1, [], false);
const child1 = new MaterializedChatMessageTextChunk('Hello', 1, [], LineBreakBefore.None);
const child2 = new MaterializedChatMessageTextChunk('World', 1, [], LineBreakBefore.None);
const message = new MaterializedChatMessage(ChatRole.User, 'user', undefined, undefined, 1, 0, [], [child1, child2]);
const container = new MaterializedContainer(1, [message], []);
const container = new MaterializedContainer(1, [message], [], 0);

assert.deepStrictEqual(await container.upperBoundTokenCount(tokenizer), 13);
container.removeLowestPriorityChild();
Expand Down
Loading

0 comments on commit af5ac44

Please sign in to comment.