diff --git a/app/api/chat/route.ts b/app/api/chat/route.ts index 156ac3a..24bced3 100644 --- a/app/api/chat/route.ts +++ b/app/api/chat/route.ts @@ -7,6 +7,7 @@ import { PromptTemplate } from "@langchain/core/prompts"; import { getAbi } from "@/utils/etherscan"; import { generateToolFromABI } from "@/utils/generateToolFromABI"; import { CustomParser } from "@/utils/CustomParser"; +import { contractCollection } from "@/utils/collections"; export const runtime = "nodejs"; @@ -18,13 +19,11 @@ export async function POST(req: NextRequest) { try { const body = await req.json(); const messages = body.messages ?? []; - const contractAddresses = (body.contractAddresses ?? "").split( - ",", - ) as string[]; + const contracts = await contractCollection.get(); const network = body.network ?? ""; const formattedPreviousMessages = messages.slice(0, -1).map(formatMessage); const currentMessageContent = messages[messages.length - 1].content; - + const contractAddresses = contracts.map(({ address }) => address); const TEMPLATE = `You are to interact with smart contracts on behalf of the user. The smart contract addresses are ${contractAddresses}. You will be provided with functions that represent the functions in the ABI the user can call. Based on the user's prompt, determine what function they are trying to call, and extract the appropriate inputs. Current conversation: @@ -37,11 +36,12 @@ AI:`; const abis = await Promise.all( contractAddresses.map((ca) => getAbi(ca, network)), ); - const tools = abis.flatMap((abi, i) => - JSON.parse(abi) + const tools = abis.flatMap((abi, i) => { + const contract = contracts[i]; + return JSON.parse(abi) .filter((f: any) => f.name && f.type === "function") - .map(generateToolFromABI(contractAddresses[i])), - ); + .map(generateToolFromABI(contract)); + }); const prompt = PromptTemplate.fromTemplate(TEMPLATE); const model = new ChatOpenAI({ diff --git a/app/api/contracts/route.ts b/app/api/contracts/route.ts index c5ad91b..00e6db8 100644 --- a/app/api/contracts/route.ts +++ b/app/api/contracts/route.ts @@ -1,14 +1,11 @@ -import { KVCollection } from "@/utils/kvCollection"; +import { contractCollection } from "@/utils/collections"; import { NextRequest, NextResponse } from "next/server"; export const runtime = "nodejs"; -const collection = new KVCollection<{ address: string; name: string }>( - "contracts:", -); export async function GET() { try { - const contracts = await collection.get(); + const contracts = await contractCollection.get(); return NextResponse.json({ contracts }, { status: 200 }); } catch (e: any) { @@ -18,7 +15,7 @@ export async function GET() { } export async function POST(req: NextRequest) { - const existingContracts = await collection.get(); + const existingContracts = await contractCollection.get(); try { const body = await req.json(); @@ -32,8 +29,8 @@ export async function POST(req: NextRequest) { // TODO: fetch abi and throw if not found - await collection.add({ address: body.address, name: body.name }); - const contracts = await collection.get(); + await contractCollection.add({ address: body.address, name: body.name }); + const contracts = await contractCollection.get(); return NextResponse.json({ contracts }, { status: 200 }); } catch (e: any) { @@ -48,8 +45,8 @@ export async function POST(req: NextRequest) { export async function DELETE(req: NextRequest) { try { const body = await req.json(); - await collection.delete(body.key); - const contracts = await collection.get(); + await contractCollection.delete(body.key); + const contracts = await contractCollection.get(); return NextResponse.json({ contracts }, { status: 200 }); } catch (e: any) { diff --git a/app/api/execute/route.ts b/app/api/execute/route.ts index 18293a9..7939106 100644 --- a/app/api/execute/route.ts +++ b/app/api/execute/route.ts @@ -2,6 +2,7 @@ import { NextRequest, NextResponse } from "next/server"; import { getAbi } from "@/utils/etherscan"; import { generateToolFromABI } from "@/utils/generateToolFromABI"; import { routeBodySchema } from "./schemas"; +import { contractCollection } from "@/utils/collections"; export const runtime = "nodejs"; @@ -23,17 +24,28 @@ export async function POST(req: NextRequest) { const { toolCall, network, didToken } = result.data; - // parse contractAddress from toolCall.name; Should be in format `${contractAddress}-${functionName}` - const contractAddress = toolCall.name.split("-").at(0) as string; + // parse contractAddress from toolCall.name; Should be in format `${contractKey}-${functionName}` + const contractKey = parseInt(toolCall.name.split("-").at(0) as string, 10); + const contracts = await contractCollection.get(); + const contract = contracts.find(({ key }) => contractKey === key); + + if (!contract) { + return NextResponse.json( + { + error: `Unable to find reference ${contractKey}`, + }, + { status: 400 }, + ); + } try { let abi = "[]"; try { - abi = await getAbi(contractAddress, network); + abi = await getAbi(contract.address, network); } catch (e) { return NextResponse.json( { - error: `Could Not retreive ABI for contract ${contractAddress}`, + error: `Could Not retreive ABI for contract ${contract.address}`, }, { status: 400 }, ); @@ -41,13 +53,13 @@ export async function POST(req: NextRequest) { const tools = JSON.parse(abi) .filter((f: any) => f.name && f.type === "function") - .map(generateToolFromABI(contractAddress, didToken)); + .map(generateToolFromABI(contract, didToken)); const tool = tools.find((t: any) => t.name === toolCall.name); if (!tool) { return NextResponse.json( { - error: `Function ${toolCall.name} not found in ${contractAddress}`, + error: `Function ${toolCall.name} not found in ${contract.address}`, }, { status: 404 }, ); diff --git a/utils/collections.ts b/utils/collections.ts new file mode 100644 index 0000000..a3a9d98 --- /dev/null +++ b/utils/collections.ts @@ -0,0 +1,6 @@ +import { KVCollection } from "./kvCollection"; + +export const contractCollection = new KVCollection<{ + address: string; + name: string; +}>("contracts:"); diff --git a/utils/generateToolFromABI.ts b/utils/generateToolFromABI.ts index 37d55e9..69075f5 100644 --- a/utils/generateToolFromABI.ts +++ b/utils/generateToolFromABI.ts @@ -8,10 +8,12 @@ import { TransactionError, NetworkError, SigningError } from "./errors"; const magic = await Magic.init(process.env.MAGIC_SECRET_KEY); export const generateToolFromABI = - (contractAddress: string, didToken?: string) => + ( + contract: { key: number; address: string; name: string }, + didToken?: string, + ) => (func: AbiFunction): any => { let schema: any = {}; - func.inputs.forEach((input) => { if (input.type === "uint256[]") { schema[input.name ?? ""] = z.array(z.number()).describe("description"); @@ -25,8 +27,8 @@ export const generateToolFromABI = }); return new DynamicStructuredTool({ - name: `${contractAddress}-${func.name}`, - description: `Description for ${contractAddress} ${func.name}`, + name: `${contract.key}-${func.name}`, + description: `Description for ${contract.address} ${func.name}`, schema: z.object(schema), func: async (args): Promise => { // This function should return a string according to the link hence the stringifed JSON @@ -47,7 +49,7 @@ export const generateToolFromABI = try { const txReceipt = await getTransactionReceipt({ - contractAddress, + contractAddress: contract.address, functionName: func.name, args: ensuredArgOrder, publicAddress,