Skip to content

Commit

Permalink
core[minor]: Add new FakeStreamingChatModel & update test (langchain-…
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul authored Apr 5, 2024
1 parent ed6b3a9 commit fe75251
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 1 deletion.
3 changes: 2 additions & 1 deletion langchain-core/src/runnables/tests/runnable.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import {
FakeRunnable,
FakeListChatModel,
SingleRunExtractor,
FakeStreamingChatModel,
} from "../../utils/testing/index.js";
import { RunnableSequence, RunnableLambda } from "../base.js";
import { RouterRunnable } from "../router.js";
Expand Down Expand Up @@ -448,7 +449,7 @@ describe("runId config", () => {

test("stream", async () => {
const tracer = new SingleRunExtractor();
const llm = new FakeChatModel({});
const llm = new FakeStreamingChatModel({});
const testId = uuidv4();
const stream = await llm.stream("gg", {
callbacks: [tracer],
Expand Down
77 changes: 77 additions & 0 deletions langchain-core/src/utils/testing/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,83 @@ export class FakeChatModel extends BaseChatModel {
}
}

export class FakeStreamingChatModel extends BaseChatModel {
sleep?: number = 50;

responses?: BaseMessage[];

thrownErrorString?: string;

constructor(
fields: {
sleep?: number;
responses?: BaseMessage[];
thrownErrorString?: string;
} & BaseLLMParams
) {
super(fields);
this.sleep = fields.sleep ?? this.sleep;
this.responses = fields.responses;
this.thrownErrorString = fields.thrownErrorString;
}

_llmType() {
return "fake";
}

async _generate(
messages: BaseMessage[],
_options: this["ParsedCallOptions"],
_runManager?: CallbackManagerForLLMRun
): Promise<ChatResult> {
if (this.thrownErrorString) {
throw new Error(this.thrownErrorString);
}

const content = this.responses?.[0].content ?? messages[0].content;
const generation: ChatResult = {
generations: [{
text: "",
message: new AIMessage({
content,
})
}]
}

return generation
}

async *_streamResponseChunks(
messages: BaseMessage[],
_options: this["ParsedCallOptions"],
_runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
if (this.thrownErrorString) {
throw new Error(this.thrownErrorString);
}
const content = this.responses?.[0].content ?? messages[0].content;
if (typeof content !== "string") {
for (const _ of this.responses ?? messages) {
yield new ChatGenerationChunk({
text: "",
message: new AIMessageChunk({
content,
}),
})
}
} else {
for (const _ of this.responses ?? messages) {
yield new ChatGenerationChunk({
text: content,
message: new AIMessageChunk({
content,
}),
})
}
}
}
}

export class FakeRetriever extends BaseRetriever {
lc_namespace = ["test", "fake"];

Expand Down

0 comments on commit fe75251

Please sign in to comment.