-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
17 changed files
with
3,272 additions
and
159 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,170 +1,65 @@ | ||
import os | ||
import logging | ||
import pprint | ||
|
||
import numpy as np | ||
from addresses import ADDRESSES | ||
from dotenv import find_dotenv, load_dotenv | ||
from lp_tools import get_tick_range | ||
from mint_position import close_position, get_all_user_positions, get_mint_params | ||
from prefect import get_run_logger | ||
|
||
from giza.agents.action import action | ||
from giza.agents import AgentResult, GizaAgent | ||
from giza.agents.task import task | ||
from giza.agents import GizaAgent | ||
|
||
load_dotenv(find_dotenv()) | ||
|
||
# Here we load a custom sepolia rpc url from the environment | ||
sepolia_rpc_url = os.environ.get("SEPOLIA_RPC_URL") | ||
|
||
MODEL_ID = ... # Update with your model ID | ||
VERSION_ID = ... # Update with your version ID | ||
|
||
|
||
@task | ||
def process_data(realized_vol, dec_price_change): | ||
pct_change_sq = (100 * dec_price_change) ** 2 | ||
X = np.array([[realized_vol, pct_change_sq]]) | ||
return X | ||
|
||
|
||
# Get image | ||
@task | ||
def get_data(): | ||
# TODO: implement fetching onchain or from some other source | ||
realized_vol = 4.20 | ||
dec_price_change = 0.1 | ||
return realized_vol, dec_price_change | ||
|
||
|
||
@task | ||
def create_agent( | ||
model_id: int, version_id: int, chain: str, contracts: dict, account: str | ||
): | ||
""" | ||
Create a Giza agent for the volatility prediction model | ||
""" | ||
def transmission(): | ||
logger = logging.getLogger(__name__) | ||
id = ... | ||
version = ... | ||
account = ... | ||
realized_vol, dec_price_change = get_data() | ||
input_data = process_data(realized_vol, dec_price_change) | ||
|
||
agent = GizaAgent( | ||
contracts=contracts, | ||
id=model_id, | ||
version_id=version_id, | ||
chain=chain, | ||
integrations=["UniswapV3"], | ||
id=id, | ||
chain="ethereum:sepolia:https://sepolia.infura.io/v3/765888cfa824440c8c0feb5b492abedd", | ||
version_id=version, | ||
account=account, | ||
) | ||
return agent | ||
|
||
|
||
@task | ||
def predict(agent: GizaAgent, X: np.ndarray): | ||
""" | ||
Predict the digit in an image. | ||
Args: | ||
image (np.ndarray): Image to predict. | ||
Returns: | ||
int: Predicted digit. | ||
""" | ||
prediction = agent.predict(input_feed={"val": X}, verifiable=True, job_size="XL") | ||
return prediction | ||
|
||
|
||
@task | ||
def get_pred_val(prediction: AgentResult): | ||
""" | ||
Get the value from the prediction. | ||
Args: | ||
prediction (dict): Prediction from the model. | ||
|
||
Returns: | ||
int: Predicted value. | ||
""" | ||
# This will block the executon until the prediction has generated the proof and the proof has been verified | ||
return prediction.value[0][0] | ||
|
||
|
||
# Create Action | ||
@action | ||
def transmission( | ||
pred_model_id, | ||
pred_version_id, | ||
account="dev", | ||
chain=f"ethereum:sepolia:{sepolia_rpc_url}", | ||
): | ||
logger = get_run_logger() | ||
|
||
nft_manager_address = ADDRESSES["NonfungiblePositionManager"][11155111] | ||
tokenA_address = ADDRESSES["UNI"][11155111] | ||
tokenB_address = ADDRESSES["WETH"][11155111] | ||
pool_address = "0x287B0e934ed0439E2a7b1d5F0FC25eA2c24b64f7" | ||
user_address = "0xCBB090699E0664f0F6A4EFbC616f402233718152" | ||
|
||
pool_fee = 3000 | ||
tokenA_amount = 1000 | ||
tokenB_amount = 1000 | ||
|
||
logger.info("Fetching input data") | ||
realized_vol, dec_price_change = get_data() | ||
|
||
logger.info(f"Input data: {realized_vol}, {dec_price_change}") | ||
X = process_data(realized_vol, dec_price_change) | ||
|
||
nft_manager_abi_path = "nft_manager_abi.json" | ||
contracts = { | ||
"nft_manager": [nft_manager_address, nft_manager_abi_path], | ||
"tokenA": [tokenA_address], | ||
"tokenB": tokenB_address, | ||
"pool": pool_address, | ||
} | ||
agent = create_agent( | ||
model_id=pred_model_id, | ||
version_id=pred_version_id, | ||
chain=chain, | ||
contracts=contracts, | ||
account=account, | ||
result = agent.predict( | ||
input_feed={"val": input_data}, verifiable=True, dry_run=True | ||
) | ||
result = predict(agent, X) | ||
predicted_value = get_pred_val(result) | ||
logger.info(f"Result: {result}") | ||
|
||
logger.info(f"Result: {result}") | ||
with agent.execute() as contracts: | ||
logger.info("Executing contract") | ||
# TODO: fix below | ||
positions = get_all_user_positions(contracts.nft_manager, user_address) | ||
logger.info(f"Found the following positions: {positions}") | ||
# step 4: close all positions | ||
logger.info("Closing all open positions...") | ||
for nft_id in positions: | ||
close_position(user_address, contracts.nft_manager, nft_id) | ||
# step 4: calculate mint params | ||
logger.info("Calculating mint params...") | ||
_, curr_tick, _, _, _, _, _ = contracts.pool.slot0() | ||
tokenA_decimals = contracts.tokenA.decimals() | ||
tokenB_decimals = contracts.tokenB.decimals() | ||
# TODO: confirm input should be result and not result * 100 | ||
lower_tick, upper_tick = get_tick_range( | ||
curr_tick, predicted_value, tokenA_decimals, tokenB_decimals, pool_fee | ||
UNI_address = "0x1f9840a85d5aF5bf1D1762F925BDADdC4201F984" | ||
WETH_address = "0xfFf9976782d46CC05630D1f6eBAb18b2324d6B14" | ||
uni = contracts.UniswapV3 | ||
volatility_prediction = result.value[0] | ||
pool = uni.get_pool(UNI_address, WETH_address, fee=500) | ||
curr_price = pool.get_pool_price() | ||
lower_price = curr_price * (1 - volatility_prediction) | ||
upper_price = curr_price * (1 + volatility_prediction) | ||
amount0 = 100 | ||
amount1 = 100 | ||
agent_result = uni.mint_position( | ||
pool, lower_price, upper_price, amount0, amount1 | ||
) | ||
mint_params = get_mint_params( | ||
tokenA_address, | ||
tokenB_address, | ||
user_address, | ||
tokenA_amount, | ||
tokenB_amount, | ||
pool_fee, | ||
lower_tick, | ||
upper_tick, | ||
logger.info( | ||
f"Current price: {curr_price}, new bounds: {lower_price}, {upper_price}" | ||
) | ||
# step 5: mint new position | ||
logger.info("Minting new position...") | ||
contract_result = contracts.nft_manager.mint(mint_params) | ||
logger.info("SUCCESSFULLY MINTED A POSITION") | ||
logger.info("Contract executed") | ||
logger.info(f"Minted position: {agent_result}") | ||
|
||
logger.info(f"Contract result: {contract_result}") | ||
pprint.pprint(contract_result.__dict__) | ||
logger.info(f"Contract result: {agent_result}") | ||
logger.info("Finished") | ||
|
||
|
||
transmission(MODEL_ID, VERSION_ID) | ||
transmission() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from ape.api import AccountAPI | ||
|
||
import giza.agents.integrations.uniswap.uniswap as uniswap_module | ||
|
||
|
||
class IntegrationFactory: | ||
@staticmethod | ||
def from_name(name: str, sender: AccountAPI) -> uniswap_module.Uniswap: | ||
match name: | ||
case "UniswapV3": | ||
return uniswap_module.Uniswap(sender, version=3) | ||
case _: | ||
raise ValueError(f"Integration {name} not found") |
Oops, something went wrong.