Skip to content

Commit

Permalink
use sdk for tools
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed Jul 19, 2024
1 parent 1579171 commit 317080d
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 87 deletions.
2 changes: 1 addition & 1 deletion libs/langchain-ollama/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"license": "MIT",
"dependencies": {
"@langchain/core": ">0.2.17 <0.3.0",
"ollama": "^0.5.2",
"ollama": "^0.5.6",
"uuid": "^10.0.0"
},
"devDependencies": {
Expand Down
52 changes: 11 additions & 41 deletions libs/langchain-ollama/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import type {
ChatRequest as OllamaChatRequest,
ChatResponse as OllamaChatResponse,
Message as OllamaMessage,
Tool as OllamaTool,
} from "ollama";
import { StructuredToolInterface } from "@langchain/core/tools";
import { Runnable, RunnableToolLike } from "@langchain/core/runnables";
Expand Down Expand Up @@ -316,7 +317,7 @@ export class ChatOllama

invocationParams(
options?: this["ParsedCallOptions"]
): Omit<OllamaChatRequest, "messages"> & { tools?: ToolDefinition[] } {
): Omit<OllamaChatRequest, "messages"> {
if (options?.tool_choice) {
throw new Error("Tool choice is not supported for ChatOllama.");
}
Expand Down Expand Up @@ -357,7 +358,9 @@ export class ChatOllama
penalize_newline: this.penalizeNewline,
stop: options?.stop,
},
tools: options?.tools?.map(convertToOpenAITool),
tools: options?.tools?.length
? (options.tools.map(convertToOpenAITool) as OllamaTool[])
: undefined,
};
}

Expand Down Expand Up @@ -449,14 +452,11 @@ export class ChatOllama
};

if (params.tools && params.tools.length > 0) {
const toolResult = await this._generateRawApi(
{
...params,
tools: params.tools,
messages: ollamaMessages,
},
options
);
const toolResult = await this.client.chat({
...params,
messages: ollamaMessages,
stream: false, // Ollama currently does not support streaming with tools
});

const { message: responseMessage, ...rest } = toolResult;
usageMetadata.input_tokens += rest.prompt_eval_count ?? 0;
Expand Down Expand Up @@ -495,9 +495,7 @@ export class ChatOllama

yield new ChatGenerationChunk({
text: responseMessage.content ?? "",
message: new AIMessageChunk({
content: responseMessage.content ?? "",
}),
message: convertOllamaMessagesToLangChain(responseMessage),
});
await runManager?.handleLLMNewToken(responseMessage.content ?? "");
}
Expand All @@ -512,32 +510,4 @@ export class ChatOllama
}),
});
}

async _generateRawApi(
request: OllamaChatRequest & { tools: ToolDefinition[] },
options?: this["ParsedCallOptions"]
): Promise<OllamaChatResponse> {
return this.caller.callWithOptions(
{ signal: options?.signal },
async () => {
const streamRes = await fetch(`${this.baseUrl}/api/chat`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
...request,
stream: false,
}),
});

// Ensure the response is ok
if (!streamRes.ok) {
throw new Error(`HTTP error! status: ${streamRes.status}`);
}
const streamResJson = await streamRes.json();
return streamResJson;
}
);
}
}
2 changes: 0 additions & 2 deletions libs/langchain-ollama/src/tests/chat_models-tools.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ test("Ollama can call withStructuredOutput", async () => {
});

const result = await model.invoke(messageHistory);
console.log("WSO", result);
expect(result).toBeDefined();
expect(result.location).toBeDefined();
expect(result.location).not.toBe("");
Expand All @@ -100,7 +99,6 @@ test("Ollama can call withStructuredOutput includeRaw", async () => {
});

const result = await model.invoke(messageHistory);
console.log("WSO", result);
expect(result).toBeDefined();
expect(result.parsed.location).toBeDefined();
expect(result.parsed.location).not.toBe("");
Expand Down
55 changes: 18 additions & 37 deletions libs/langchain-ollama/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,18 @@ import {
ToolMessage,
UsageMetadata,
} from "@langchain/core/messages";
import type { Message as OllamaMessage } from "ollama";
import type {
Message as OllamaMessage,
ToolCall as OllamaToolCall,
} from "ollama";
import { v4 as uuidv4 } from "uuid";

export interface OllamaToolCall {
type?: "function";
id?: string;
function: {
name: string;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
arguments: Record<string, any>;
};
}

export interface OllamaMessageWithTools extends Omit<OllamaMessage, "content"> {
tool_calls?: OllamaToolCall[];
content?: string;
tool_call_id?: string;
}

