diff --git a/packages/app/src/assets/node_images/delegate_function_call_node.png b/packages/app/src/assets/node_images/delegate_function_call_node.png
new file mode 100644
index 000000000..feeecb4bc
Binary files /dev/null and b/packages/app/src/assets/node_images/delegate_function_call_node.png differ
diff --git a/packages/app/src/components/RenderDataValue.tsx b/packages/app/src/components/RenderDataValue.tsx
index 0c1db44d5..4aee6fec4 100644
--- a/packages/app/src/components/RenderDataValue.tsx
+++ b/packages/app/src/components/RenderDataValue.tsx
@@ -142,13 +142,26 @@ const scalarRenderers: {
assistant
{messageContent}
- {message.function_call && (
-
-
Function Call:
+ {message.function_calls ? (
+
+
Function Calls:
-
+ {message.function_calls.map((fc, i) => (
+
+
+
+ ))}
+ ) : (
+ message.function_call && (
+
+ )
)}
))
diff --git a/packages/app/src/components/editors/custom/ToolCallHandlersEditor.tsx b/packages/app/src/components/editors/custom/ToolCallHandlersEditor.tsx
index 95b8f9cb9..7a265d379 100644
--- a/packages/app/src/components/editors/custom/ToolCallHandlersEditor.tsx
+++ b/packages/app/src/components/editors/custom/ToolCallHandlersEditor.tsx
@@ -172,7 +172,7 @@ export const ToolCallHandlers: FC = ({
}
isDisabled={isDisabled}
isReadOnly={isReadonly}
- placeholder="Tool ID"
+ placeholder="Tool/Function"
style={{ marginRight: '8px' }}
/>
=> {
return {
@@ -143,5 +144,6 @@ export const useBuiltInNodeImages = (): Record => {
listGraphs: listGraphsNodeImage,
graphReference: graphReferenceNodeImage,
callGraph: callGraphNodeImage,
+ delegateFunctionCall: delegateFunctionCallNodeImage,
};
};
diff --git a/packages/core/src/model/DataValue.ts b/packages/core/src/model/DataValue.ts
index bf1630909..7bbb2db4f 100644
--- a/packages/core/src/model/DataValue.ts
+++ b/packages/core/src/model/DataValue.ts
@@ -21,16 +21,26 @@ export type UserChatMessage = {
message: ChatMessageMessagePart | ChatMessageMessagePart[];
};
+export type AssistantChatMessageFunctionCall = {
+ id: string | undefined;
+ name: string;
+ arguments: string; // JSON string
+};
+
+export type ParsedAssistantChatMessageFunctionCall = {
+ id: string | undefined;
+ name: string;
+ arguments: Record;
+};
+
export type AssistantChatMessage = {
type: 'assistant';
message: ChatMessageMessagePart | ChatMessageMessagePart[];
- function_call:
- | {
- id: string | undefined;
- name: string;
- arguments: string; // JSON string
- }
- | undefined;
+
+ /** @deprecated use function_calls instead */
+ function_call: AssistantChatMessageFunctionCall | undefined;
+
+ function_calls: AssistantChatMessageFunctionCall[] | undefined;
};
export type FunctionResponseChatMessage = {
diff --git a/packages/core/src/model/Nodes.ts b/packages/core/src/model/Nodes.ts
index 9b9ee6f59..12ce3b07d 100644
--- a/packages/core/src/model/Nodes.ts
+++ b/packages/core/src/model/Nodes.ts
@@ -210,6 +210,7 @@ import { graphReferenceNode } from './nodes/GraphReferenceNode.js';
export * from './nodes/GraphReferenceNode.js';
import { callGraphNode } from './nodes/CallGraphNode.js';
+import { delegateFunctionCallNode } from './nodes/DelegateFunctionCallNode.js';
export * from './nodes/CallGraphNode.js';
export const registerBuiltInNodes = (registry: NodeRegistration) => {
@@ -283,7 +284,8 @@ export const registerBuiltInNodes = (registry: NodeRegistration) => {
.register(replaceDatasetNode)
.register(listGraphsNode)
.register(graphReferenceNode)
- .register(callGraphNode);
+ .register(callGraphNode)
+ .register(delegateFunctionCallNode);
};
let globalRivetNodeRegistry = registerBuiltInNodes(new NodeRegistration());
diff --git a/packages/core/src/model/nodes/AssembleMessageNode.ts b/packages/core/src/model/nodes/AssembleMessageNode.ts
index ff1c6b559..03d2e4fff 100644
--- a/packages/core/src/model/nodes/AssembleMessageNode.ts
+++ b/packages/core/src/model/nodes/AssembleMessageNode.ts
@@ -193,6 +193,7 @@ export class AssembleMessageNodeImpl extends NodeImpl {
type,
message: [],
function_call: undefined, // Not supported yet in Assemble Message node
+ function_calls: undefined, // Not supported yet in Assemble Message node
}),
)
.with(
diff --git a/packages/core/src/model/nodes/CallGraphNode.ts b/packages/core/src/model/nodes/CallGraphNode.ts
index 99e08e4af..f5a5ed381 100644
--- a/packages/core/src/model/nodes/CallGraphNode.ts
+++ b/packages/core/src/model/nodes/CallGraphNode.ts
@@ -88,7 +88,7 @@ export class CallGraphNodeImpl extends NodeImpl {
static getUIData(): NodeUIData {
return {
infoBoxBody: dedent`
- Gets a reference to another graph, that can be used to pass around graphs to call using a Call Graph node.
+ Calls another graph and passes inputs to it. Use in combination with the Graph Reference node to call dynamic graphs.
`,
infoBoxTitle: 'Call Graph Node',
contextMenuTitle: 'Call Graph',
diff --git a/packages/core/src/model/nodes/ChatNode.ts b/packages/core/src/model/nodes/ChatNode.ts
index d9657375d..890841bc9 100644
--- a/packages/core/src/model/nodes/ChatNode.ts
+++ b/packages/core/src/model/nodes/ChatNode.ts
@@ -97,7 +97,7 @@ export class ChatNodeImpl extends NodeImpl {
width: 200,
},
data: {
- model: 'gpt-3.5-turbo',
+ model: 'gpt-4o-mini',
useModelInput: false,
temperature: 0.5,
@@ -962,11 +962,20 @@ export class ChatNodeImpl extends NodeImpl {
id: functionCalls[0][0]!.id,
}
: undefined,
+ function_calls: functionCalls[0]
+ ? functionCalls[0].map((fc) => ({
+ name: fc.name,
+ arguments: fc.arguments,
+ id: fc.id,
+ }))
+ : undefined,
},
],
};
}
+ console.dir({ output });
+
const endTime = Date.now();
if (responseChoicesParts.length === 0 && functionCalls.length === 0) {
diff --git a/packages/core/src/model/nodes/DelegateFunctionCallNode.ts b/packages/core/src/model/nodes/DelegateFunctionCallNode.ts
new file mode 100644
index 000000000..f473c2a14
--- /dev/null
+++ b/packages/core/src/model/nodes/DelegateFunctionCallNode.ts
@@ -0,0 +1,182 @@
+import { nanoid } from 'nanoid';
+import type {
+ AssistantChatMessageFunctionCall,
+ DataValue,
+ GptFunction,
+ ParsedAssistantChatMessageFunctionCall,
+} from '../DataValue.js';
+import type { ChartNode, NodeId, NodeInputDefinition, NodeOutputDefinition, PortId } from '../NodeBase.js';
+import type { GraphId } from '../NodeGraph.js';
+import { NodeImpl, type NodeBody, type NodeUIData } from '../NodeImpl.js';
+import { dedent } from 'ts-dedent';
+import type { EditorDefinition } from '../EditorDefinition.js';
+import type { RivetUIContext } from '../RivetUIContext.js';
+import { nodeDefinition } from '../NodeDefinition.js';
+import type { InternalProcessContext } from '../ProcessContext.js';
+import type { Inputs, Outputs } from '../GraphProcessor.js';
+import { coerceType, coerceTypeOptional } from '../../utils/coerceType.js';
+
+export type DelegateFunctionCallNode = ChartNode<'delegateFunctionCall', DelegateFunctionCallNodeData>;
+
+export type DelegateFunctionCallNodeData = {
+ handlers: { key: string; value: GraphId }[];
+ unknownHandler: GraphId | undefined;
+};
+
+export class DelegateFunctionCallNodeImpl extends NodeImpl {
+ static create(): DelegateFunctionCallNode {
+ const chartNode: DelegateFunctionCallNode = {
+ type: 'delegateFunctionCall',
+ title: 'Delegate Function Call',
+ id: nanoid() as NodeId,
+ visualData: {
+ x: 0,
+ y: 0,
+ width: 325,
+ },
+ data: {
+ handlers: [],
+ unknownHandler: undefined,
+ },
+ };
+
+ return chartNode;
+ }
+
+ getInputDefinitions(): NodeInputDefinition[] {
+ const inputs: NodeInputDefinition[] = [];
+
+ inputs.push({
+ id: 'function-call' as PortId,
+ dataType: 'object',
+ title: 'Function Call',
+ coerced: true,
+ required: true,
+ description: 'The function call to delegate to a subgraph.',
+ });
+
+ return inputs;
+ }
+
+ getOutputDefinitions(): NodeOutputDefinition[] {
+ const outputs: NodeOutputDefinition[] = [];
+
+ outputs.push({
+ id: 'output' as PortId,
+ dataType: 'string',
+ title: 'Output',
+ description: 'The output of the function call.',
+ });
+
+ outputs.push({
+ id: 'message' as PortId,
+ dataType: 'object',
+ title: 'Message Output',
+ description: 'Maps the output for use directly with an Assemble Prompt node and GPT.',
+ });
+
+ return outputs;
+ }
+
+ static getUIData(): NodeUIData {
+ return {
+ infoBoxBody: dedent`
+ Handles a function call by delegating it to a different subgraph depending on the function call.
+ `,
+ infoBoxTitle: 'Delegate Function Call Node',
+ contextMenuTitle: 'Delegate Function Call',
+ group: ['Advanced'],
+ };
+ }
+
+ getEditors(): EditorDefinition[] {
+ return [
+ {
+ type: 'custom',
+ customEditorId: 'ToolCallHandlers',
+ label: 'Handlers',
+ dataKey: 'handlers',
+ },
+ {
+ type: 'graphSelector',
+ dataKey: 'unknownHandler',
+ label: 'Unknown Handler',
+ helperMessage: 'The subgraph to delegate to if the function call does not match any handlers.',
+ },
+ ];
+ }
+
+ getBody(context: RivetUIContext): NodeBody {
+ if (this.data.handlers.length === 0) {
+ return 'No handlers defined';
+ }
+
+ const lines = ['Handlers:'];
+
+ this.data.handlers.forEach(({ key, value }) => {
+ const subgraphName = context.project.graphs[value]?.metadata!.name! ?? 'Unknown Subgraph';
+ lines.push(` ${key || '(MISSING!)'} -> ${subgraphName}`);
+ });
+
+ return lines.join('\n');
+ }
+
+ async process(inputs: Inputs, context: InternalProcessContext): Promise {
+ const functionCall = coerceType(
+ inputs['function-call' as PortId],
+ 'object',
+ ) as ParsedAssistantChatMessageFunctionCall;
+
+ let handler = this.data.handlers.find((handler) => handler.key === functionCall.name);
+
+ if (!handler) {
+ if (this.data.unknownHandler) {
+ handler = { key: undefined!, value: this.data.unknownHandler };
+ } else {
+ throw new Error(`No handler found for function call: ${functionCall.name}`);
+ }
+ }
+
+ const subgraphInputs: Record = {
+ _function_name: {
+ type: 'string',
+ value: functionCall.name,
+ },
+ _arguments: {
+ type: 'object',
+ value: functionCall.arguments,
+ },
+ };
+
+ for (const [argName, argument] of Object.entries(functionCall.arguments)) {
+ subgraphInputs[argName] = {
+ type: 'any',
+ value: argument,
+ };
+ }
+
+ const handlerGraphId = handler.value;
+ const subprocessor = context.createSubProcessor(handlerGraphId, { signal: context.signal });
+
+ const outputs = await subprocessor.processGraph(context, subgraphInputs, context.contextValues);
+
+ const outputString = coerceTypeOptional(outputs.output, 'string') ?? '';
+
+ return {
+ ['output' as PortId]: {
+ type: 'string',
+ value: outputString,
+ },
+ ['message' as PortId]: {
+ type: 'chat-message',
+ value: {
+ type: 'function',
+ message: outputString,
+ name: functionCall.id ?? '',
+ },
+ },
+ };
+ }
+}
+
+export const delegateFunctionCallNode = nodeDefinition(DelegateFunctionCallNodeImpl, 'Delegate Function Call');
diff --git a/packages/core/src/model/nodes/PromptNode.ts b/packages/core/src/model/nodes/PromptNode.ts
index b848cb17f..390d33286 100644
--- a/packages/core/src/model/nodes/PromptNode.ts
+++ b/packages/core/src/model/nodes/PromptNode.ts
@@ -10,6 +10,7 @@ import { NodeImpl, type NodeUIData } from '../NodeImpl.js';
import { nodeDefinition } from '../NodeDefinition.js';
import {
type AssistantChatMessage,
+ type AssistantChatMessageFunctionCall,
type ChatMessage,
type EditorDefinition,
type Inputs,
@@ -245,7 +246,8 @@ export class PromptNodeImpl extends NodeImpl {
return {
type,
message: outputValue,
- function_call: functionCall as AssistantChatMessage['function_call'],
+ function_call: functionCall as AssistantChatMessageFunctionCall,
+ function_calls: functionCall ? [functionCall as AssistantChatMessageFunctionCall] : undefined,
};
})
.with(
diff --git a/packages/core/src/utils/chatMessageToOpenAIChatCompletionMessage.ts b/packages/core/src/utils/chatMessageToOpenAIChatCompletionMessage.ts
index a301c123c..3bc4638cd 100644
--- a/packages/core/src/utils/chatMessageToOpenAIChatCompletionMessage.ts
+++ b/packages/core/src/utils/chatMessageToOpenAIChatCompletionMessage.ts
@@ -53,15 +53,21 @@ export async function chatMessageToOpenAIChatCompletionMessage(
role: m.type,
content: onlyStringContent(m),
- tool_calls: m.function_call
- ? [
- {
- id: m.function_call.id ?? 'unknown_function_call',
- type: 'function',
- function: m.function_call,
- },
- ]
- : undefined,
+ tool_calls: m.function_calls
+ ? m.function_calls.map((fc) => ({
+ id: fc.id ?? 'unknown_function_call',
+ type: 'function',
+ function: fc,
+ }))
+ : m.function_call
+ ? [
+ {
+ id: m.function_call.id ?? 'unknown_function_call',
+ type: 'function',
+ function: m.function_call,
+ },
+ ]
+ : undefined,
}),
)
.with(
diff --git a/packages/core/src/utils/openai.ts b/packages/core/src/utils/openai.ts
index 05e2980e2..aac2cab99 100644
--- a/packages/core/src/utils/openai.ts
+++ b/packages/core/src/utils/openai.ts
@@ -142,6 +142,22 @@ export const openaiModels = {
},
displayName: 'GPT-4o',
},
+ 'gpt-4o-mini': {
+ maxTokens: 128000,
+ cost: {
+ prompt: 0.00015,
+ completion: 0.00075,
+ },
+ displayName: 'GPT-4o mini',
+ },
+ 'gpt-4o-mini-2024-07-18': {
+ maxTokens: 128000,
+ cost: {
+ prompt: 0.00015,
+ completion: 0.00075,
+ },
+ displayName: 'GPT-4o mini (2024-07-18)',
+ },
'local-model': {
maxTokens: Number.MAX_SAFE_INTEGER,
cost: {