Skip to content

Commit

Permalink
feat: add processInputs / processOutputs to traceable
Browse files Browse the repository at this point in the history
  • Loading branch information
agola11 committed Oct 13, 2024
1 parent e1521a5 commit 3495382
Show file tree
Hide file tree
Showing 2 changed files with 405 additions and 24 deletions.
334 changes: 334 additions & 0 deletions js/src/tests/traceable.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { getAssumedTreeFromCalls } from "./utils/tree.js";
import { mockClient } from "./utils/mock_client.js";
import { Client, overrideFetchImplementation } from "../index.js";
import { AsyncLocalStorageProviderSingleton } from "../singletons/traceable.js";
import { KVMap } from "../schemas.js";

test("basic traceable implementation", async () => {
const { client, callSpy } = mockClient();
Expand Down Expand Up @@ -1129,3 +1130,336 @@ test("traceable continues execution when client throws error", async () => {
expect(errorClient.createRun).toHaveBeenCalled();
expect(errorClient.updateRun).toHaveBeenCalled();
});

test("traceable with processInputs", async () => {
const { client, callSpy } = mockClient();

const processInputs = jest.fn((inputs: Readonly<KVMap>) => {
return { ...inputs, password: "****" };
});

const func = traceable(
async function func(input: { username: string; password: string }) {
// The function should receive the original inputs
expect(input.password).toBe("secret");
return `Welcome, ${input.username}`;
},
{
client,
tracingEnabled: true,
processInputs,
}
);

await func({ username: "user1", password: "secret" });

expect(processInputs).toHaveBeenCalledWith({
username: "user1",
password: "secret",
});
// Verify that the logged inputs have the password masked
expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["func:0"],
edges: [],
data: {
"func:0": {
inputs: {
username: "user1",
password: "****",
},
outputs: { outputs: "Welcome, user1" },
},
},
});
});

test("traceable with processOutputs", async () => {
const { client, callSpy } = mockClient();

const processOutputs = jest.fn((_outputs: Readonly<KVMap>) => {
return { outputs: "Modified Output" };
});

const func = traceable(
async function func(input: string) {
return `Original Output for ${input}`;
},
{
client,
tracingEnabled: true,
processOutputs,
}
);

const result = await func("test");

expect(processOutputs).toHaveBeenCalledWith({
outputs: "Original Output for test",
});
expect(result).toBe("Original Output for test");
// Verify that the tracing data shows the modified output
expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["func:0"],
edges: [],
data: {
"func:0": {
inputs: { input: "test" },
outputs: { outputs: "Modified Output" },
},
},
});
});

test("traceable with processInputs throwing error does not affect invocation", async () => {
const { client, callSpy } = mockClient();

const processInputs = jest.fn((_inputs: Readonly<KVMap>) => {
throw new Error("processInputs error");
});

const func = traceable(
async function func(input: { username: string }) {
// This should not be called
return `Hello, ${input.username}`;
},
{
client,
tracingEnabled: true,
processInputs,
}
);

const result = await func({ username: "user1" });

expect(processInputs).toHaveBeenCalledWith({ username: "user1" });
expect(result).toBe("Hello, user1");

expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["func:0"],
edges: [],
data: {
"func:0": {
inputs: { username: "user1" },
outputs: { outputs: "Hello, user1" },
},
},
});
});

test("traceable with processOutputs throwing error does not affect invocation", async () => {
const { client, callSpy } = mockClient();

const processOutputs = jest.fn((_outputs: Readonly<KVMap>) => {
throw new Error("processOutputs error");
});

const func = traceable(
async function func(input: string) {
return `Original Output for ${input}`;
},
{
client,
tracingEnabled: true,
processOutputs,
}
);

const result = await func("test");

expect(processOutputs).toHaveBeenCalledWith({
outputs: "Original Output for test",
});
expect(result).toBe("Original Output for test");

expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["func:0"],
edges: [],
data: {
"func:0": {
inputs: { input: "test" },
outputs: { outputs: "Original Output for test" },
},
},
});
});

