Skip to content

Commit

Permalink
implemented some more code
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed Apr 23, 2024
1 parent 9a2cd84 commit 7cb8cf7
Show file tree
Hide file tree
Showing 2 changed files with 191 additions and 11 deletions.
53 changes: 52 additions & 1 deletion js/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ import {
getRuntimeEnvironment,
} from "./utils/env.js";

import { RunEvaluator } from "./evaluation/evaluator.js";
import {
EvaluationResult,
EvaluationResults,
RunEvaluator,
} from "./evaluation/evaluator.js";
import { __version__ } from "./index.js";

interface ClientConfig {
Expand Down Expand Up @@ -2426,4 +2430,51 @@ export class Client {
yield* tokens;
}
}

private _selectEvalResults(
results: EvaluationResult | EvaluationResults
): Array<EvaluationResult> {
let results_: Array<EvaluationResult>;
if ("results" in results) {
results_ = results.results;
} else {
results_ = [results];
}
return results_;
}

public async logEvaluationFeedback(
evaluatorResponse: EvaluationResult | EvaluationResults,
run?: Run,
sourceInfo?: { [key: string]: any }
): Promise<EvaluationResult[]> {
const results: Array<EvaluationResult> =
this._selectEvalResults(evaluatorResponse);
for (const res of results) {
let sourceInfo_ = sourceInfo || {};
if (res.evaluatorInfo) {
sourceInfo_ = { ...res.evaluatorInfo, ...sourceInfo_ };
}
let runId_: string | undefined = undefined;
if (res.targetRunId) {
runId_ = res.targetRunId;
} else if (run) {
runId_ = run.id;
} else {
throw new Error("No run ID provided");
}

await this.createFeedback(runId_, res.key, {
score: res.score,
value: res.value,
comment: res.comment,
correction: res.correction,
sourceInfo: sourceInfo_,
sourceRunId: res.sourceRunId,
feedbackConfig: res.feedbackConfig as FeedbackConfig | undefined,
feedbackSourceType: "model",
});
}
return results;
}
}
149 changes: 139 additions & 10 deletions js/src/evaluation/runner.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import { Client } from "../index.js";
import { Example, KVMap, Run, TracerSession } from "../schemas.js";
import { traceable } from "../traceable.js";
import { Client, RunTree, RunTreeConfig } from "../index.js";
import { BaseRun, Example, KVMap, Run, TracerSession } from "../schemas.js";
import {
TraceableFunction,
isTraceableFunction,
traceable,
} from "../traceable.js";
import { getGitInfo } from "../utils/_git.js";
import { isUUIDv4 } from "../utils/_uuid.js";
import { AsyncCaller } from "../utils/async_caller.js";
import { getLangChainEnvVarsMetadata } from "../utils/env.js";
import { randomName } from "./_random_name.js";
import {
Expand Down Expand Up @@ -433,30 +438,95 @@ class _ExperimentManager extends _ExperimentManagerMixin {

// Private methods

_predict(
/**
* Run the target function on the examples.
* @param {TargetT} target The target function to evaluate.
* @param options
* @returns {AsyncGenerator<_ForwardResults>} An async generator of the results.
*/
async *_predict(
target: TargetT,
options?: {
maxConcurrency?: number;
}
): AsyncGenerator<_ForwardResults> {
throw new Error("Not implemented");
const fn = wrapFunctionAndEnsureTraceable(target);
const maxConcurrency = options?.maxConcurrency ?? 0;

if (maxConcurrency === 0) {
for await (const example of this.examples) {
yield _forward(
fn,
example,
this.experimentName,
this._metadata,
this.client
);
}
} else {
const caller = new AsyncCaller({
maxConcurrency,
});

const futures: Array<Promise<_ForwardResults>> = [];

for await (const example of this.examples) {
futures.push(
caller.call(
_forward,
fn,
example,
this.experimentName,
this._metadata,
this.client
)
);
}

for (const future of futures) {
yield await future;
}
}

// Close out the project.
this._end();
}

_runEvaluators(
async _runEvaluators(
evaluators: Array<RunEvaluator>,
currentResults: ExperimentResultRow
): ExperimentResultRow {
throw new Error("Not implemented");
): Promise<ExperimentResultRow> {
const { run, example, evaluationResults } = currentResults;
for (const evaluator of evaluators) {
try {
const evaluatorResponse = await evaluator.evaluateRun(run, example);
evaluationResults.results.push(
...(await this.client.logEvaluationFeedback(evaluatorResponse, run))
);
} catch (e) {
console.error(
`Error running evaluator ${evaluator.evaluateRun.name} on run ${
run.id
}: ${JSON.stringify(e, null, 2)}`
);
}
}

return {
run,
example,
evaluationResults,
};
}

_score(
async *_score(
evaluators: Array<RunEvaluator>,
maxConcurrency?: number
): AsyncIterable<ExperimentResultRow> {
throw new Error("Not implemented");
}

_applySummaryEvaluators(
async *_applySummaryEvaluators(
summaryEvaluators: Array<SummaryEvaluatorT>
): AsyncGenerator<EvaluationResults> {
throw new Error("Not implemented");
Expand Down Expand Up @@ -494,6 +564,48 @@ class _ExperimentManager extends _ExperimentManagerMixin {
}
}

async function _forward(
fn: (...args: any[]) => any, // TODO fix this type. What is `rh.SupportsLangsmithExtra`?
example: Example,
experimentName: string,
metadata: Record<string, any>,
client: Client
): Promise<_ForwardResults> {
let run: BaseRun | null = null;

const _getRun = (r: RunTree): void => {
run = r;
};

try {
fn(example.inputs, {
reference_example_id: example.id,
on_end: _getRun,
project_name: experimentName,
metadata: {
...metadata,
example_version: example.modified_at
? new Date(example.modified_at).toISOString()
: new Date(example.created_at).toISOString(),
},
client,
});
} catch (e) {
console.error(
`Error running target function: ${JSON.stringify(e, null, 2)}`
);
}

if (!run) {
throw new Error("Run not created by target function.");
}

return {
run,
example,
};
}

function _resolveData(
data: DataT,
options: {
Expand Down Expand Up @@ -581,3 +693,20 @@ async function* asyncTee<T>(

yield* iterators;
}

interface SupportsLangSmithExtra<R> {
(target: TargetT, langSmithExtra?: Partial<RunTreeConfig>): R;
}

function wrapFunctionAndEnsureTraceable(target: TargetT) {
if (typeof target === "function") {
if (isTraceableFunction(target)) {
return target as SupportsLangSmithExtra<ReturnType<typeof target>>;
} else {
return traceable(target, {
name: "target",
});
}
}
throw new Error("Target must be runnable function");
}

0 comments on commit 7cb8cf7

Please sign in to comment.