export function convertOllamaMessagesToLangChain(
messages: OllamaMessageWithTools,
extra: {
messages: OllamaMessage,
extra?: {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
responseMetadata: Record<string, any>;
usageMetadata: UsageMetadata;
responseMetadata?: Record<string, any>;
usageMetadata?: UsageMetadata;
}
): AIMessageChunk {
return new AIMessageChunk({
Expand All @@ -42,10 +29,10 @@ export function convertOllamaMessagesToLangChain(
args: JSON.stringify(tc.function.arguments),
type: "tool_call_chunk",
index: 0,
id: tc.id ?? uuidv4(),
id: uuidv4(),
})),
response_metadata: extra.responseMetadata,
usage_metadata: extra.usageMetadata,
response_metadata: extra?.responseMetadata,
usage_metadata: extra?.usageMetadata,
});
}

Expand All @@ -54,9 +41,7 @@ function extractBase64FromDataUrl(dataUrl: string): string {
return match ? match[1] : "";
}

function convertAMessagesToOllama(
messages: AIMessage
): OllamaMessageWithTools[] {
function convertAMessagesToOllama(messages: AIMessage): OllamaMessage[] {
if (typeof messages.content === "string") {
return [
{
Expand All @@ -73,7 +58,7 @@ function convertAMessagesToOllama(
role: "assistant",
content: c.text,
}));
let toolCallMsgs: OllamaMessageWithTools | undefined;
let toolCallMsgs: OllamaMessage | undefined;

if (
messages.content.find((c) => c.type === "tool_use") &&
Expand All @@ -95,6 +80,7 @@ function convertAMessagesToOllama(
toolCallMsgs = {
role: "assistant",
tool_calls: toolCalls,
content: "",
};
}
} else if (
Expand All @@ -111,7 +97,7 @@ function convertAMessagesToOllama(

function convertHumanGenericMessagesToOllama(
message: HumanMessage
): OllamaMessageWithTools[] {
): OllamaMessage[] {
if (typeof message.content === "string") {
return [
{
Expand Down Expand Up @@ -145,9 +131,7 @@ function convertHumanGenericMessagesToOllama(
});
}

function convertSystemMessageToOllama(
message: SystemMessage
): OllamaMessageWithTools[] {
function convertSystemMessageToOllama(message: SystemMessage): OllamaMessage[] {
if (typeof message.content === "string") {
return [
{
Expand All @@ -173,15 +157,12 @@ function convertSystemMessageToOllama(
}
}

function convertToolMessageToOllama(
message: ToolMessage
): OllamaMessageWithTools[] {
function convertToolMessageToOllama(message: ToolMessage): OllamaMessage[] {
if (typeof message.content !== "string") {
throw new Error("Non string tool message content is not supported");
}
return [
{
tool_call_id: message.tool_call_id,
role: "tool",
content: message.content,
},
Expand All @@ -190,7 +171,7 @@ function convertToolMessageToOllama(

export function convertToOllamaMessages(
messages: BaseMessage[]
): OllamaMessageWithTools[] {
): OllamaMessage[] {
return messages.flatMap((msg) => {
if (["human", "generic"].includes(msg._getType())) {
return convertHumanGenericMessagesToOllama(msg);
Expand Down
12 changes: 6 additions & 6 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -11501,7 +11501,7 @@ __metadata:
resolution: "@langchain/ollama@workspace:libs/langchain-ollama"
dependencies:
"@jest/globals": ^29.5.0
"@langchain/core": ">0.2.14 <0.3.0"
"@langchain/core": ">0.2.17 <0.3.0"
"@langchain/scripts": ~0.0.14
"@langchain/standard-tests": 0.0.0
"@swc/core": ^1.3.90
Expand All @@ -11519,7 +11519,7 @@ __metadata:
eslint-plugin-prettier: ^4.2.1
jest: ^29.5.0
jest-environment-node: ^29.6.4
ollama: ^0.5.2
ollama: ^0.5.6
prettier: ^2.8.3
release-it: ^15.10.1
rollup: ^4.5.2
Expand Down Expand Up @@ -33065,12 +33065,12 @@ __metadata:
languageName: node
linkType: hard

"ollama@npm:^0.5.2":
version: 0.5.2
resolution: "ollama@npm:0.5.2"
"ollama@npm:^0.5.6":
version: 0.5.6
resolution: "ollama@npm:0.5.6"
dependencies:
whatwg-fetch: ^3.6.20
checksum: d824825fcf52dba24e4c99eb7ec1f95075dc2e3c89d89991af7d633a67f4484f43386dd83e70b757ab1d695d7a15c27f0ed87c0ac91d47509836fbec138db85a
checksum: f7aafe4f0cf5e3fee9f5be7501733d3ab4ea0b02e0aafacdae90cb5a8babfa4bb4543d47fab152b5424084d3331185a09e584a5d3c74e2cefcf017dc5964f520
languageName: node
linkType: hard

Expand Down

0 comments on commit 317080d

Please sign in to comment.