Skip to content

Commit

Permalink
core[minor]: Add base implementation of withStructuredOutput
Browse files Browse the repository at this point in the history
  • Loading branch information
bracesproul committed Jun 13, 2024
1 parent 4f3deb2 commit 82cd2bb
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 7 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/standard-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ jobs:
strategy:
matrix:
package: [anthropic, cohere, google-genai, groq, mistralai]
if: contains(needs.get-changed-files.outputs.changed_files, 'langchain-core/') || contains(needs.get-changed-files.outputs.changed_files, 'libs/langchain-${{ matrix.package }}/')
steps:
- uses: actions/checkout@v4
if: contains(needs.get-changed-files.outputs.changed_files, 'langchain-core/') || contains(needs.get-changed-files.outputs.changed_files, 'libs/langchain-${{ matrix.package }}/')
- name: Use Node.js 18.x
uses: actions/setup-node@v3
with:
Expand Down Expand Up @@ -113,4 +113,4 @@ jobs:
env:
BEDROCK_AWS_REGION: "us-east-1"
BEDROCK_AWS_SECRET_ACCESS_KEY: ${{ secrets.BEDROCK_AWS_SECRET_ACCESS_KEY }}
BEDROCK_AWS_ACCESS_KEY_ID: ${{ secrets.BEDROCK_AWS_ACCESS_KEY_ID }}
BEDROCK_AWS_ACCESS_KEY_ID: ${{ secrets.BEDROCK_AWS_ACCESS_KEY_ID }}
4 changes: 4 additions & 0 deletions langchain-core/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ utils/types.cjs
utils/types.js
utils/types.d.ts
utils/types.d.cts
utils/is_openai_tool.cjs
utils/is_openai_tool.js
utils/is_openai_tool.d.ts
utils/is_openai_tool.d.cts
vectorstores.cjs
vectorstores.js
vectorstores.d.ts
Expand Down
1 change: 1 addition & 0 deletions langchain-core/langchain.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ export const config = {
"utils/testing": "utils/testing/index",
"utils/tiktoken": "utils/tiktoken",
"utils/types": "utils/types/index",
"utils/is_openai_tool": "utils/is_openai_tool",
vectorstores: "vectorstores",
},
tsConfigPath: resolve("./tsconfig.json"),
Expand Down
13 changes: 13 additions & 0 deletions langchain-core/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,15 @@
"import": "./utils/types.js",
"require": "./utils/types.cjs"
},
"./utils/is_openai_tool": {
"types": {
"import": "./utils/is_openai_tool.d.ts",
"require": "./utils/is_openai_tool.d.cts",
"default": "./utils/is_openai_tool.d.ts"
},
"import": "./utils/is_openai_tool.js",
"require": "./utils/is_openai_tool.cjs"
},
"./vectorstores": {
"types": {
"import": "./vectorstores.d.ts",
Expand Down Expand Up @@ -810,6 +819,10 @@
"utils/types.js",
"utils/types.d.ts",
"utils/types.d.cts",
"utils/is_openai_tool.cjs",
"utils/is_openai_tool.js",
"utils/is_openai_tool.d.ts",
"utils/is_openai_tool.d.cts",
"vectorstores.cjs",
"vectorstores.js",
"vectorstores.d.ts",
Expand Down
157 changes: 152 additions & 5 deletions langchain-core/src/language_models/chat_models.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import { z } from "zod";
import { zodToJsonSchema } from "zod-to-json-schema";
import {
AIMessage,
type BaseMessage,
BaseMessageChunk,
type BaseMessageLike,
HumanMessage,
coerceMessageLikeToMessage,
AIMessageChunk,
} from "../messages/index.js";
import type { BasePromptValueInterface } from "../prompt_values.js";
import {
Expand All @@ -17,6 +20,8 @@ import {
} from "../outputs.js";
import {
BaseLanguageModel,
StructuredOutputMethodOptions,
ToolDefinition,
type BaseLanguageModelCallOptions,
type BaseLanguageModelInput,
type BaseLanguageModelParams,
Expand All @@ -29,10 +34,16 @@ import {
import type { RunnableConfig } from "../runnables/config.js";
import type { BaseCache } from "../caches.js";
import { StructuredToolInterface } from "../tools.js";
import { Runnable } from "../runnables/base.js";
import {
Runnable,
RunnableLambda,
RunnableSequence,
} from "../runnables/base.js";
import { isStreamEventsHandler } from "../tracers/event_stream.js";
import { isLogStreamHandler } from "../tracers/log_stream.js";
import { concat } from "../utils/stream.js";
import { RunnablePassthrough } from "../runnables/passthrough.js";
import { isZodSchema } from "../utils/types/is_zod_schema.js";

/**
* Represents a serialized chat model.
Expand Down Expand Up @@ -143,12 +154,16 @@ export abstract class BaseChatModel<
* Bind tool-like objects to this chat model.
*
* @param tools A list of tool definitions to bind to this chat model.
* Can be a structured tool or an object matching the provider's
* specific tool schema.
* Can be a structured tool, an OpenAI formatted tool, or an object
* matching the provider's specific tool schema.
* @param kwargs Any additional parameters to bind.
*/
bindTools?(
tools: (StructuredToolInterface | Record<string, unknown>)[],
tools: (
| StructuredToolInterface
| Record<string, unknown>
| ToolDefinition
)[],
kwargs?: Partial<CallOptions>
): Runnable<BaseLanguageModelInput, OutputMessageType, CallOptions>;

Expand Down Expand Up @@ -714,6 +729,138 @@ export abstract class BaseChatModel<
}
return result.content;
}

withStructuredOutput<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends Record<string, any> = Record<string, any>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
>(
outputSchema:
| z.ZodType<RunOutput>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| Record<string, any>,
config?: StructuredOutputMethodOptions<false>
): Runnable<BaseLanguageModelInput, RunOutput>;

withStructuredOutput<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends Record<string, any> = Record<string, any>
>(
outputSchema:
| z.ZodType<RunOutput>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| Record<string, any>,
config?: StructuredOutputMethodOptions<true>
): Runnable<BaseLanguageModelInput, { raw: BaseMessage; parsed: RunOutput }>;

withStructuredOutput<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
RunOutput extends Record<string, any> = Record<string, any>
>(
outputSchema:
| z.ZodType<RunOutput>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
| Record<string, any>,
config?: StructuredOutputMethodOptions<boolean>
):
| Runnable<BaseLanguageModelInput, RunOutput>
| Runnable<
BaseLanguageModelInput,
{
raw: BaseMessage;
parsed: RunOutput;
}
> {
if (!("bindTools" in this) || typeof this.bindTools !== "function") {
throw new Error(
`Chat model must implement ".bindTools()" to use withStructuredOutput.`
);
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const schema: z.ZodType<RunOutput> | Record<string, any> = outputSchema;
const name = config?.name;
const description = schema.description ?? "A function available to call.";
const method = config?.method;
const includeRaw = config?.includeRaw;
if (method === "jsonMode") {
throw new Error(
`Base withStructuredOutput implementation only supports "functionCalling" as a method.`
);
}

let functionName = name ?? "extract";
let tools: ToolDefinition[];
if (isZodSchema(schema)) {
tools = [
{
type: "function",
function: {
name: functionName,
description,
parameters: zodToJsonSchema(schema),
},
},
];
} else {
if ("name" in schema) {
functionName = schema.name;
}
tools = [
{
type: "function",
function: {
name: functionName,
description,
parameters: schema,
},
},
];
}

const llm = this.bindTools(tools);
const outputParser = RunnableLambda.from<AIMessageChunk, RunOutput>(
(input: AIMessageChunk): RunOutput => {
if (!input.tool_calls || input.tool_calls.length === 0) {
throw new Error("No tool calls found in the response.");
}
const toolCall = input.tool_calls.find(
(tc) => tc.name === functionName
);
if (!toolCall) {
throw new Error(`No tool call found with name ${functionName}.`);
}
return toolCall.args as RunOutput;
}
);

if (!includeRaw) {
return llm.pipe(outputParser).withConfig({
runName: "StructuredOutput",
}) as Runnable<BaseLanguageModelInput, RunOutput>;
}

const parserAssign = RunnablePassthrough.assign({
// eslint-disable-next-line @typescript-eslint/no-explicit-any
parsed: (input: any, config) => outputParser.invoke(input.raw, config),
});
const parserNone = RunnablePassthrough.assign({
parsed: () => null,
});
const parsedWithFallback = parserAssign.withFallbacks({
fallbacks: [parserNone],
});
return RunnableSequence.from<
BaseLanguageModelInput,
{ raw: BaseMessage; parsed: RunOutput }
>([
{
raw: llm,
},
parsedWithFallback,
]).withConfig({
runName: "StructuredOutputRunnable",
});
}
}

/**
Expand Down Expand Up @@ -750,4 +897,4 @@ export abstract class SimpleChatModel<
],
};
}
}
}
17 changes: 17 additions & 0 deletions langchain-core/src/utils/is_openai_tool.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { ToolDefinition } from "../language_models/base.js";

export function isOpenAITool(tool: unknown): tool is ToolDefinition {
if (typeof tool !== "object" || !tool) return false;
if (
"type" in tool &&
tool.type === "function" &&
"function" in tool &&
typeof tool.function === "object" &&
tool.function &&
"name" in tool.function &&
"parameters" in tool.function
) {
return true;
}
return false;
}

0 comments on commit 82cd2bb

Please sign in to comment.