test("traceable async generator with processOutputs", async () => {
const { client, callSpy } = mockClient();

const processOutputs = jest.fn((outputs: Readonly<KVMap>) => {
return { outputs: outputs.outputs.map((output: number) => output * 2) };
});

const func = traceable(
async function* func() {
for (let i = 1; i <= 3; i++) {
yield i;
}
},
{
client,
tracingEnabled: true,
processOutputs,
}
);

const results: number[] = [];
for await (const value of func()) {
results.push(value);
}

expect(results).toEqual([1, 2, 3]); // Original values
expect(processOutputs).toHaveBeenCalledWith({ outputs: [1, 2, 3] });

// Tracing data should reflect the processed outputs
expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["func:0"],
edges: [],
data: {
"func:0": {
outputs: { outputs: [2, 4, 6] }, // Processed outputs
},
},
});
});

test("traceable function returning object with async iterable and processOutputs", async () => {
const { client, callSpy } = mockClient();

const processOutputs = jest.fn((outputs: Readonly<KVMap>) => {
return { outputs: outputs.outputs.map((output: number) => output * 2) };
});

const func = traceable(
async function func() {
return {
data: "some data",
stream: (async function* () {
for (let i = 1; i <= 3; i++) {
yield i;
}
})(),
};
},
{
client,
tracingEnabled: true,
processOutputs,
__finalTracedIteratorKey: "stream",
}
);

const result = await func();
expect(result.data).toBe("some data");

const results: number[] = [];
for await (const value of result.stream) {
results.push(value);
}

expect(results).toEqual([1, 2, 3]);
expect(processOutputs).toHaveBeenCalledWith({ outputs: [1, 2, 3] });

expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["func:0"],
edges: [],
data: {
"func:0": {
outputs: { outputs: [2, 4, 6] },
},
},
});
});

test("traceable generator function with processOutputs", async () => {
const { client, callSpy } = mockClient();

const processOutputs = jest.fn((outputs: Readonly<KVMap>) => {
return { outputs: outputs.outputs.map((output: number) => output * 2) };
});

function* func() {
for (let i = 1; i <= 3; i++) {
yield i;
}
}

const tracedFunc = traceable(func, {
client,
tracingEnabled: true,
processOutputs,
});

const results: number[] = [];
for (const value of await tracedFunc()) {
results.push(value);
}

expect(results).toEqual([1, 2, 3]);
expect(processOutputs).toHaveBeenCalledWith({ outputs: [1, 2, 3] });

expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["func:0"],
edges: [],
data: {
"func:0": {
outputs: { outputs: [2, 4, 6] },
},
},
});
});

test("traceable with complex outputs", async () => {
const { client, callSpy } = mockClient();

const processOutputs = jest.fn((outputs: Readonly<KVMap>) => {
return { data: "****", output: outputs.output, nested: outputs.nested };
});

const func = traceable(
async function func(input: string) {
return {
data: "some sensitive data",
output: `Original Output for ${input}`,
nested: {
key: "value",
nestedOutput: `Nested Output for ${input}`,
},
};
},
{
client,
tracingEnabled: true,
processOutputs,
}
);

const result = await func("test");

expect(result).toEqual({
data: "some sensitive data",
output: "Original Output for test",
nested: {
key: "value",
nestedOutput: "Nested Output for test",
},
});

expect(getAssumedTreeFromCalls(callSpy.mock.calls)).toMatchObject({
nodes: ["func:0"],
edges: [],
data: {
"func:0": {
inputs: { input: "test" },
outputs: {
data: "****",
output: "Original Output for test",
nested: {
key: "value",
nestedOutput: "Nested Output for test",
},
},
},
},
});
});
Loading

0 comments on commit 3495382

Please sign in to comment.