Skip to content

Commit

Permalink
gpt-4o-mini default, and add Delegate Function Call node, and fix All…
Browse files Browse the repository at this point in the history
… Messages parallel function calling
  • Loading branch information
abrenneke committed Jul 31, 2024
1 parent 1acca66 commit 4c51473
Show file tree
Hide file tree
Showing 14 changed files with 270 additions and 26 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
21 changes: 17 additions & 4 deletions packages/app/src/components/RenderDataValue.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,26 @@ const scalarRenderers: {
<em>assistant</em>
</header>
{messageContent}
{message.function_call && (
<div className="function-call">
<h4>Function Call:</h4>
{message.function_calls ? (
<div className="function-calls">
<h4>Function Calls:</h4>
<div className="pre-wrap">
<RenderDataValue value={inferType(message.function_call)} />
{message.function_calls.map((fc, i) => (
<div key={i}>
<RenderDataValue value={inferType(fc)} />
</div>
))}
</div>
</div>
) : (
message.function_call && (
<div className="function-call">
<h4>Function Call:</h4>
<div className="pre-wrap">
<RenderDataValue value={inferType(message.function_call)} />
</div>
</div>
)
)}
</div>
))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ export const ToolCallHandlers: FC<ToolCallHandlersProps> = ({
}
isDisabled={isDisabled}
isReadOnly={isReadonly}
placeholder="Tool ID"
placeholder="Tool/Function"
style={{ marginRight: '8px' }}
/>
<GraphSelectorSelect
Expand Down
3 changes: 2 additions & 1 deletion packages/app/src/components/nodeStyles.ts
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,8 @@ export const nodeStyles = css`
/* background-color: rgba(1, 1, 1, 0.5); */
}
.node-output .function-call {
.node-output .function-call,
.node-output .function-calls {
h4 {
margin-top: 0;
margin-bottom: 0;
Expand Down
2 changes: 2 additions & 0 deletions packages/app/src/hooks/useBuiltInNodeImages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ import replaceDatasetNodeImage from '../assets/node_images/replace_dataset_node.
import listGraphsNodeImage from '../assets/node_images/list_graphs_node.png';
import graphReferenceNodeImage from '../assets/node_images/graph_reference_node.png';
import callGraphNodeImage from '../assets/node_images/call_graph_node.png';
import delegateFunctionCallNodeImage from '../assets/node_images/delegate_function_call_node.png';

export const useBuiltInNodeImages = (): Record<BuiltInNodeType, string> => {
return {
Expand Down Expand Up @@ -143,5 +144,6 @@ export const useBuiltInNodeImages = (): Record<BuiltInNodeType, string> => {
listGraphs: listGraphsNodeImage,
graphReference: graphReferenceNodeImage,
callGraph: callGraphNodeImage,
delegateFunctionCall: delegateFunctionCallNodeImage,
};
};
24 changes: 17 additions & 7 deletions packages/core/src/model/DataValue.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,26 @@ export type UserChatMessage = {
message: ChatMessageMessagePart | ChatMessageMessagePart[];
};

export type AssistantChatMessageFunctionCall = {
id: string | undefined;
name: string;
arguments: string; // JSON string
};

export type ParsedAssistantChatMessageFunctionCall = {
id: string | undefined;
name: string;
arguments: Record<string, unknown>;
};

export type AssistantChatMessage = {
type: 'assistant';
message: ChatMessageMessagePart | ChatMessageMessagePart[];
function_call:
| {
id: string | undefined;
name: string;
arguments: string; // JSON string
}
| undefined;

/** @deprecated use function_calls instead */
function_call: AssistantChatMessageFunctionCall | undefined;

function_calls: AssistantChatMessageFunctionCall[] | undefined;
};

export type FunctionResponseChatMessage = {
Expand Down
4 changes: 3 additions & 1 deletion packages/core/src/model/Nodes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ import { graphReferenceNode } from './nodes/GraphReferenceNode.js';
export * from './nodes/GraphReferenceNode.js';

import { callGraphNode } from './nodes/CallGraphNode.js';
import { delegateFunctionCallNode } from './nodes/DelegateFunctionCallNode.js';
export * from './nodes/CallGraphNode.js';

export const registerBuiltInNodes = (registry: NodeRegistration) => {
Expand Down Expand Up @@ -283,7 +284,8 @@ export const registerBuiltInNodes = (registry: NodeRegistration) => {
.register(replaceDatasetNode)
.register(listGraphsNode)
.register(graphReferenceNode)
.register(callGraphNode);
.register(callGraphNode)
.register(delegateFunctionCallNode);
};

let globalRivetNodeRegistry = registerBuiltInNodes(new NodeRegistration());
Expand Down
1 change: 1 addition & 0 deletions packages/core/src/model/nodes/AssembleMessageNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ export class AssembleMessageNodeImpl extends NodeImpl<AssembleMessageNode> {
type,
message: [],
function_call: undefined, // Not supported yet in Assemble Message node
function_calls: undefined, // Not supported yet in Assemble Message node
}),
)
.with(
Expand Down
2 changes: 1 addition & 1 deletion packages/core/src/model/nodes/CallGraphNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ export class CallGraphNodeImpl extends NodeImpl<CallGraphNode> {
static getUIData(): NodeUIData {
return {
infoBoxBody: dedent`
Gets a reference to another graph, that can be used to pass around graphs to call using a Call Graph node.
Calls another graph and passes inputs to it. Use in combination with the Graph Reference node to call dynamic graphs.
`,
infoBoxTitle: 'Call Graph Node',
contextMenuTitle: 'Call Graph',
Expand Down
11 changes: 10 additions & 1 deletion packages/core/src/model/nodes/ChatNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ export class ChatNodeImpl extends NodeImpl<ChatNode> {
width: 200,
},
data: {
model: 'gpt-3.5-turbo',
model: 'gpt-4o-mini',
useModelInput: false,

temperature: 0.5,
Expand Down Expand Up @@ -962,11 +962,20 @@ export class ChatNodeImpl extends NodeImpl<ChatNode> {
id: functionCalls[0][0]!.id,
}
: undefined,
function_calls: functionCalls[0]
? functionCalls[0].map((fc) => ({
name: fc.name,
arguments: fc.arguments,
id: fc.id,
}))
: undefined,
},
],
};
}

console.dir({ output });

const endTime = Date.now();

if (responseChoicesParts.length === 0 && functionCalls.length === 0) {
Expand Down
182 changes: 182 additions & 0 deletions packages/core/src/model/nodes/DelegateFunctionCallNode.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import { nanoid } from 'nanoid';
import type {
AssistantChatMessageFunctionCall,
DataValue,
GptFunction,
ParsedAssistantChatMessageFunctionCall,
} from '../DataValue.js';
import type { ChartNode, NodeId, NodeInputDefinition, NodeOutputDefinition, PortId } from '../NodeBase.js';
import type { GraphId } from '../NodeGraph.js';
import { NodeImpl, type NodeBody, type NodeUIData } from '../NodeImpl.js';
import { dedent } from 'ts-dedent';
import type { EditorDefinition } from '../EditorDefinition.js';
import type { RivetUIContext } from '../RivetUIContext.js';
import { nodeDefinition } from '../NodeDefinition.js';
import type { InternalProcessContext } from '../ProcessContext.js';
import type { Inputs, Outputs } from '../GraphProcessor.js';
import { coerceType, coerceTypeOptional } from '../../utils/coerceType.js';

export type DelegateFunctionCallNode = ChartNode<'delegateFunctionCall', DelegateFunctionCallNodeData>;

export type DelegateFunctionCallNodeData = {
handlers: { key: string; value: GraphId }[];
unknownHandler: GraphId | undefined;
};

export class DelegateFunctionCallNodeImpl extends NodeImpl<DelegateFunctionCallNode> {
static create(): DelegateFunctionCallNode {
const chartNode: DelegateFunctionCallNode = {
type: 'delegateFunctionCall',
title: 'Delegate Function Call',
id: nanoid() as NodeId,
visualData: {
x: 0,
y: 0,
width: 325,
},
data: {
handlers: [],
unknownHandler: undefined,
},
};

return chartNode;
}

getInputDefinitions(): NodeInputDefinition[] {
const inputs: NodeInputDefinition[] = [];

inputs.push({
id: 'function-call' as PortId,
dataType: 'object',
title: 'Function Call',
coerced: true,
required: true,
description: 'The function call to delegate to a subgraph.',
});

return inputs;
}

getOutputDefinitions(): NodeOutputDefinition[] {
const outputs: NodeOutputDefinition[] = [];

outputs.push({
id: 'output' as PortId,
dataType: 'string',
title: 'Output',
description: 'The output of the function call.',
});

outputs.push({
id: 'message' as PortId,
dataType: 'object',
title: 'Message Output',
description: 'Maps the output for use directly with an Assemble Prompt node and GPT.',
});

return outputs;
}

static getUIData(): NodeUIData {
return {
infoBoxBody: dedent`
Handles a function call by delegating it to a different subgraph depending on the function call.
`,
infoBoxTitle: 'Delegate Function Call Node',
contextMenuTitle: 'Delegate Function Call',
group: ['Advanced'],
};
}

getEditors(): EditorDefinition<DelegateFunctionCallNode>[] {
return [
{
type: 'custom',
customEditorId: 'ToolCallHandlers',
label: 'Handlers',
dataKey: 'handlers',
},
{
type: 'graphSelector',
dataKey: 'unknownHandler',
label: 'Unknown Handler',
helperMessage: 'The subgraph to delegate to if the function call does not match any handlers.',
},
];
}

getBody(context: RivetUIContext): NodeBody {
if (this.data.handlers.length === 0) {
return 'No handlers defined';
}

const lines = ['Handlers:'];

this.data.handlers.forEach(({ key, value }) => {
const subgraphName = context.project.graphs[value]?.metadata!.name! ?? 'Unknown Subgraph';
lines.push(` ${key || '(MISSING!)'} -> ${subgraphName}`);
});

return lines.join('\n');
}

async process(inputs: Inputs, context: InternalProcessContext): Promise<Outputs> {
const functionCall = coerceType(
inputs['function-call' as PortId],
'object',
) as ParsedAssistantChatMessageFunctionCall;

let handler = this.data.handlers.find((handler) => handler.key === functionCall.name);

if (!handler) {
if (this.data.unknownHandler) {
handler = { key: undefined!, value: this.data.unknownHandler };
} else {
throw new Error(`No handler found for function call: ${functionCall.name}`);
}
}

const subgraphInputs: Record<string, DataValue> = {
_function_name: {
type: 'string',
value: functionCall.name,
},
_arguments: {
type: 'object',
value: functionCall.arguments,
},
};

for (const [argName, argument] of Object.entries(functionCall.arguments)) {
subgraphInputs[argName] = {
type: 'any',
value: argument,
};
}

const handlerGraphId = handler.value;
const subprocessor = context.createSubProcessor(handlerGraphId, { signal: context.signal });

const outputs = await subprocessor.processGraph(context, subgraphInputs, context.contextValues);

const outputString = coerceTypeOptional(outputs.output, 'string') ?? '';

return {
['output' as PortId]: {
type: 'string',
value: outputString,
},
['message' as PortId]: {
type: 'chat-message',
value: {
type: 'function',
message: outputString,
name: functionCall.id ?? '',
},
},
};
}
}

export const delegateFunctionCallNode = nodeDefinition(DelegateFunctionCallNodeImpl, 'Delegate Function Call');
4 changes: 3 additions & 1 deletion packages/core/src/model/nodes/PromptNode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { NodeImpl, type NodeUIData } from '../NodeImpl.js';
import { nodeDefinition } from '../NodeDefinition.js';
import {
type AssistantChatMessage,
type AssistantChatMessageFunctionCall,
type ChatMessage,
type EditorDefinition,
type Inputs,
Expand Down Expand Up @@ -245,7 +246,8 @@ export class PromptNodeImpl extends NodeImpl<PromptNode> {
return {
type,
message: outputValue,
function_call: functionCall as AssistantChatMessage['function_call'],
function_call: functionCall as AssistantChatMessageFunctionCall,
function_calls: functionCall ? [functionCall as AssistantChatMessageFunctionCall] : undefined,
};
})
.with(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,21 @@ export async function chatMessageToOpenAIChatCompletionMessage(
role: m.type,
content: onlyStringContent(m),

tool_calls: m.function_call
? [
{
id: m.function_call.id ?? 'unknown_function_call',
type: 'function',
function: m.function_call,
},
]
: undefined,
tool_calls: m.function_calls
? m.function_calls.map((fc) => ({
id: fc.id ?? 'unknown_function_call',
type: 'function',
function: fc,
}))
: m.function_call
? [
{
id: m.function_call.id ?? 'unknown_function_call',
type: 'function',
function: m.function_call,
},
]
: undefined,
}),
)
.with(
Expand Down
Loading

0 comments on commit 4c51473

Please sign in to comment.