diff --git a/packages/app/src/assets/node_images/delegate_function_call_node.png b/packages/app/src/assets/node_images/delegate_function_call_node.png new file mode 100644 index 000000000..feeecb4bc Binary files /dev/null and b/packages/app/src/assets/node_images/delegate_function_call_node.png differ diff --git a/packages/app/src/components/RenderDataValue.tsx b/packages/app/src/components/RenderDataValue.tsx index 0c1db44d5..4aee6fec4 100644 --- a/packages/app/src/components/RenderDataValue.tsx +++ b/packages/app/src/components/RenderDataValue.tsx @@ -142,13 +142,26 @@ const scalarRenderers: { assistant {messageContent} - {message.function_call && ( -
-

Function Call:

+ {message.function_calls ? ( +
+

Function Calls:

- + {message.function_calls.map((fc, i) => ( +
+ +
+ ))}
+ ) : ( + message.function_call && ( +
+

Function Call:

+
+ +
+
+ ) )}
)) diff --git a/packages/app/src/components/editors/custom/ToolCallHandlersEditor.tsx b/packages/app/src/components/editors/custom/ToolCallHandlersEditor.tsx index 95b8f9cb9..7a265d379 100644 --- a/packages/app/src/components/editors/custom/ToolCallHandlersEditor.tsx +++ b/packages/app/src/components/editors/custom/ToolCallHandlersEditor.tsx @@ -172,7 +172,7 @@ export const ToolCallHandlers: FC = ({ } isDisabled={isDisabled} isReadOnly={isReadonly} - placeholder="Tool ID" + placeholder="Tool/Function" style={{ marginRight: '8px' }} /> => { return { @@ -143,5 +144,6 @@ export const useBuiltInNodeImages = (): Record => { listGraphs: listGraphsNodeImage, graphReference: graphReferenceNodeImage, callGraph: callGraphNodeImage, + delegateFunctionCall: delegateFunctionCallNodeImage, }; }; diff --git a/packages/core/src/model/DataValue.ts b/packages/core/src/model/DataValue.ts index bf1630909..7bbb2db4f 100644 --- a/packages/core/src/model/DataValue.ts +++ b/packages/core/src/model/DataValue.ts @@ -21,16 +21,26 @@ export type UserChatMessage = { message: ChatMessageMessagePart | ChatMessageMessagePart[]; }; +export type AssistantChatMessageFunctionCall = { + id: string | undefined; + name: string; + arguments: string; // JSON string +}; + +export type ParsedAssistantChatMessageFunctionCall = { + id: string | undefined; + name: string; + arguments: Record; +}; + export type AssistantChatMessage = { type: 'assistant'; message: ChatMessageMessagePart | ChatMessageMessagePart[]; - function_call: - | { - id: string | undefined; - name: string; - arguments: string; // JSON string - } - | undefined; + + /** @deprecated use function_calls instead */ + function_call: AssistantChatMessageFunctionCall | undefined; + + function_calls: AssistantChatMessageFunctionCall[] | undefined; }; export type FunctionResponseChatMessage = { diff --git a/packages/core/src/model/Nodes.ts b/packages/core/src/model/Nodes.ts index 9b9ee6f59..12ce3b07d 100644 --- a/packages/core/src/model/Nodes.ts +++ b/packages/core/src/model/Nodes.ts @@ -210,6 +210,7 @@ import { graphReferenceNode } from './nodes/GraphReferenceNode.js'; export * from './nodes/GraphReferenceNode.js'; import { callGraphNode } from './nodes/CallGraphNode.js'; +import { delegateFunctionCallNode } from './nodes/DelegateFunctionCallNode.js'; export * from './nodes/CallGraphNode.js'; export const registerBuiltInNodes = (registry: NodeRegistration) => { @@ -283,7 +284,8 @@ export const registerBuiltInNodes = (registry: NodeRegistration) => { .register(replaceDatasetNode) .register(listGraphsNode) .register(graphReferenceNode) - .register(callGraphNode); + .register(callGraphNode) + .register(delegateFunctionCallNode); }; let globalRivetNodeRegistry = registerBuiltInNodes(new NodeRegistration()); diff --git a/packages/core/src/model/nodes/AssembleMessageNode.ts b/packages/core/src/model/nodes/AssembleMessageNode.ts index ff1c6b559..03d2e4fff 100644 --- a/packages/core/src/model/nodes/AssembleMessageNode.ts +++ b/packages/core/src/model/nodes/AssembleMessageNode.ts @@ -193,6 +193,7 @@ export class AssembleMessageNodeImpl extends NodeImpl { type, message: [], function_call: undefined, // Not supported yet in Assemble Message node + function_calls: undefined, // Not supported yet in Assemble Message node }), ) .with( diff --git a/packages/core/src/model/nodes/CallGraphNode.ts b/packages/core/src/model/nodes/CallGraphNode.ts index 99e08e4af..f5a5ed381 100644 --- a/packages/core/src/model/nodes/CallGraphNode.ts +++ b/packages/core/src/model/nodes/CallGraphNode.ts @@ -88,7 +88,7 @@ export class CallGraphNodeImpl extends NodeImpl { static getUIData(): NodeUIData { return { infoBoxBody: dedent` - Gets a reference to another graph, that can be used to pass around graphs to call using a Call Graph node. + Calls another graph and passes inputs to it. Use in combination with the Graph Reference node to call dynamic graphs. `, infoBoxTitle: 'Call Graph Node', contextMenuTitle: 'Call Graph', diff --git a/packages/core/src/model/nodes/ChatNode.ts b/packages/core/src/model/nodes/ChatNode.ts index d9657375d..890841bc9 100644 --- a/packages/core/src/model/nodes/ChatNode.ts +++ b/packages/core/src/model/nodes/ChatNode.ts @@ -97,7 +97,7 @@ export class ChatNodeImpl extends NodeImpl { width: 200, }, data: { - model: 'gpt-3.5-turbo', + model: 'gpt-4o-mini', useModelInput: false, temperature: 0.5, @@ -962,11 +962,20 @@ export class ChatNodeImpl extends NodeImpl { id: functionCalls[0][0]!.id, } : undefined, + function_calls: functionCalls[0] + ? functionCalls[0].map((fc) => ({ + name: fc.name, + arguments: fc.arguments, + id: fc.id, + })) + : undefined, }, ], }; } + console.dir({ output }); + const endTime = Date.now(); if (responseChoicesParts.length === 0 && functionCalls.length === 0) { diff --git a/packages/core/src/model/nodes/DelegateFunctionCallNode.ts b/packages/core/src/model/nodes/DelegateFunctionCallNode.ts new file mode 100644 index 000000000..f473c2a14 --- /dev/null +++ b/packages/core/src/model/nodes/DelegateFunctionCallNode.ts @@ -0,0 +1,182 @@ +import { nanoid } from 'nanoid'; +import type { + AssistantChatMessageFunctionCall, + DataValue, + GptFunction, + ParsedAssistantChatMessageFunctionCall, +} from '../DataValue.js'; +import type { ChartNode, NodeId, NodeInputDefinition, NodeOutputDefinition, PortId } from '../NodeBase.js'; +import type { GraphId } from '../NodeGraph.js'; +import { NodeImpl, type NodeBody, type NodeUIData } from '../NodeImpl.js'; +import { dedent } from 'ts-dedent'; +import type { EditorDefinition } from '../EditorDefinition.js'; +import type { RivetUIContext } from '../RivetUIContext.js'; +import { nodeDefinition } from '../NodeDefinition.js'; +import type { InternalProcessContext } from '../ProcessContext.js'; +import type { Inputs, Outputs } from '../GraphProcessor.js'; +import { coerceType, coerceTypeOptional } from '../../utils/coerceType.js'; + +export type DelegateFunctionCallNode = ChartNode<'delegateFunctionCall', DelegateFunctionCallNodeData>; + +export type DelegateFunctionCallNodeData = { + handlers: { key: string; value: GraphId }[]; + unknownHandler: GraphId | undefined; +}; + +export class DelegateFunctionCallNodeImpl extends NodeImpl { + static create(): DelegateFunctionCallNode { + const chartNode: DelegateFunctionCallNode = { + type: 'delegateFunctionCall', + title: 'Delegate Function Call', + id: nanoid() as NodeId, + visualData: { + x: 0, + y: 0, + width: 325, + }, + data: { + handlers: [], + unknownHandler: undefined, + }, + }; + + return chartNode; + } + + getInputDefinitions(): NodeInputDefinition[] { + const inputs: NodeInputDefinition[] = []; + + inputs.push({ + id: 'function-call' as PortId, + dataType: 'object', + title: 'Function Call', + coerced: true, + required: true, + description: 'The function call to delegate to a subgraph.', + }); + + return inputs; + } + + getOutputDefinitions(): NodeOutputDefinition[] { + const outputs: NodeOutputDefinition[] = []; + + outputs.push({ + id: 'output' as PortId, + dataType: 'string', + title: 'Output', + description: 'The output of the function call.', + }); + + outputs.push({ + id: 'message' as PortId, + dataType: 'object', + title: 'Message Output', + description: 'Maps the output for use directly with an Assemble Prompt node and GPT.', + }); + + return outputs; + } + + static getUIData(): NodeUIData { + return { + infoBoxBody: dedent` + Handles a function call by delegating it to a different subgraph depending on the function call. + `, + infoBoxTitle: 'Delegate Function Call Node', + contextMenuTitle: 'Delegate Function Call', + group: ['Advanced'], + }; + } + + getEditors(): EditorDefinition[] { + return [ + { + type: 'custom', + customEditorId: 'ToolCallHandlers', + label: 'Handlers', + dataKey: 'handlers', + }, + { + type: 'graphSelector', + dataKey: 'unknownHandler', + label: 'Unknown Handler', + helperMessage: 'The subgraph to delegate to if the function call does not match any handlers.', + }, + ]; + } + + getBody(context: RivetUIContext): NodeBody { + if (this.data.handlers.length === 0) { + return 'No handlers defined'; + } + + const lines = ['Handlers:']; + + this.data.handlers.forEach(({ key, value }) => { + const subgraphName = context.project.graphs[value]?.metadata!.name! ?? 'Unknown Subgraph'; + lines.push(` ${key || '(MISSING!)'} -> ${subgraphName}`); + }); + + return lines.join('\n'); + } + + async process(inputs: Inputs, context: InternalProcessContext): Promise { + const functionCall = coerceType( + inputs['function-call' as PortId], + 'object', + ) as ParsedAssistantChatMessageFunctionCall; + + let handler = this.data.handlers.find((handler) => handler.key === functionCall.name); + + if (!handler) { + if (this.data.unknownHandler) { + handler = { key: undefined!, value: this.data.unknownHandler }; + } else { + throw new Error(`No handler found for function call: ${functionCall.name}`); + } + } + + const subgraphInputs: Record = { + _function_name: { + type: 'string', + value: functionCall.name, + }, + _arguments: { + type: 'object', + value: functionCall.arguments, + }, + }; + + for (const [argName, argument] of Object.entries(functionCall.arguments)) { + subgraphInputs[argName] = { + type: 'any', + value: argument, + }; + } + + const handlerGraphId = handler.value; + const subprocessor = context.createSubProcessor(handlerGraphId, { signal: context.signal }); + + const outputs = await subprocessor.processGraph(context, subgraphInputs, context.contextValues); + + const outputString = coerceTypeOptional(outputs.output, 'string') ?? ''; + + return { + ['output' as PortId]: { + type: 'string', + value: outputString, + }, + ['message' as PortId]: { + type: 'chat-message', + value: { + type: 'function', + message: outputString, + name: functionCall.id ?? '', + }, + }, + }; + } +} + +export const delegateFunctionCallNode = nodeDefinition(DelegateFunctionCallNodeImpl, 'Delegate Function Call'); diff --git a/packages/core/src/model/nodes/PromptNode.ts b/packages/core/src/model/nodes/PromptNode.ts index b848cb17f..390d33286 100644 --- a/packages/core/src/model/nodes/PromptNode.ts +++ b/packages/core/src/model/nodes/PromptNode.ts @@ -10,6 +10,7 @@ import { NodeImpl, type NodeUIData } from '../NodeImpl.js'; import { nodeDefinition } from '../NodeDefinition.js'; import { type AssistantChatMessage, + type AssistantChatMessageFunctionCall, type ChatMessage, type EditorDefinition, type Inputs, @@ -245,7 +246,8 @@ export class PromptNodeImpl extends NodeImpl { return { type, message: outputValue, - function_call: functionCall as AssistantChatMessage['function_call'], + function_call: functionCall as AssistantChatMessageFunctionCall, + function_calls: functionCall ? [functionCall as AssistantChatMessageFunctionCall] : undefined, }; }) .with( diff --git a/packages/core/src/utils/chatMessageToOpenAIChatCompletionMessage.ts b/packages/core/src/utils/chatMessageToOpenAIChatCompletionMessage.ts index a301c123c..3bc4638cd 100644 --- a/packages/core/src/utils/chatMessageToOpenAIChatCompletionMessage.ts +++ b/packages/core/src/utils/chatMessageToOpenAIChatCompletionMessage.ts @@ -53,15 +53,21 @@ export async function chatMessageToOpenAIChatCompletionMessage( role: m.type, content: onlyStringContent(m), - tool_calls: m.function_call - ? [ - { - id: m.function_call.id ?? 'unknown_function_call', - type: 'function', - function: m.function_call, - }, - ] - : undefined, + tool_calls: m.function_calls + ? m.function_calls.map((fc) => ({ + id: fc.id ?? 'unknown_function_call', + type: 'function', + function: fc, + })) + : m.function_call + ? [ + { + id: m.function_call.id ?? 'unknown_function_call', + type: 'function', + function: m.function_call, + }, + ] + : undefined, }), ) .with( diff --git a/packages/core/src/utils/openai.ts b/packages/core/src/utils/openai.ts index 05e2980e2..aac2cab99 100644 --- a/packages/core/src/utils/openai.ts +++ b/packages/core/src/utils/openai.ts @@ -142,6 +142,22 @@ export const openaiModels = { }, displayName: 'GPT-4o', }, + 'gpt-4o-mini': { + maxTokens: 128000, + cost: { + prompt: 0.00015, + completion: 0.00075, + }, + displayName: 'GPT-4o mini', + }, + 'gpt-4o-mini-2024-07-18': { + maxTokens: 128000, + cost: { + prompt: 0.00015, + completion: 0.00075, + }, + displayName: 'GPT-4o mini (2024-07-18)', + }, 'local-model': { maxTokens: Number.MAX_SAFE_INTEGER, cost: {