Skip to content

Commit

Permalink
Add server wallet transaction signing to lifi and fixed get balance
Browse files Browse the repository at this point in the history
  • Loading branch information
RickyRoller committed Jan 7, 2025
1 parent 83cc7d7 commit fa4b49e
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 82 deletions.
2 changes: 1 addition & 1 deletion src/agent/agents/wallet_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ async def transfer_to(self) -> "BaseAgent":
@agent_tool()
async def get_balances(self) -> Dict[str, any]:
"""
Get the balance of all tokens in the wallet.
Get the balance of the ETH in the wallet.
"""
return await self._wallet.get_balances()

Expand Down
49 changes: 32 additions & 17 deletions src/agent/core/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ def __init__(
self._entry_agent = entry_agent
self._current_agent: Optional[BaseAgent] = entry_agent

async def _execute_tool(self, tool_call: Dict[str, Any]) -> str | BaseAgent:
async def _execute_tool(
self, tool_call: Dict[str, Any]
) -> Dict[str, str | BaseAgent]:
"""Execute a tool call by finding the appropriate agent and method
Returns:
Union[str, BaseAgent]: Either a JSON serialized result string or a BaseAgent instance
Raises:
ValueError: If no agent is found with the requested tool method
Dict[str, Any]: Result dictionary containing success status and result/error
"""
try:
method_name = tool_call["function"]["name"]
Expand All @@ -69,15 +69,18 @@ async def _execute_tool(self, tool_call: Dict[str, Any]) -> str | BaseAgent:
result = await method(**args)
# Special case for agent transfers
if isinstance(result, BaseAgent):
return result
return result
return {"success": True, "result": result}
return {"success": True, "result": result}

raise ValueError(f"No agent found with tool method: {method_name}")
return {
"success": False,
"error": f"No agent found with tool method: {method_name}",
}

except Exception as e:
error_message = f"Tool execution failed: {str(e)}"
self._debug_log("Tool execution error", error_message)
return error_message
return {"success": False, "error": error_message}

def _debug_log(self, message: str, data: Optional[Any] = None) -> None:
"""Log debug information if debug mode is enabled
Expand Down Expand Up @@ -130,6 +133,7 @@ async def agent_loop(self) -> AsyncGenerator[str, None]:
"""Generate a response from the current agent"""
generation_count = 0
status = "streaming"

while generation_count < 2 and status == "streaming":
generation_count += 1
self._debug_log(f"Generation attempt {generation_count}")
Expand All @@ -150,22 +154,33 @@ async def agent_loop(self) -> AsyncGenerator[str, None]:
tool_calls=chunk["tool_calls"],
)

# Then process each tool call
# Process each tool call
for tool_call in chunk["tool_calls"]:
result = await self._execute_tool(tool_call)
if isinstance(result, BaseAgent):
self._current_agent = result

if not result["success"]:
# Add error to message history
self._message_manager.add_message(
result["error"], "tool", tool_id=tool_call["id"]
)
# Yield specific error and end loop
yield "I'm sorry, something went wrong and I was unable to complete your request"
status = "failed"
return

if isinstance(result["result"], BaseAgent):
self._current_agent = result["result"]
generation_count = 0
self._debug_log(
"Switching to new agent, resetting generation count"
)
result = (
f"Transferring to {result.name}. You may continue."
)
tool_result = f"Transferring to {result['result'].name}. You may continue."
else:
tool_result = result["result"]

# Always add the result to message history, whether success or failure
# Add successful result to message history
self._message_manager.add_message(
result, "tool", tool_id=tool_call["id"]
tool_result, "tool", tool_id=tool_call["id"]
)

except Exception as e:
Expand All @@ -175,4 +190,4 @@ async def agent_loop(self) -> AsyncGenerator[str, None]:
break

if status == "streaming":
yield "I'm sorry, I'm having trouble processing your request. Please try again later."
yield "Request could not be completed. Please try again."
20 changes: 8 additions & 12 deletions src/wallet/adapters/lifi/lifi_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,6 @@ async def swap(self, quote: Dict[str, Any]) -> AsyncGenerator[Dict[str, Any], No
Yields:
Dict containing status updates about the swap progress
Returns:
Final swap result with transaction details
"""
transaction_request = {
"data": quote["transactionRequest"]["data"],
Expand All @@ -124,29 +121,28 @@ async def swap(self, quote: Dict[str, Any]) -> AsyncGenerator[Dict[str, Any], No
"gas": hex(int(quote["estimate"]["gasCosts"][0]["estimate"])),
"type": "0x2",
"nonce": await self._wallet._web3.eth.get_transaction_count(
self._wallet._account.address
self._wallet._wallet_address
),
}

signed_tx = self._wallet._web3.eth.account.sign_transaction(
transaction_request, self._wallet._account.key
response = await self._wallet.sign_transaction(
transaction_request, gas_estimate=False
)

tx_hash = response.get("data", {}).get("hash")
if not tx_hash:
raise QuoteError("No transaction hash")

yield {
"status": "pending",
"message": "Transaction submitted, waiting for confirmation...",
"transaction_hash": tx_hash.hex(),
"transaction_hash": tx_hash,
}

tx_hash = await self._wallet._web3.eth.send_raw_transaction(
signed_tx.raw_transaction
)

receipt = await self._wallet._web3.eth.wait_for_transaction_receipt(
tx_hash, poll_latency=0.5
)

# Return final result
yield {
"status": "success" if receipt["status"] == 1 else "failed",
"message": (
Expand Down
68 changes: 17 additions & 51 deletions src/wallet/wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,74 +136,38 @@ async def transfer(
@wallet_tool(
descriptions={"token_type": "Type of tokens to fetch (erc20, erc721, etc)"}
)
async def get_balances(self, token_type: str = "erc20") -> Dict[str, Any]:
async def get_balances(self) -> Dict[str, Any]:
"""
Check wallet balances for ETH and tokens. Use this method when:
- User wants to check their ETH balance
- User wants to see all their token balances
- User needs to verify their available funds
- User asks about their crypto holdings
- User wants to know how much they own
Check wallet balances for ETH
Common triggers: "check balance", "how much eth do i have", "show my tokens",
"what's in my wallet", "view balance", "check my crypto"
Args:
token_type (str): Type of tokens to fetch (default: "erc20")
Common triggers: "check balance", "how much eth do i have", "view balance", "check my crypto"
Returns:
Dict[str, Any]: Combined ETH and token balances
"""
if not self._account:
raise WalletError("Wallet not initialized with private key")

raise WalletError("Not implemented")
# Get ETH balance
try:
eth_balance_wei = await self._web3.eth.get_balance(self._account.address)
eth_balance_wei = await self._web3.eth.get_balance(self._wallet_address)
eth_balance = self._web3.from_wei(eth_balance_wei, "ether")
except Exception as e:
raise WalletError(f"Failed to fetch ETH balance: {str(e)}")

# Get tracked token balances
token_balances = {}
for token_address in self._tracked_tokens:
try:
token_contract = common_contracts.get_contract("erc20", token_address)
raw_balance = await token_contract.functions.balanceOf(
self._account.address
).call()
# Get token symbol and decimals
symbol = await token_contract.functions.symbol().call()
decimals = await token_contract.functions.decimals().call()
# Convert raw balance to token amount
token_balance = Decimal(raw_balance) / Decimal(10**decimals)
token_balances[symbol] = str(token_balance)
except Exception as e:
raise WalletError(
f"Failed to fetch balance for token {token_address}: {str(e)}"
)

balances = {
"ETH": str(eth_balance),
**token_balances,
}
return f"""
Display the following wallet balances in a bulleted list using markdown
{balances}
"""
return f"ETH: {str(eth_balance)}"

def get_tracked_tokens(self) -> List[str]:
"""Returns list of tracked token addresses"""
return list(self._tracked_tokens)

async def sign_transaction_with_privy(
self, transaction: Dict[str, Any]
async def sign_transaction(
self, transaction: Dict[str, Any], gas_estimate: bool = True
) -> Dict[str, Any]:
"""
Sign and send a transaction using Privy's wallet API.
Args:
transaction (Dict[str, Any]): Transaction parameters following eth_signTransaction format
transaction (Dict[str, Any]): Transaction parameters
gas_estimate (bool): Whether to estimate gas if not provided
Returns:
Dict[str, Any]: Transaction response from Privy API
Expand All @@ -214,29 +178,31 @@ async def sign_transaction_with_privy(
url = f"https://api.privy.io/v1/wallets/{self._wallet_id}/rpc"

privy_app_id = os.getenv("PRIVY_APP_ID")

privy_app_secret = os.getenv("PRIVY_APP_SECRET")

if not privy_app_id or not privy_app_secret:
raise ValueError(
"PRIVY_APP_ID and PRIVY_APP_SECRET environment variables must be set"
)

# Create basic auth header from app_id:app_secret
auth_string = f"{privy_app_id}:{privy_app_secret}"
basic_auth = base64.b64encode(auth_string.encode()).decode()

# Ensure chain_id is included in transaction
# Add chain_id if not present
if "chain_id" not in transaction:
transaction["chain_id"] = self._chain_id

# Estimate gas if needed and not provided
if gas_estimate and "gas" not in transaction:
gas = await self._web3.eth.estimate_gas(transaction)
transaction["gas"] = hex(gas)

body = {
"method": "eth_sendTransaction",
"caip2": f"eip155:{self._chain_id}",
"params": {"transaction": transaction},
}

# Get Privy authorization headers
headers = self._privy_signer.get_auth_headers(url=url, body=body, method="POST")
headers.update(
{
Expand All @@ -249,6 +215,6 @@ async def sign_transaction_with_privy(
async with session.post(url, json=body, headers=headers) as response:
if response.status != 200:
error_text = await response.text()
raise WalletError(f"Privy API request failed: {error_text}")
raise WalletError(f"API request failed: {error_text}")

return await response.json()
8 changes: 7 additions & 1 deletion src/wallet/wallet_types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TypeVar
from typing import TypeVar, Dict, Any
from eth_account.signers.local import LocalAccount
from web3 import AsyncWeb3

Expand All @@ -9,6 +9,12 @@
class WalletInstance:
_web3: Web3Type
_account: AccountType
_wallet_id: str
_chain_id: int

async def sign_transaction(self, transaction: Dict[str, Any]) -> Dict[str, Any]:
"""Sign and send transaction"""
pass


WalletType = TypeVar("WalletType", bound=WalletInstance)

0 comments on commit fa4b49e

Please sign in to comment.