Skip to content

Commit

Permalink
integrations compatible with agent
Browse files Browse the repository at this point in the history
  • Loading branch information
jwarmuz99 committed Jun 7, 2024
1 parent 8e5fb6a commit 09b0316
Show file tree
Hide file tree
Showing 17 changed files with 3,272 additions and 159 deletions.
173 changes: 34 additions & 139 deletions examples/uni_v3_lp/action_agent.py
Original file line number Diff line number Diff line change
@@ -1,170 +1,65 @@
import os
import logging
import pprint

import numpy as np
from addresses import ADDRESSES
from dotenv import find_dotenv, load_dotenv
from lp_tools import get_tick_range
from mint_position import close_position, get_all_user_positions, get_mint_params
from prefect import get_run_logger

from giza.agents.action import action
from giza.agents import AgentResult, GizaAgent
from giza.agents.task import task
from giza.agents import GizaAgent

load_dotenv(find_dotenv())

# Here we load a custom sepolia rpc url from the environment
sepolia_rpc_url = os.environ.get("SEPOLIA_RPC_URL")

MODEL_ID = ... # Update with your model ID
VERSION_ID = ... # Update with your version ID


@task
def process_data(realized_vol, dec_price_change):
pct_change_sq = (100 * dec_price_change) ** 2
X = np.array([[realized_vol, pct_change_sq]])
return X


# Get image
@task
def get_data():
# TODO: implement fetching onchain or from some other source
realized_vol = 4.20
dec_price_change = 0.1
return realized_vol, dec_price_change


@task
def create_agent(
model_id: int, version_id: int, chain: str, contracts: dict, account: str
):
"""
Create a Giza agent for the volatility prediction model
"""
def transmission():
logger = logging.getLogger(__name__)
id = ...
version = ...
account = ...
realized_vol, dec_price_change = get_data()
input_data = process_data(realized_vol, dec_price_change)

agent = GizaAgent(
contracts=contracts,
id=model_id,
version_id=version_id,
chain=chain,
integrations=["UniswapV3"],
id=id,
chain="ethereum:sepolia:https://sepolia.infura.io/v3/765888cfa824440c8c0feb5b492abedd",
version_id=version,
account=account,
)
return agent


@task
def predict(agent: GizaAgent, X: np.ndarray):
"""
Predict the digit in an image.
Args:
image (np.ndarray): Image to predict.
Returns:
int: Predicted digit.
"""
prediction = agent.predict(input_feed={"val": X}, verifiable=True, job_size="XL")
return prediction


@task
def get_pred_val(prediction: AgentResult):
"""
Get the value from the prediction.
Args:
prediction (dict): Prediction from the model.

Returns:
int: Predicted value.
"""
# This will block the executon until the prediction has generated the proof and the proof has been verified
return prediction.value[0][0]


# Create Action
@action
def transmission(
pred_model_id,
pred_version_id,
account="dev",
chain=f"ethereum:sepolia:{sepolia_rpc_url}",
):
logger = get_run_logger()

nft_manager_address = ADDRESSES["NonfungiblePositionManager"][11155111]
tokenA_address = ADDRESSES["UNI"][11155111]
tokenB_address = ADDRESSES["WETH"][11155111]
pool_address = "0x287B0e934ed0439E2a7b1d5F0FC25eA2c24b64f7"
user_address = "0xCBB090699E0664f0F6A4EFbC616f402233718152"

pool_fee = 3000
tokenA_amount = 1000
tokenB_amount = 1000

logger.info("Fetching input data")
realized_vol, dec_price_change = get_data()

logger.info(f"Input data: {realized_vol}, {dec_price_change}")
X = process_data(realized_vol, dec_price_change)

