Skip to content

Commit

Permalink
Fix runnable map, start adding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 committed Aug 1, 2024
1 parent eda0df5 commit 139b0a4
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 13 deletions.
11 changes: 1 addition & 10 deletions langchain-core/src/runnables/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2082,16 +2082,7 @@ export class RunnableMap<
);
}
);
if (options?.signal) {
promises.push(
new Promise<never>((_, reject) => {
options.signal?.addEventListener("abort", () =>
reject(new Error("Aborted"))
);
})
);
}
await Promise.all(promises);
await raceWithSignal(Promise.all(promises), options?.signal);
} catch (e) {
await runManager?.handleChainError(e);
throw e;
Expand Down
39 changes: 39 additions & 0 deletions langchain-core/src/runnables/tests/runnable_map.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,45 @@ test("Test map inference in a sequence", async () => {
);
});

test("Test invoke with signal", async () => {
const map = RunnableMap.from({
question: new RunnablePassthrough(),
context: async () => {
await new Promise((resolve) => setTimeout(resolve, 500));
return "SOME STUFF";
},
});
const controller = new AbortController();
await expect(async () => {
await Promise.all([
map.invoke("testing", {
signal: controller.signal,
}),
new Promise<void>((resolve) => {
controller.abort();
resolve();
}),
]);
}).rejects.toThrowError();
});

test("Test stream with signal", async () => {
const map = RunnableMap.from({
question: new RunnablePassthrough(),
context: async () => {
await new Promise((resolve) => setTimeout(resolve, 500));
return "SOME STUFF";
},
});
const controller = new AbortController();
await expect(async () => {
const stream = await map.stream("TESTING", { signal: controller.signal });
for await (const _ of stream) {
controller.abort();
}
}).rejects.toThrowError();
});

test("Should not allow mismatched inputs", async () => {
const prompt = ChatPromptTemplate.fromTemplate(
"context: {context}, question: {question}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2120,3 +2120,24 @@ test("Runnable streamEvents method with text/event-stream encoding", async () =>

expect(decoder.decode(events[3])).toEqual("event: end\n\n");
});

test("Runnable streamEvents method should respect passed signal", async () => {
const r = RunnableLambda.from(reverse);

const chain = r
.withConfig({ runName: "1" })
.pipe(r.withConfig({ runName: "2" }))
.pipe(r.withConfig({ runName: "3" }));

const controller = new AbortController();
const eventStream = await chain.streamEvents("hello", {
version: "v2",
signal: controller.signal,
});
await expect(async () => {
for await (const _ of eventStream) {
// Abort after the first chunk
controller.abort();
}
}).rejects.toThrowError();
});
7 changes: 4 additions & 3 deletions langchain-core/src/utils/signal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ export async function raceWithSignal<T>(
if (signal === undefined) {
return promise;
}
if (signal.aborted) {
throw new Error("AbortError");
}
return Promise.race([
promise,
new Promise<never>((_, reject) => {
// Must be inside of the promise to avoid a race condition
if (signal.aborted) {
return reject(new Error("Aborted"));

Check failure on line 13 in langchain-core/src/utils/signal.ts

View workflow job for this annotation

GitHub Actions / Check linting

Return values from promise executor functions cannot be read
}
signal.addEventListener("abort", () => reject(new Error("Aborted")));
}),
]);
Expand Down

0 comments on commit 139b0a4

Please sign in to comment.