diff --git a/js/src/evaluation/runner.ts b/js/src/evaluation/runner.ts index 35352a2be..f70a15daf 100644 --- a/js/src/evaluation/runner.ts +++ b/js/src/evaluation/runner.ts @@ -100,8 +100,8 @@ interface ExperimentResultRow { */ class ExperimentResults implements AsyncIterableIterator { private manager: _ExperimentManager; - private results: ExperimentResultRow[] = []; - private processedCount = 0; + results: ExperimentResultRow[] = []; + processedCount = 0; private _summaryResults: EvaluationResults; get summaryResults(): EvaluationResults { @@ -110,7 +110,6 @@ class ExperimentResults implements AsyncIterableIterator { constructor(experimentManager: _ExperimentManager) { this.manager = experimentManager; - this.processData(this.manager); } get experimentName(): string { @@ -131,7 +130,7 @@ class ExperimentResults implements AsyncIterableIterator { } } - private async processData(manager: _ExperimentManager): Promise { + async processData(manager: _ExperimentManager): Promise { const results = manager.getResults(); for await (const item of results) { this.results.push(item); @@ -159,7 +158,6 @@ const _isCallable = (target: TargetT | AsyncIterable): boolean => ("invoke" in target && typeof target.invoke === "function") ); - async function _evaluate( target: TargetT | AsyncIterable, fields: { @@ -174,12 +172,12 @@ async function _evaluate( } ): Promise { const client = fields.client ?? new Client(); - const runs = _isCallable(target) ? null : target as AsyncIterable; + const runs = _isCallable(target) ? null : (target as AsyncIterable); const [experiment_, newRuns] = await _resolveExperiment( fields.experiment ?? null, runs, - client, - ) + client + ); let manager = await new _ExperimentManager({ data: fields.data, @@ -204,6 +202,7 @@ async function _evaluate( } // Start consuming the results. const results = new ExperimentResults(manager); + await results.processData(manager); return results; } @@ -359,7 +358,9 @@ class _ExperimentManager extends _ExperimentManagerMixin { get examples(): AsyncIterable { if (this._examples === undefined) { - this._examples = _resolveData(this._data, { client: this.client }); + return _resolveData(this._data, { client: this.client }); + } else { + return this._examples; } return async function* (this: _ExperimentManager) { for await (const example of this._examples!) { @@ -534,12 +535,12 @@ class _ExperimentManager extends _ExperimentManagerMixin { if (!this._summaryResults) { return { results: [] }; } - + const results: EvaluationResult[] = []; for await (const evaluationResults of this._summaryResults) { results.push(...evaluationResults.results); } - + return { results }; } @@ -562,7 +563,7 @@ class _ExperimentManager extends _ExperimentManagerMixin { if (maxConcurrency === 0) { for await (const example of this.examples) { - yield _forward( + yield await _forward( fn, example, this.experimentName, @@ -590,7 +591,7 @@ class _ExperimentManager extends _ExperimentManagerMixin { ); } - for (const future of futures) { + for await (const future of futures) { yield future; } } @@ -761,7 +762,7 @@ class _ExperimentManager extends _ExperimentManagerMixin { } async function _forward( - fn: (...args: any[]) => any, // TODO fix this type. What is `rh.SupportsLangsmithExtra`? + fn: (...args: any[]) => Promise, // TODO fix this type. What is `rh.SupportsLangsmithExtra`? example: Example, experimentName: string, metadata: Record, @@ -774,7 +775,7 @@ async function _forward( }; try { - fn(example.inputs, { + await fn(example.inputs, { reference_example_id: example.id, on_end: _getRun, project_name: experimentName, @@ -928,8 +929,10 @@ function _resolveEvaluators( async function _resolveExperiment( experiment: TracerSession | null, runs: AsyncIterable | null, - client: Client, -): Promise<[TracerSession | string | undefined, AsyncIterable | undefined]> { + client: Client +): Promise< + [TracerSession | string | undefined, AsyncIterable | undefined] +> { // TODO: Remove this, handle outside the manager if (experiment !== null) { if (!experiment.name) { @@ -947,7 +950,9 @@ async function _resolveExperiment( const [runsClone, runsOriginal] = results; const runsCloneIterator = runsClone[Symbol.asyncIterator](); // todo: this is `any`. does it work properly? - const firstRun = await runsCloneIterator.next().then(result => result.value); + const firstRun = await runsCloneIterator + .next() + .then((result) => result.value); const retrievedExperiment = await client.readProject(firstRun.sessionId); if (!retrievedExperiment.name) { throw new Error("Experiment name not found for provided runs."); @@ -956,4 +961,4 @@ async function _resolveExperiment( } return [undefined, undefined]; -} \ No newline at end of file +} diff --git a/js/src/tests/evaluate.int.test.ts b/js/src/tests/evaluate.int.test.ts new file mode 100644 index 000000000..cfec53367 --- /dev/null +++ b/js/src/tests/evaluate.int.test.ts @@ -0,0 +1,16 @@ +import { evaluate } from "../evaluation/runner.js"; + +test("evaluate can evaluate", async () => { + const dummyDatasetName = "ds-somber-yesterday-36"; + const evalFunc = (input: Record) => { + console.log("__input__", input); + return input; + }; + // const evalRunnable = new RunnableLambda({ func: (input: Record) => { + // console.log("input", input); + // }}); + + const evalRes = await evaluate(evalFunc, dummyDatasetName); + console.log(evalRes.results); + expect(evalRes.processedCount).toBeGreaterThan(0); +}); diff --git a/js/src/traceable.ts b/js/src/traceable.ts index b33dd621e..6c8cf4cff 100644 --- a/js/src/traceable.ts +++ b/js/src/traceable.ts @@ -147,6 +147,10 @@ export function traceable any>( return new Promise((resolve, reject) => { void asyncLocalStorage.run(currentRunTree, async () => { try { + const onEnd = args.find((obj) => "on_end" in obj)?.on_end; + if (onEnd) { + onEnd(currentRunTree); + } const rawOutput = await wrappedFunc(...rawInputs); if (isAsyncIterable(rawOutput)) { // eslint-disable-next-line no-inner-declarations