nft_manager_abi_path = "nft_manager_abi.json"
contracts = {
"nft_manager": [nft_manager_address, nft_manager_abi_path],
"tokenA": [tokenA_address],
"tokenB": tokenB_address,
"pool": pool_address,
}
agent = create_agent(
model_id=pred_model_id,
version_id=pred_version_id,
chain=chain,
contracts=contracts,
account=account,
result = agent.predict(
input_feed={"val": input_data}, verifiable=True, dry_run=True
)
result = predict(agent, X)
predicted_value = get_pred_val(result)
logger.info(f"Result: {result}")

logger.info(f"Result: {result}")
with agent.execute() as contracts:
logger.info("Executing contract")
# TODO: fix below
positions = get_all_user_positions(contracts.nft_manager, user_address)
logger.info(f"Found the following positions: {positions}")
# step 4: close all positions
logger.info("Closing all open positions...")
for nft_id in positions:
close_position(user_address, contracts.nft_manager, nft_id)
# step 4: calculate mint params
logger.info("Calculating mint params...")
_, curr_tick, _, _, _, _, _ = contracts.pool.slot0()
tokenA_decimals = contracts.tokenA.decimals()
tokenB_decimals = contracts.tokenB.decimals()
# TODO: confirm input should be result and not result * 100
lower_tick, upper_tick = get_tick_range(
curr_tick, predicted_value, tokenA_decimals, tokenB_decimals, pool_fee
UNI_address = "0x1f9840a85d5aF5bf1D1762F925BDADdC4201F984"
WETH_address = "0xfFf9976782d46CC05630D1f6eBAb18b2324d6B14"
uni = contracts.UniswapV3
volatility_prediction = result.value[0]
pool = uni.get_pool(UNI_address, WETH_address, fee=500)
curr_price = pool.get_pool_price()
lower_price = curr_price * (1 - volatility_prediction)
upper_price = curr_price * (1 + volatility_prediction)
amount0 = 100
amount1 = 100
agent_result = uni.mint_position(
pool, lower_price, upper_price, amount0, amount1
)
mint_params = get_mint_params(
tokenA_address,
tokenB_address,
user_address,
tokenA_amount,
tokenB_amount,
pool_fee,
lower_tick,
upper_tick,
logger.info(
f"Current price: {curr_price}, new bounds: {lower_price}, {upper_price}"
)
# step 5: mint new position
logger.info("Minting new position...")
contract_result = contracts.nft_manager.mint(mint_params)
logger.info("SUCCESSFULLY MINTED A POSITION")
logger.info("Contract executed")
logger.info(f"Minted position: {agent_result}")

logger.info(f"Contract result: {contract_result}")
pprint.pprint(contract_result.__dict__)
logger.info(f"Contract result: {agent_result}")
logger.info("Finished")


transmission(MODEL_ID, VERSION_ID)
transmission()
79 changes: 59 additions & 20 deletions giza/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Callable, Dict, List, Optional, Self, Tuple, Union

from ape import Contract, accounts, networks
from ape.api import AccountAPI
from ape.contracts import ContractInstance
from ape.exceptions import NetworkError
from ape_accounts.accounts import InvalidPasswordError
Expand All @@ -18,6 +19,7 @@
from giza.cli.utils.enums import JobKind, JobStatus
from requests import HTTPError

from giza.agents.integration import IntegrationFactory
from giza.agents.model import GizaModel
from giza.agents.utils import read_json

Expand All @@ -34,7 +36,8 @@ def __init__(
self,
id: int,
version_id: int,
contracts: Dict[str, Union[str, List[str]]],
contracts: Optional[Dict[str, Union[str, List[str]]]] = None,
integrations: Optional[List[str]] = None,
chain: Optional[str] = None,
account: Optional[str] = None,
**kwargs: Any,
Expand All @@ -44,6 +47,7 @@ def __init__(
model_id (int): The ID of the model.
version_id (int): The version of the model.
contracts (Dict[str, str]): The contracts to handle, must be a dictionary with the contract name as the key and the contract address as the value.
integrations (List[str]): The integrations to use.
chain_id (int): The ID of the blockchain network.
**kwargs: Additional keyword arguments.
"""
Expand All @@ -63,11 +67,11 @@ def __init__(
logger.error("Agent is missing required parameters")
raise ValueError(f"Agent is missing required parameters: {e}")

self.contract_handler = ContractHandler(contracts)
self.chain = chain
self.account = account
self._check_passphrase_in_env()
self._check_or_create_account()
self.contract_handler = ContractHandler(contracts, integrations)

# Useful for testing
network_parser: Callable = kwargs.get(
Expand Down Expand Up @@ -240,8 +244,8 @@ def execute(self) -> Any:
f"Invalid passphrase for account {self.account}. Could not decrypt account."
) from e
logger.debug("Autosign enabled")
with accounts.use_sender(self._account):
yield self.contract_handler.handle()
with accounts.use_sender(self._account) as sender:
yield self.contract_handler.handle(account=sender)

def predict(
self,
Expand Down Expand Up @@ -452,15 +456,35 @@ class ContractHandler:
which means that it should be done insede the GizaAgent's execute context.
"""

def __init__(self, contracts: Dict[str, Union[str, List[str]]]) -> None:
def __init__(
self,
contracts: Optional[Dict[str, Union[str, List[str]]]] = None,
integrations: Optional[List[str]] = None,
) -> None:
if contracts is None:
contracts = {}
if integrations is None:
integrations = []
contract_names = list(contracts.keys())
duplicates = set(contract_names) & set(integrations)
if duplicates:
duplicate_names = ", ".join(duplicates)
raise ValueError(
f"Integrations of these names already exist: {duplicate_names}. Choose different contract names."
)
self._contracts = contracts
self._integrations = integrations
self._contracts_instances: Dict[str, ContractInstance] = {}
self._integrations_instances: Dict[str, IntegrationFactory] = {}

def __getattr__(self, name: str) -> ContractInstance:
def __getattr__(self, name: str) -> Union[ContractInstance, IntegrationFactory]:
"""
Get the contract by name.
"""
return self._contracts_instances[name]
if name in self._contracts_instances.keys():
return self._contracts_instances[name]
if name in self._integrations_instances.keys():
return self._integrations_instances[name]

def _initiate_contract(
self, address: str, abi: Optional[str] = None
Expand All @@ -472,26 +496,41 @@ def _initiate_contract(
return Contract(address=address)
return Contract(address=address, abi=abi)

def handle(self) -> Self:
def _initiate_integration(
self, name: str, account: AccountAPI
) -> IntegrationFactory:
"""
Initiate the integration.
"""
return IntegrationFactory.from_name(name, sender=account)

def handle(self, account) -> Self:
"""
Handle the contracts.
"""
try:
for name, contract_data in self._contracts.items():
if isinstance(contract_data, str):
address = contract_data
self._contracts_instances[name] = self._initiate_contract(address)
elif isinstance(contract_data, list):
if len(contract_data) == 1:
address = contract_data[0]
if self._contracts:
for name, contract_data in self._contracts.items():
if isinstance(contract_data, str):
address = contract_data
self._contracts_instances[name] = self._initiate_contract(
address
)
else:
address, abi = contract_data
self._contracts_instances[name] = self._initiate_contract(
address, abi
)
elif isinstance(contract_data, list):
if len(contract_data) == 1:
address = contract_data[0]
self._contracts_instances[name] = self._initiate_contract(
address
)
else:
address, abi = contract_data
self._contracts_instances[name] = self._initiate_contract(
address, abi
)
for name in self._integrations:
self._integrations_instances[name] = self._initiate_integration(
name, account
)
except NetworkError as e:
logger.error(f"Failed to initiate contract: {e}")
raise ValueError(
Expand Down
13 changes: 13 additions & 0 deletions giza/agents/integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from ape.api import AccountAPI

import giza.agents.integrations.uniswap.uniswap as uniswap_module


class IntegrationFactory:
@staticmethod
def from_name(name: str, sender: AccountAPI) -> uniswap_module.Uniswap:
match name:
case "UniswapV3":
return uniswap_module.Uniswap(sender, version=3)
case _:
raise ValueError(f"Integration {name} not found")
Loading

0 comments on commit 09b0316

Please sign in to comment.