Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(JS): add processInputs / processOutputs to traceable + add usage metadata to wrapOpenAI #1095

Merged
merged 7 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
334 changes: 334 additions & 0 deletions js/src/tests/traceable.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { getAssumedTreeFromCalls } from "./utils/tree.js";
import { mockClient } from "./utils/mock_client.js";
import { Client, overrideFetchImplementation } from "../index.js";
import { AsyncLocalStorageProviderSingleton } from "../singletons/traceable.js";
import { KVMap } from "../schemas.js";

test("basic traceable implementation", async () => {
const { client, callSpy } = mockClient();
Expand Down Expand Up @@ -1129,3 +1130,336 @@ test("traceable continues execution when client throws error", async () => {
expect(errorClient.createRun).toHaveBeenCalled();
expect(errorClient.updateRun).toHaveBeenCalled();
});

test("traceable with processInputs", async () => {
const { client, callSpy } = mockClient();

const processInputs = jest.fn((inputs: Readonly<KVMap>) => {
return { ...inputs, password: "****" };
});

const func = traceable(
async function func(input: { username: string; password: string }) {
// The function should receive the original inputs
expect(input.password).toBe("secret");
return `Welcome, ${input.username}`;
},
{
client,
tracingEnabled: true,
processInputs,
}
);

await func({ username: "user1", password: "secret" });

expect(processInputs).toHaveBeenCalledWith({
username: "user1",
password: "secret",
});
// Verify that the logged inputs have the password masked
expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["func:0"],
edges: [],
data: {
"func:0": {
inputs: {
username: "user1",
password: "****",
},
outputs: { outputs: "Welcome, user1" },
},
},
});
});

test("traceable with processOutputs", async () => {
const { client, callSpy } = mockClient();

const processOutputs = jest.fn((_outputs: Readonly<KVMap>) => {
return { outputs: "Modified Output" };
});

const func = traceable(
async function func(input: string) {
return `Original Output for ${input}`;
},
{
client,
tracingEnabled: true,
processOutputs,
}
);

const result = await func("test");

expect(processOutputs).toHaveBeenCalledWith({
outputs: "Original Output for test",
});
expect(result).toBe("Original Output for test");
// Verify that the tracing data shows the modified output
expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["func:0"],
edges: [],
data: {
"func:0": {
inputs: { input: "test" },
outputs: { outputs: "Modified Output" },
},
},
});
});

test("traceable with processInputs throwing error does not affect invocation", async () => {
const { client, callSpy } = mockClient();

const processInputs = jest.fn((_inputs: Readonly<KVMap>) => {
throw new Error("processInputs error");
});

const func = traceable(
async function func(input: { username: string }) {
// This should not be called
return `Hello, ${input.username}`;
},
{
client,
tracingEnabled: true,
processInputs,
}
);

const result = await func({ username: "user1" });

expect(processInputs).toHaveBeenCalledWith({ username: "user1" });
expect(result).toBe("Hello, user1");

expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["func:0"],
edges: [],
data: {
"func:0": {
inputs: { username: "user1" },
outputs: { outputs: "Hello, user1" },
},
},
});
});

test("traceable with processOutputs throwing error does not affect invocation", async () => {
const { client, callSpy } = mockClient();

const processOutputs = jest.fn((_outputs: Readonly<KVMap>) => {
throw new Error("processOutputs error");
});

const func = traceable(
async function func(input: string) {
return `Original Output for ${input}`;
},
{
client,
tracingEnabled: true,
processOutputs,
}
);

const result = await func("test");

expect(processOutputs).toHaveBeenCalledWith({
outputs: "Original Output for test",
});
expect(result).toBe("Original Output for test");

expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["func:0"],
edges: [],
data: {
"func:0": {
inputs: { input: "test" },
outputs: { outputs: "Original Output for test" },
},
},
});
});

test("traceable async generator with processOutputs", async () => {
const { client, callSpy } = mockClient();

const processOutputs = jest.fn((outputs: Readonly<KVMap>) => {
return { outputs: outputs.outputs.map((output: number) => output * 2) };
});

const func = traceable(
async function* func() {
for (let i = 1; i <= 3; i++) {
yield i;
}
},
{
client,
tracingEnabled: true,
processOutputs,
}
);

const results: number[] = [];
for await (const value of func()) {
results.push(value);
}

expect(results).toEqual([1, 2, 3]); // Original values
expect(processOutputs).toHaveBeenCalledWith({ outputs: [1, 2, 3] });

// Tracing data should reflect the processed outputs
expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["func:0"],
edges: [],
data: {
"func:0": {
outputs: { outputs: [2, 4, 6] }, // Processed outputs
},
},
});
});

test("traceable function returning object with async iterable and processOutputs", async () => {
const { client, callSpy } = mockClient();

const processOutputs = jest.fn((outputs: Readonly<KVMap>) => {
return { outputs: outputs.outputs.map((output: number) => output * 2) };
});

const func = traceable(
async function func() {
return {
data: "some data",
stream: (async function* () {
for (let i = 1; i <= 3; i++) {
yield i;
}
})(),
};
},
{
client,
tracingEnabled: true,
processOutputs,
__finalTracedIteratorKey: "stream",
}
);

const result = await func();
expect(result.data).toBe("some data");

const results: number[] = [];
for await (const value of result.stream) {
results.push(value);
}

expect(results).toEqual([1, 2, 3]);
expect(processOutputs).toHaveBeenCalledWith({ outputs: [1, 2, 3] });

expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["func:0"],
edges: [],
data: {
"func:0": {
outputs: { outputs: [2, 4, 6] },
},
},
});
});

test("traceable generator function with processOutputs", async () => {
const { client, callSpy } = mockClient();

const processOutputs = jest.fn((outputs: Readonly<KVMap>) => {
return { outputs: outputs.outputs.map((output: number) => output * 2) };
});

function* func() {
for (let i = 1; i <= 3; i++) {
yield i;
}
}

const tracedFunc = traceable(func, {
client,
tracingEnabled: true,
processOutputs,
});

const results: number[] = [];
for (const value of await tracedFunc()) {
results.push(value);
}

expect(results).toEqual([1, 2, 3]);
expect(processOutputs).toHaveBeenCalledWith({ outputs: [1, 2, 3] });

expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["func:0"],
edges: [],
data: {
"func:0": {
outputs: { outputs: [2, 4, 6] },
},
},
});
});

test("traceable with complex outputs", async () => {
const { client, callSpy } = mockClient();

const processOutputs = jest.fn((outputs: Readonly<KVMap>) => {
return { data: "****", output: outputs.output, nested: outputs.nested };
});

const func = traceable(
async function func(input: string) {
return {
data: "some sensitive data",
output: `Original Output for ${input}`,
nested: {
key: "value",
nestedOutput: `Nested Output for ${input}`,
},
};
},
{
client,
tracingEnabled: true,
processOutputs,
}
);

const result = await func("test");

expect(result).toEqual({
data: "some sensitive data",
output: "Original Output for test",
nested: {
key: "value",
nestedOutput: "Nested Output for test",
},
});

expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["func:0"],
edges: [],
data: {
"func:0": {
inputs: { input: "test" },
outputs: {
data: "****",
output: "Original Output for test",
nested: {
key: "value",
nestedOutput: "Nested Output for test",
},
},
},
},
});
});
Loading
Loading