Skip to content

Commit

Permalink
discojs*: rename .unbatch() to .flat()
Browse files Browse the repository at this point in the history
  • Loading branch information
JulienVig committed Nov 14, 2024
1 parent c477bb3 commit cb806c0
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion cli/src/benchmark_gpt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ async function main(args: Required<CLIArguments>): Promise<void> {
task.trainingInformation.maxSequenceLength = contextLength
const dataset = loadText('../datasets/wikitext/wiki.train.tokens')
.map(text => processing.tokenize(tokenizer, text))
.unbatch()
.flat()
.batchWithOverlap(config.blockSize)

const preprocessedDataset = dataset
Expand Down
2 changes: 1 addition & 1 deletion cli/src/train_gpt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async function main(): Promise<void> {

const tokenDataset = new Dataset([data])
.map((text: string) => processing.tokenize(tokenizer, text))
.unbatch()
.flat()
.batchWithOverlap(config.blockSize)
.map((tokens) => [tokens.pop(), tokens.last()] as [List<number>, number])
.repeat()
Expand Down
2 changes: 1 addition & 1 deletion discojs/src/dataset/dataset.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ describe("dataset", () => {
const blockSize = 4

const parsed = new Dataset([expectedTokens])
.unbatch()
.flat()
.batchWithOverlap(blockSize)

// -1 because the last sequence is dropped as there is no next token label
Expand Down
4 changes: 2 additions & 2 deletions discojs/src/dataset/dataset.ts
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ export class Dataset<T> implements AsyncIterable<T> {
);
}

/** Flatten chunks */
unbatch<U>(this: Dataset<Batched<U>>): Dataset<U> {
/** Flatten batches/arrays of elements */
flat<U>(this: Dataset<Batched<U>>): Dataset<U> {
return new Dataset(
async function* (this: Dataset<Batched<U>>) {
for await (const batch of this) yield* batch;
Expand Down
4 changes: 2 additions & 2 deletions discojs/src/processing/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ export async function preprocess<D extends DataType>(

const tokenizer = await models.getTaskTokenizer(t);
return d.map(text => processing.tokenize(tokenizer, text))
.unbatch()
.flat()
.batchWithOverlap(blockSize)
.map((tokens) => [tokens.pop(), tokens.last()]) as
Dataset<DataFormat.ModelEncoded[D]>;
Expand Down Expand Up @@ -101,7 +101,7 @@ export async function preprocessWithoutLabel<D extends DataType>(
const tokenizer = await models.getTaskTokenizer(t);

return d.map(text => processing.tokenize(tokenizer, text))
.unbatch()
.flat()
.batch(blockSize)
}
}
Expand Down
4 changes: 2 additions & 2 deletions discojs/src/validator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ export class Validator<D extends DataType> {
.zip(batch.map(([_, outputs]) => outputs))
.map(([inferred, truth]) => inferred === truth),
)
.unbatch();
.flat();

for await (const e of results) yield e;
}
Expand All @@ -36,7 +36,7 @@ export class Validator<D extends DataType> {
)
.batch(this.task.trainingInformation.batchSize)
.map((batch) => this.#model.predict(batch))
.unbatch();
.flat();

const predictions = await processing.postprocess(
this.task,
Expand Down

0 comments on commit cb806c0

Please sign in to comment.