From 7cb8cf7e0689a9c70cb77e6d831b5ff256ee3ed2 Mon Sep 17 00:00:00 2001 From: bracesproul Date: Tue, 23 Apr 2024 15:00:25 -0700 Subject: [PATCH] implemented some more code --- js/src/client.ts | 53 ++++++++++++- js/src/evaluation/runner.ts | 149 +++++++++++++++++++++++++++++++++--- 2 files changed, 191 insertions(+), 11 deletions(-) diff --git a/js/src/client.ts b/js/src/client.ts index 749f02afa..703113b48 100644 --- a/js/src/client.ts +++ b/js/src/client.ts @@ -33,7 +33,11 @@ import { getRuntimeEnvironment, } from "./utils/env.js"; -import { RunEvaluator } from "./evaluation/evaluator.js"; +import { + EvaluationResult, + EvaluationResults, + RunEvaluator, +} from "./evaluation/evaluator.js"; import { __version__ } from "./index.js"; interface ClientConfig { @@ -2426,4 +2430,51 @@ export class Client { yield* tokens; } } + + private _selectEvalResults( + results: EvaluationResult | EvaluationResults + ): Array { + let results_: Array; + if ("results" in results) { + results_ = results.results; + } else { + results_ = [results]; + } + return results_; + } + + public async logEvaluationFeedback( + evaluatorResponse: EvaluationResult | EvaluationResults, + run?: Run, + sourceInfo?: { [key: string]: any } + ): Promise { + const results: Array = + this._selectEvalResults(evaluatorResponse); + for (const res of results) { + let sourceInfo_ = sourceInfo || {}; + if (res.evaluatorInfo) { + sourceInfo_ = { ...res.evaluatorInfo, ...sourceInfo_ }; + } + let runId_: string | undefined = undefined; + if (res.targetRunId) { + runId_ = res.targetRunId; + } else if (run) { + runId_ = run.id; + } else { + throw new Error("No run ID provided"); + } + + await this.createFeedback(runId_, res.key, { + score: res.score, + value: res.value, + comment: res.comment, + correction: res.correction, + sourceInfo: sourceInfo_, + sourceRunId: res.sourceRunId, + feedbackConfig: res.feedbackConfig as FeedbackConfig | undefined, + feedbackSourceType: "model", + }); + } + return results; + } } diff --git a/js/src/evaluation/runner.ts b/js/src/evaluation/runner.ts index 36ff7e207..3da3e64a9 100644 --- a/js/src/evaluation/runner.ts +++ b/js/src/evaluation/runner.ts @@ -1,8 +1,13 @@ -import { Client } from "../index.js"; -import { Example, KVMap, Run, TracerSession } from "../schemas.js"; -import { traceable } from "../traceable.js"; +import { Client, RunTree, RunTreeConfig } from "../index.js"; +import { BaseRun, Example, KVMap, Run, TracerSession } from "../schemas.js"; +import { + TraceableFunction, + isTraceableFunction, + traceable, +} from "../traceable.js"; import { getGitInfo } from "../utils/_git.js"; import { isUUIDv4 } from "../utils/_uuid.js"; +import { AsyncCaller } from "../utils/async_caller.js"; import { getLangChainEnvVarsMetadata } from "../utils/env.js"; import { randomName } from "./_random_name.js"; import { @@ -433,30 +438,95 @@ class _ExperimentManager extends _ExperimentManagerMixin { // Private methods - _predict( + /** + * Run the target function on the examples. + * @param {TargetT} target The target function to evaluate. + * @param options + * @returns {AsyncGenerator<_ForwardResults>} An async generator of the results. + */ + async *_predict( target: TargetT, options?: { maxConcurrency?: number; } ): AsyncGenerator<_ForwardResults> { - throw new Error("Not implemented"); + const fn = wrapFunctionAndEnsureTraceable(target); + const maxConcurrency = options?.maxConcurrency ?? 0; + + if (maxConcurrency === 0) { + for await (const example of this.examples) { + yield _forward( + fn, + example, + this.experimentName, + this._metadata, + this.client + ); + } + } else { + const caller = new AsyncCaller({ + maxConcurrency, + }); + + const futures: Array> = []; + + for await (const example of this.examples) { + futures.push( + caller.call( + _forward, + fn, + example, + this.experimentName, + this._metadata, + this.client + ) + ); + } + + for (const future of futures) { + yield await future; + } + } + + // Close out the project. + this._end(); } - _runEvaluators( + async _runEvaluators( evaluators: Array, currentResults: ExperimentResultRow - ): ExperimentResultRow { - throw new Error("Not implemented"); + ): Promise { + const { run, example, evaluationResults } = currentResults; + for (const evaluator of evaluators) { + try { + const evaluatorResponse = await evaluator.evaluateRun(run, example); + evaluationResults.results.push( + ...(await this.client.logEvaluationFeedback(evaluatorResponse, run)) + ); + } catch (e) { + console.error( + `Error running evaluator ${evaluator.evaluateRun.name} on run ${ + run.id + }: ${JSON.stringify(e, null, 2)}` + ); + } + } + + return { + run, + example, + evaluationResults, + }; } - _score( + async *_score( evaluators: Array, maxConcurrency?: number ): AsyncIterable { throw new Error("Not implemented"); } - _applySummaryEvaluators( + async *_applySummaryEvaluators( summaryEvaluators: Array ): AsyncGenerator { throw new Error("Not implemented"); @@ -494,6 +564,48 @@ class _ExperimentManager extends _ExperimentManagerMixin { } } +async function _forward( + fn: (...args: any[]) => any, // TODO fix this type. What is `rh.SupportsLangsmithExtra`? + example: Example, + experimentName: string, + metadata: Record, + client: Client +): Promise<_ForwardResults> { + let run: BaseRun | null = null; + + const _getRun = (r: RunTree): void => { + run = r; + }; + + try { + fn(example.inputs, { + reference_example_id: example.id, + on_end: _getRun, + project_name: experimentName, + metadata: { + ...metadata, + example_version: example.modified_at + ? new Date(example.modified_at).toISOString() + : new Date(example.created_at).toISOString(), + }, + client, + }); + } catch (e) { + console.error( + `Error running target function: ${JSON.stringify(e, null, 2)}` + ); + } + + if (!run) { + throw new Error("Run not created by target function."); + } + + return { + run, + example, + }; +} + function _resolveData( data: DataT, options: { @@ -581,3 +693,20 @@ async function* asyncTee( yield* iterators; } + +interface SupportsLangSmithExtra { + (target: TargetT, langSmithExtra?: Partial): R; +} + +function wrapFunctionAndEnsureTraceable(target: TargetT) { + if (typeof target === "function") { + if (isTraceableFunction(target)) { + return target as SupportsLangSmithExtra>; + } else { + return traceable(target, { + name: "target", + }); + } + } + throw new Error("Target must be runnable function"); +}