Skip to content

Commit

Permalink
google-common[minor]: Fix streaming tool calls
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed Jul 25, 2024
1 parent 88e2bca commit 4c43a75
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 18 deletions.
56 changes: 50 additions & 6 deletions libs/langchain-google-common/src/utils/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { v4 as uuidv4 } from "uuid";
import {
AIMessage,
AIMessageChunk,
AIMessageFields,
AIMessageChunkFields,
BaseMessage,
BaseMessageChunk,
BaseMessageFields,
Expand Down Expand Up @@ -566,7 +566,7 @@ export function chunkToString(chunk: BaseMessageChunk): string {
}

export function partToMessageChunk(part: GeminiPart): BaseMessageChunk {
const fields = partsToBaseMessageFields([part]);
const fields = partsToBaseMessageChunkFields([part]);
if (typeof fields.content === "string") {
return new AIMessageChunk(fields);
} else if (fields.content.every((item) => item.type === "text")) {
Expand Down Expand Up @@ -636,12 +636,13 @@ export function responseToBaseMessageFields(
response: GoogleLLMResponse
): BaseMessageFields {
const parts = responseToParts(response);
return partsToBaseMessageFields(parts);
return partsToBaseMessageChunkFields(parts);
}

export function partsToBaseMessageFields(parts: GeminiPart[]): AIMessageFields {
const fields: AIMessageFields = {
export function partsToBaseMessageChunkFields(parts: GeminiPart[]): AIMessageChunkFields {
const fields: AIMessageChunkFields = {
content: partsToMessageContent(parts),
tool_call_chunks: [],
tool_calls: [],
invalid_tool_calls: [],
};
Expand All @@ -650,6 +651,13 @@ export function partsToBaseMessageFields(parts: GeminiPart[]): AIMessageFields {
if (rawTools.length > 0) {
const tools = toolsRawToTools(rawTools);
for (const tool of tools) {
fields.tool_call_chunks?.push({
name: tool.function.name,
args: tool.function.arguments,
id: tool.id,
type: "tool_call_chunk",
});

try {
fields.tool_calls?.push({
name: tool.function.name,
Expand All @@ -661,7 +669,7 @@ export function partsToBaseMessageFields(parts: GeminiPart[]): AIMessageFields {
} catch (e: any) {
fields.invalid_tool_calls?.push({
name: tool.function.name,
args: JSON.parse(tool.function.arguments),
args: tool.function.arguments,
id: tool.id,
error: e.message,
type: "invalid_tool_call",
Expand All @@ -675,6 +683,42 @@ export function partsToBaseMessageFields(parts: GeminiPart[]): AIMessageFields {
return fields;
}

// export function partsToBaseMessageFields(parts: GeminiPart[]): AIMessageFields {
// const fields: AIMessageFields = {
// content: partsToMessageContent(parts),
// tool_calls: [],
// invalid_tool_calls: [],
// };

// const rawTools = partsToToolsRaw(parts);
// if (rawTools.length > 0) {
// const tools = toolsRawToTools(rawTools);
// for (const tool of tools) {
// try {
// fields.tool_calls?.push({
// name: tool.function.name,
// args: JSON.parse(tool.function.arguments),
// id: tool.id,
// type: "tool_call",
// });
// // eslint-disable-next-line @typescript-eslint/no-explicit-any
// } catch (e: any) {
// fields.invalid_tool_calls?.push({
// name: tool.function.name,
// args: JSON.parse(tool.function.arguments),
// id: tool.id,
// error: e.message,
// type: "invalid_tool_call",
// });
// }
// }
// fields.additional_kwargs = {
// tool_calls: tools,
// };
// }
// return fields;
// }

export function responseToBaseMessage(
response: GoogleLLMResponse
): BaseMessage {
Expand Down
3 changes: 2 additions & 1 deletion libs/langchain-google-vertexai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@
"release-it": "^15.10.1",
"rollup": "^4.5.2",
"ts-jest": "^29.1.0",
"typescript": "<5.2.0"
"typescript": "<5.2.0",
"zod": "^3.22.4"
},
"publishConfig": {
"access": "public"
Expand Down
57 changes: 46 additions & 11 deletions libs/langchain-google-vertexai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ import {
} from "@langchain/core/messages";
import { ChatVertexAI } from "../chat_models.js";
import { GeminiTool } from "../types.js";
import { tool } from "@langchain/core/tools";
import { concat } from "@langchain/core/utils/stream";
import { z } from "zod";

describe("GAuth Chat", () => {
test("invoke", async () => {
Expand Down Expand Up @@ -41,7 +44,7 @@ describe("GAuth Chat", () => {
expect(textContent.text).toEqual("2");
*/
} catch (e) {
console.error(e);
// console.error(e);
throw e;
}
});
Expand Down Expand Up @@ -81,7 +84,7 @@ describe("GAuth Chat", () => {
expect(["H", "T"]).toContainEqual(textContent.text);
*/
} catch (e) {
console.error(e);
// console.error(e);
throw e;
}
});
Expand All @@ -108,12 +111,11 @@ describe("GAuth Chat", () => {
const lastChunk = resArray[resArray.length - 1];
expect(lastChunk).toBeDefined();
expect(lastChunk._getType()).toEqual("ai");
const aiChunk = lastChunk as AIMessageChunk;
console.log(aiChunk);

console.log(JSON.stringify(resArray, null, 2));
// const aiChunk = lastChunk as AIMessageChunk;
// console.log(aiChunk);
// console.log(JSON.stringify(resArray, null, 2));
} catch (e) {
console.error(e);
// console.error(e);
throw e;
}
});
Expand Down Expand Up @@ -209,7 +211,7 @@ describe("GAuth Chat", () => {
for await (const chunk of res) {
resArray.push(chunk);
}
console.log(JSON.stringify(resArray, null, 2));
// console.log(JSON.stringify(resArray, null, 2));
});

test("withStructuredOutput", async () => {
Expand Down Expand Up @@ -249,7 +251,7 @@ test("Stream token count usage_metadata", async () => {
res = res.concat(chunk);
}
}
console.log(res);
// console.log(res);
expect(res?.usage_metadata).toBeDefined();
if (!res?.usage_metadata) {
return;
Expand All @@ -276,7 +278,7 @@ test("streamUsage excludes token usage", async () => {
res = res.concat(chunk);
}
}
console.log(res);
// console.log(res);
expect(res?.usage_metadata).not.toBeDefined();
});

Expand All @@ -286,7 +288,7 @@ test("Invoke token count usage_metadata", async () => {
maxOutputTokens: 10,
});
const res = await model.invoke("Why is the sky blue? Be concise.");
console.log(res);
// console.log(res);
expect(res?.usage_metadata).toBeDefined();
if (!res?.usage_metadata) {
return;
Expand Down Expand Up @@ -322,3 +324,36 @@ test("Streaming true constructor param will stream", async () => {

expect(totalTokenCount).toBeGreaterThan(1);
});

test("ChatGoogleGenerativeAI can stream tools", async () => {
const model = new ChatVertexAI({});

const weatherTool = tool((_) => {
return "The weather in San Francisco today is 18 degrees and sunny."
}, {
name: "current_weather_tool",
description: "Get the current weather for a given location.",
schema: z.object({
location: z.string().describe("The location to get the weather for."),
})
})

const modelWithTools = model.bindTools([weatherTool]);
const stream = await modelWithTools.stream("Whats the weather like today in San Francisco?");
let finalChunk: AIMessageChunk | undefined;
for await (const chunk of stream) {
finalChunk = !finalChunk ? chunk : concat(finalChunk, chunk);
}

expect(finalChunk).toBeDefined();
if (!finalChunk) return;

const toolCalls = finalChunk.tool_calls;
expect(toolCalls).toBeDefined();
if (!toolCalls) {
throw new Error("tool_calls not in response");
}
expect(toolCalls.length).toBe(1);
expect(toolCalls[0].name).toBe("current_weather_tool");
expect(toolCalls[0].args).toHaveProperty("location")
});
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class ChatVertexAIStandardIntegrationTests extends ChatModelIntegrationTests<
Cls: ChatVertexAI,
chatModelHasToolCalling: true,
chatModelHasStructuredOutput: true,
invokeResponseType: AIMessageChunk,
constructorArgs: {
model: "gemini-1.5-pro",
},
Expand All @@ -32,6 +33,14 @@ class ChatVertexAIStandardIntegrationTests extends ChatModelIntegrationTests<
"Not implemented."
);
}

async testInvokeMoreComplexTools() {
this.skipTestMessage(
"testInvokeMoreComplexTools",
"ChatVertexAI",
"Google VertexAI does not support tool schemas where the object properties are not defined."
);
}
}

const testClass = new ChatVertexAIStandardIntegrationTests();
Expand Down
1 change: 1 addition & 0 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -11695,6 +11695,7 @@ __metadata:
rollup: ^4.5.2
ts-jest: ^29.1.0
typescript: <5.2.0
zod: ^3.22.4
languageName: unknown
linkType: soft

Expand Down

0 comments on commit 4c43a75

Please sign in to comment.