From 6cbc82089b84971975b6ce7bcdf32e4397bdc7e1 Mon Sep 17 00:00:00 2001 From: adel Date: Fri, 1 Nov 2024 04:50:25 +0100 Subject: [PATCH] feat: Updated the Pragma Caller contract (#1506) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Time spent on this PR: ## Pull request type Please check the type of change your PR introduces: - [ ] Bugfix - [x] Feature - [ ] Code style update (formatting, renaming) - [ ] Refactoring (no functional changes, no api changes) - [ ] Build related changes - [ ] Documentation content changes - [ ] Other (please describe): ## What is the current behavior? Resolves #NA ## What is the new behavior? - `PragmaCaller` now calls `get_data` with `AggregationMode` - Added the `SummaryStats` contract calls for Pragma computational feeds - Updated tests with `get_data` & `MockPragmaSummaryStats` contract - - - This change is [Reviewable](https://reviewable.io/reviews/kkrt-labs/kakarot/1506) --------- Co-authored-by: Oba Co-authored-by: enitrat --- cairo/mock_pragma/Scarb.lock | 47 +-- cairo/mock_pragma/Scarb.toml | 2 +- cairo/mock_pragma/src/lib.cairo | 1 + .../mock_pragma/src/mock_pragma_oracle.cairo | 12 +- .../src/mock_pragma_summary_stats.cairo | 89 +++++ kakarot_scripts/constants.py | 2 + .../deployment/starknet_deployments.py | 4 + solidity_contracts/scripts/PragmaCaller.s.sol | 17 + .../src/CairoPrecompiles/PragmaCaller.sol | 209 +++++++++++- .../Pragma/test_pragma_precompile.py | 323 ++++++++++++++++-- 10 files changed, 610 insertions(+), 96 deletions(-) create mode 100644 cairo/mock_pragma/src/mock_pragma_summary_stats.cairo create mode 100644 solidity_contracts/scripts/PragmaCaller.s.sol diff --git a/cairo/mock_pragma/Scarb.lock b/cairo/mock_pragma/Scarb.lock index cb99d8159..2c1df4669 100644 --- a/cairo/mock_pragma/Scarb.lock +++ b/cairo/mock_pragma/Scarb.lock @@ -1,55 +1,14 @@ # Code generated by scarb DO NOT EDIT. version = 1 -[[package]] -name = "alexandria_data_structures" -version = "0.1.0" -source = "git+https://github.com/keep-starknet-strange/alexandria.git?rev=46c8d8ab9e3bfb68b70a29b3246f809cd8bf70e4#46c8d8ab9e3bfb68b70a29b3246f809cd8bf70e4" - -[[package]] -name = "alexandria_math" -version = "0.2.0" -source = "git+https://github.com/keep-starknet-strange/alexandria.git?rev=46c8d8ab9e3bfb68b70a29b3246f809cd8bf70e4#46c8d8ab9e3bfb68b70a29b3246f809cd8bf70e4" -dependencies = [ - "alexandria_data_structures", -] - -[[package]] -name = "alexandria_sorting" -version = "0.1.0" -source = "git+https://github.com/keep-starknet-strange/alexandria.git?rev=46c8d8ab9e3bfb68b70a29b3246f809cd8bf70e4#46c8d8ab9e3bfb68b70a29b3246f809cd8bf70e4" - -[[package]] -name = "alexandria_storage" -version = "0.2.0" -source = "git+https://github.com/keep-starknet-strange/alexandria.git?rev=92c3c1b4ac35a4a56c14abe992814581aee875a8#92c3c1b4ac35a4a56c14abe992814581aee875a8" - -[[package]] -name = "cubit" -version = "1.2.0" -source = "git+https://github.com/influenceth/cubit?rev=2ccb2536dffa3f15ebd38b755c1be65fde1eab0c#2ccb2536dffa3f15ebd38b755c1be65fde1eab0c" - [[package]] name = "mock_pragma" version = "0.1.0" dependencies = [ - "pragma", + "pragma_lib", ] [[package]] -name = "openzeppelin" -version = "0.7.0" -source = "git+https://github.com/OpenZeppelin/cairo-contracts.git?tag=v0.7.0#bb8c56817577b66cea9f18a241fe59726db42dd5" - -[[package]] -name = "pragma" +name = "pragma_lib" version = "1.0.0" -source = "git+https://github.com/astraly-labs/pragma-oracle?tag=v1.0.5#71e762dcb725b95a4a4966190219bca8380bc823" -dependencies = [ - "alexandria_data_structures", - "alexandria_math", - "alexandria_sorting", - "alexandria_storage", - "cubit", - "openzeppelin", -] +source = "git+https://github.com/astraly-labs/pragma-lib?tag=2.3.1#24bb4da111ae7eb00e7cf40d4f1767c86d6447cd" diff --git a/cairo/mock_pragma/Scarb.toml b/cairo/mock_pragma/Scarb.toml index ca2d702de..9209b64de 100644 --- a/cairo/mock_pragma/Scarb.toml +++ b/cairo/mock_pragma/Scarb.toml @@ -7,7 +7,7 @@ edition = "2023_10" [dependencies] starknet = "2.2.0" -pragma = { git = "https://github.com/astraly-labs/pragma-oracle", tag = "v1.0.5" } +pragma_lib = { git = "https://github.com/astraly-labs/pragma-lib", tag = "2.3.1" } [[target.starknet-contract]] casm = true diff --git a/cairo/mock_pragma/src/lib.cairo b/cairo/mock_pragma/src/lib.cairo index 282a32113..50da2e26e 100644 --- a/cairo/mock_pragma/src/lib.cairo +++ b/cairo/mock_pragma/src/lib.cairo @@ -1 +1,2 @@ mod mock_pragma_oracle; +mod mock_pragma_summary_stats; diff --git a/cairo/mock_pragma/src/mock_pragma_oracle.cairo b/cairo/mock_pragma/src/mock_pragma_oracle.cairo index 5c2118aa6..96c9baef6 100644 --- a/cairo/mock_pragma/src/mock_pragma_oracle.cairo +++ b/cairo/mock_pragma/src/mock_pragma_oracle.cairo @@ -1,8 +1,10 @@ -use pragma::entry::structs::{DataType, PragmaPricesResponse}; +use pragma_lib::types::{DataType, AggregationMode, PragmaPricesResponse}; #[starknet::interface] trait IOracle { - fn get_data_median(self: @TContractState, data_type: DataType) -> PragmaPricesResponse; + fn get_data( + self: @TContractState, data_type: DataType, aggregation_mode: AggregationMode + ) -> PragmaPricesResponse; } #[starknet::interface] @@ -20,7 +22,7 @@ trait IMockPragmaOracle { #[starknet::contract] mod MockPragmaOracle { use starknet::ContractAddress; - use pragma::entry::structs::{DataType, PragmaPricesResponse}; + use pragma_lib::types::{DataType, AggregationMode, PragmaPricesResponse}; use super::{IOracle, IMockPragmaOracle}; @@ -36,7 +38,9 @@ mod MockPragmaOracle { //! Must be compatible with Cairo 2.2.0 #[external(v0)] impl IPragmaOracleImpl of IOracle { - fn get_data_median(self: @ContractState, data_type: DataType) -> PragmaPricesResponse { + fn get_data( + self: @ContractState, data_type: DataType, aggregation_mode: AggregationMode + ) -> PragmaPricesResponse { match data_type { DataType::SpotEntry => { PragmaPricesResponse { diff --git a/cairo/mock_pragma/src/mock_pragma_summary_stats.cairo b/cairo/mock_pragma/src/mock_pragma_summary_stats.cairo new file mode 100644 index 000000000..200bcd28c --- /dev/null +++ b/cairo/mock_pragma/src/mock_pragma_summary_stats.cairo @@ -0,0 +1,89 @@ +use pragma_lib::types::{DataType, AggregationMode}; + +#[starknet::interface] +trait ISummaryStats { + fn calculate_mean( + self: @TContractState, + data_type: DataType, + start: u64, + stop: u64, + aggregation_mode: AggregationMode + ) -> (u128, u32); + + fn calculate_volatility( + self: @TContractState, + data_type: DataType, + start_tick: u64, + end_tick: u64, + num_samples: u64, + aggregation_mode: AggregationMode + ) -> (u128, u32); + + fn calculate_twap( + self: @TContractState, + data_type: DataType, + aggregation_mode: AggregationMode, + time: u64, + start_time: u64, + ) -> (u128, u32); +} + +#[starknet::contract] +mod MockPragmaSummaryStats { + use core::zeroable::Zeroable; + use starknet::ContractAddress; + use pragma_lib::types::{DataType, AggregationMode}; + use pragma_lib::abi::{IPragmaABIDispatcher, IPragmaABIDispatcherTrait}; + + use super::ISummaryStats; + + #[storage] + struct Storage { + pragma_oracle: IPragmaABIDispatcher, + } + + #[constructor] + fn constructor(ref self: ContractState, pragma_oracle_address: ContractAddress) { + assert(!pragma_oracle_address.is_zero(), 'Pragma Oracle cannot be 0'); + let pragma_oracle = IPragmaABIDispatcher { contract_address: pragma_oracle_address }; + self.pragma_oracle.write(pragma_oracle); + } + + //! Must be compatible with Cairo 2.2.0 + #[external(v0)] + impl ISummaryStatsImpl of ISummaryStats { + fn calculate_mean( + self: @ContractState, + data_type: DataType, + start: u64, + stop: u64, + aggregation_mode: AggregationMode + ) -> (u128, u32) { + let data = self.pragma_oracle.read().get_data(data_type, aggregation_mode); + (data.price, data.decimals) + } + + fn calculate_volatility( + self: @ContractState, + data_type: DataType, + start_tick: u64, + end_tick: u64, + num_samples: u64, + aggregation_mode: AggregationMode + ) -> (u128, u32) { + let data = self.pragma_oracle.read().get_data(data_type, aggregation_mode); + (data.price, data.decimals) + } + + fn calculate_twap( + self: @ContractState, + data_type: DataType, + aggregation_mode: AggregationMode, + time: u64, + start_time: u64, + ) -> (u128, u32) { + let data = self.pragma_oracle.read().get_data(data_type, aggregation_mode); + (data.price, data.decimals) + } + } +} diff --git a/kakarot_scripts/constants.py b/kakarot_scripts/constants.py index c93af189f..a7c085cce 100644 --- a/kakarot_scripts/constants.py +++ b/kakarot_scripts/constants.py @@ -264,6 +264,7 @@ class ChainId(IntEnum): {"contract_name": "EVM", "is_account_contract": False}, {"contract_name": "kakarot", "is_account_contract": False}, {"contract_name": "MockPragmaOracle", "is_account_contract": False}, + {"contract_name": "MockPragmaSummaryStats", "is_account_contract": False}, {"contract_name": "OpenzeppelinAccount", "is_account_contract": True}, {"contract_name": "replace_class", "is_account_contract": False}, {"contract_name": "StarknetToken", "is_account_contract": False}, @@ -283,6 +284,7 @@ class ChainId(IntEnum): "EVM", "kakarot", "MockPragmaOracle", + "MockPragmaSummaryStats", "OpenzeppelinAccount", "replace_class", "StarknetToken", diff --git a/kakarot_scripts/deployment/starknet_deployments.py b/kakarot_scripts/deployment/starknet_deployments.py index 2a2dae5f8..e64f9bfc3 100644 --- a/kakarot_scripts/deployment/starknet_deployments.py +++ b/kakarot_scripts/deployment/starknet_deployments.py @@ -64,6 +64,10 @@ async def deploy_starknet_contracts(account): ) starknet_deployments["Counter"] = await deploy_starknet("Counter") starknet_deployments["MockPragmaOracle"] = await deploy_starknet("MockPragmaOracle") + starknet_deployments["MockPragmaSummaryStats"] = await deploy_starknet( + "MockPragmaSummaryStats", + starknet_deployments["MockPragmaOracle"], + ) starknet_deployments["UniversalLibraryCaller"] = await deploy_starknet( "UniversalLibraryCaller" ) diff --git a/solidity_contracts/scripts/PragmaCaller.s.sol b/solidity_contracts/scripts/PragmaCaller.s.sol new file mode 100644 index 000000000..b5d9754a3 --- /dev/null +++ b/solidity_contracts/scripts/PragmaCaller.s.sol @@ -0,0 +1,17 @@ +pragma solidity ^0.8.13; + +import "forge-std/Script.sol"; +import {PragmaCaller} from "../src/CairoPrecompiles/PragmaCaller.sol"; + +contract PragmaCallerScript is Script { + function run() external { + uint256 deployerPrivateKey = vm.envUint("PRIVATE_KEY"); + uint256 pragmaOracleAddress = vm.envUint("PRAGMA_ORACLE_ADDRESS"); + uint256 pragmaSummaryStatsAddress = vm.envUint("PRAGMA_SUMMARY_STATS_ADDRESS"); + vm.startBroadcast(deployerPrivateKey); + + PragmaCaller pragmaCaller = new PragmaCaller(pragmaOracleAddress, pragmaSummaryStatsAddress); + + vm.stopBroadcast(); + } +} diff --git a/solidity_contracts/src/CairoPrecompiles/PragmaCaller.sol b/solidity_contracts/src/CairoPrecompiles/PragmaCaller.sol index 3f6a9a002..69093ce9e 100644 --- a/solidity_contracts/src/CairoPrecompiles/PragmaCaller.sol +++ b/solidity_contracts/src/CairoPrecompiles/PragmaCaller.sol @@ -1,13 +1,37 @@ -// SPDX-License-Identifier: MIT -pragma solidity >=0.8.0 <0.9.0; +pragma solidity 0.8.27; import {CairoLib} from "kakarot-lib/CairoLib.sol"; using CairoLib for uint256; +/// @notice Contract for interacting with Pragma's Oracle on Starknet. This include the main contract +/// and the summary stats contract. contract PragmaCaller { /// @dev The starknet address of the pragma oracle - uint256 pragmaOracle; + uint256 private immutable pragmaOracle; + + /// @dev The starknet address of the pragma summary stats + uint256 private immutable pragmaSummaryStats; + + /// @dev The aggregation mode used by the Oracle + enum AggregationMode { + Median, + Mean + } + + /// @dev The request data type + enum DataType { + SpotEntry, + FuturesEntry, + GenericEntry + } + + struct PragmaPricesRequest { + AggregationMode aggregationMode; + DataType dataType; + uint256 pairId; + uint256 expirationTimestamp; + } struct PragmaPricesResponse { uint256 price; @@ -17,33 +41,69 @@ contract PragmaCaller { uint256 maybe_expiration_timestamp; } - enum DataType { - SpotEntry, - FuturesEntry, - GenericEntry + struct PragmaCalculateMeanRequest { + DataType dataType; + uint256 pairId; + uint256 expirationTimestamp; + uint64 startTimestamp; + uint64 endTimestamp; + AggregationMode aggregationMode; } - struct DataRequest { + struct PragmaCalculateVolatilityRequest { DataType dataType; uint256 pairId; uint256 expirationTimestamp; + uint64 startTimestamp; + uint64 endTimestamp; + uint64 numSamples; + AggregationMode aggregationMode; } - constructor(uint256 pragmaOracleAddress) { + struct PragmaCalculateTwapRequest { + DataType dataType; + uint256 pairId; + uint256 expirationTimestamp; + AggregationMode aggregationMode; + uint64 startTimestamp; + uint64 durationInSeconds; + } + + struct PragmaSummaryStatsResponse { + uint256 price; + uint256 decimals; + } + + /// @dev Constructor sets the oracle & summary stats addresses. + constructor(uint256 pragmaOracleAddress, uint256 pragmaSummaryStatsAddress) { + require(pragmaOracleAddress != 0, "Invalid Pragma Oracle address"); + require(pragmaSummaryStatsAddress != 0, "Invalid Pragma Summary Stats address"); pragmaOracle = pragmaOracleAddress; + pragmaSummaryStats = pragmaSummaryStatsAddress; } - function getDataMedianSpot(DataRequest memory request) public view returns (PragmaPricesResponse memory response) { - // Serialize the data request into a format compatible with the expected Pragma inputs - [enumIndex, [variantValues...] - // expirationTimestamp is only used for FuturesEntry requests - skip it for SpotEntry requests and GenericEntry requests - uint256[] memory data = new uint256[](request.dataType == DataType.FuturesEntry ? 3 : 2); + /// @notice Calls the `get_data` function from the Pragma's Oracle contract on Starknet. + /// @param request The request parameters to fetch Pragma's Prices. See `PragmaPricesRequest`. + /// @return response The pragma prices response of the specified request. + function getData(PragmaPricesRequest memory request) public view returns (PragmaPricesResponse memory response) { + bool isFuturesData = request.dataType == DataType.FuturesEntry; + + // Serialize the data request into a format compatible with the expected Pragma inputs + uint256[] memory data = new uint256[](isFuturesData ? 4 : 3); data[0] = uint256(request.dataType); data[1] = request.pairId; - if (request.dataType == DataType.FuturesEntry) { + if (isFuturesData) { data[2] = request.expirationTimestamp; + data[3] = uint256(request.aggregationMode); + } else { + data[2] = uint256(request.aggregationMode); } - bytes memory returnData = pragmaOracle.staticcallCairo("get_data_median", data); + bytes memory returnData = pragmaOracle.staticcallCairo("get_data", data); + + // 160 = 5 felts for Spot/Generic data ; 192 = 6 felts for Futures. + uint256 expectedLength = isFuturesData ? 192 : 160; + require(returnData.length == expectedLength, "Invalid return data length."); assembly { // Load the values from the return data @@ -68,4 +128,123 @@ contract PragmaCaller { } return response; } + + /// @notice Calls the `calculate_mean` function from the Pragma's Summary Stats contract on Starknet. + /// @param request The request parameters of `calculate_mean`. See `PragmaCalculateMeanRequest`. + /// @return response The return of the mean calculation, i.e the price and the decimals. See `PragmaSummaryStatsResponse`. + function calculateMean(PragmaCalculateMeanRequest memory request) + public + view + returns (PragmaSummaryStatsResponse memory response) + { + // Serialize the data request into a format compatible with the expected Pragma inputs + uint256[] memory data = new uint256[](request.dataType == DataType.FuturesEntry ? 6 : 5); + data[0] = uint256(request.dataType); + data[1] = request.pairId; + if (request.dataType == DataType.FuturesEntry) { + data[2] = request.expirationTimestamp; + data[3] = request.startTimestamp; + data[4] = request.endTimestamp; + data[5] = uint256(request.aggregationMode); + } else { + data[2] = request.startTimestamp; + data[3] = request.endTimestamp; + data[4] = uint256(request.aggregationMode); + } + + bytes memory returnData = pragmaSummaryStats.staticcallCairo("calculate_mean", data); + require(returnData.length == 64, "Invalid return data length."); // 64 = 2 felts. + + assembly { + // Load the values from the return data + // returnData[0:32] is the length of the return data + let price := mload(add(returnData, 0x20)) + let decimals := mload(add(returnData, 0x40)) + + // Store the values in the response struct + mstore(response, price) + mstore(add(response, 0x20), decimals) + } + return response; + } + + /// @notice Calls the `calculate_volatility` function from the Pragma's Summary Stats contract on Starknet. + /// @param request The request parameters of `calculate_volatility`. See `PragmaCalculateVolatilityRequest`. + /// @return response The return of the volatility calculation, i.e the price and the decimals. See `PragmaSummaryStatsResponse`. + function calculateVolatility(PragmaCalculateVolatilityRequest memory request) + public + view + returns (PragmaSummaryStatsResponse memory response) + { + // Serialize the data request into a format compatible with the expected Pragma inputs + uint256[] memory data = new uint256[](request.dataType == DataType.FuturesEntry ? 7 : 6); + data[0] = uint256(request.dataType); + data[1] = request.pairId; + if (request.dataType == DataType.FuturesEntry) { + data[2] = request.expirationTimestamp; + data[3] = request.startTimestamp; + data[4] = request.endTimestamp; + data[5] = request.numSamples; + data[6] = uint256(request.aggregationMode); + } else { + data[2] = request.startTimestamp; + data[3] = request.endTimestamp; + data[4] = request.numSamples; + data[5] = uint256(request.aggregationMode); + } + + bytes memory returnData = pragmaSummaryStats.staticcallCairo("calculate_volatility", data); + require(returnData.length == 64, "Invalid return data length."); // 64 = 2 felts. + + assembly { + // Load the values from the return data + // returnData[0:32] is the length of the return data + let price := mload(add(returnData, 0x20)) + let decimals := mload(add(returnData, 0x40)) + + // Store the values in the response struct + mstore(response, price) + mstore(add(response, 0x20), decimals) + } + return response; + } + + /// @notice Calls the `calculate_twap` function from the Pragma's Summary Stats contract on Starknet. + /// @param request The request parameters of `calculate_twap`. See `PragmaCalculateTwapRequest`. + /// @return response The return of the twap calculation, i.e the price and the decimals. See `PragmaSummaryStatsResponse`. + function calculateTwap(PragmaCalculateTwapRequest memory request) + public + view + returns (PragmaSummaryStatsResponse memory response) + { + // Serialize the data request into a format compatible with the expected Pragma inputs + uint256[] memory data = new uint256[](request.dataType == DataType.FuturesEntry ? 6 : 5); + data[0] = uint256(request.dataType); + data[1] = request.pairId; + if (request.dataType == DataType.FuturesEntry) { + data[2] = request.expirationTimestamp; + data[3] = uint256(request.aggregationMode); + data[4] = request.startTimestamp; + data[5] = request.durationInSeconds; + } else { + data[2] = uint256(request.aggregationMode); + data[3] = request.startTimestamp; + data[4] = request.durationInSeconds; + } + + bytes memory returnData = pragmaSummaryStats.staticcallCairo("calculate_twap", data); + require(returnData.length == 64, "Invalid return data length."); // 64 = 2 felts. + + assembly { + // Load the values from the return data + // returnData[0:32] is the length of the return data + let price := mload(add(returnData, 0x20)) + let decimals := mload(add(returnData, 0x40)) + + // Store the values in the response struct + mstore(response, price) + mstore(add(response, 0x20), decimals) + } + return response; + } } diff --git a/tests/end_to_end/CairoPrecompiles/Pragma/test_pragma_precompile.py b/tests/end_to_end/CairoPrecompiles/Pragma/test_pragma_precompile.py index af330703e..fc12fc45a 100644 --- a/tests/end_to_end/CairoPrecompiles/Pragma/test_pragma_precompile.py +++ b/tests/end_to_end/CairoPrecompiles/Pragma/test_pragma_precompile.py @@ -1,4 +1,6 @@ -from typing import OrderedDict, Tuple +from dataclasses import dataclass +from enum import Enum +from typing import Optional, OrderedDict, Tuple import pytest import pytest_asyncio @@ -6,53 +8,91 @@ from kakarot_scripts.utils.kakarot import deploy from kakarot_scripts.utils.starknet import get_contract, get_deployments, invoke -ENTRY_TYPE_INDEX = {"SpotEntry": 0, "FutureEntry": 1, "GenericEntry": 2} +@dataclass(frozen=True) +class Entry: + key: int + expiration_timestamp: Optional[int] = None + is_generic: bool = False -def serialize_cairo_response(cairo_dict: OrderedDict) -> Tuple: + @property + def entry_type(self) -> int: + if self.is_generic: + return 2 + return 0 if self.expiration_timestamp is None else 1 + + def to_dict(self) -> dict: + if self.expiration_timestamp is None: + return {"SpotEntry": self.key} + if self.is_generic: + return {"GenericEntry": self.key} + return {"FutureEntry": (self.key, self.expiration_timestamp)} + + def serialize(self) -> Tuple[int, int, int]: + return (self.entry_type, self.key, self.expiration_timestamp or 0) + + +class AggregationMode(Enum): + MEDIAN = "Median" + MEAN = "Mean" + + def to_tuple(self) -> Tuple[str, None]: + return (self.value, None) + + def serialize(self) -> int: + return list(AggregationMode).index(self) + + +def serialize_cairo_response(cairo_response: OrderedDict) -> Tuple: """ Serialize the return data of a Cairo call to a tuple with the same format as the one returned by the Solidity contract. """ # A None value in the Cairo response is equivalent to a value 0 in the Solidity response. - return tuple(value if value is not None else 0 for value in cairo_dict.values()) + return tuple(value if value is not None else 0 for value in cairo_response.values()) -def serialize_data_type(data_type: dict) -> Tuple: +def serialize_cairo_inputs(*args) -> Tuple[int, ...]: """ - Serialize the data type to a tuple - with the same format as the one expected by the Solidity contract. - - In solidity, the serialized data type is a tuple with the following format: - (entry_type, pair_id, expiration_timestamp) - - SpotEntry and GenericEntry take one argument pair_id - - FutureEntry takes two arguments pair_id and expiration_timestamp - - The `expiration_timestamp` is set to 0 for SpotEntry and GenericEntry. + Serialize the provided arguments to the same format as the one expected by + the Solidity contract. + Each arguments must be either: + * an `Entry`, + * an `AggregationMode`, + * a `int`. """ - entry_type, query_args = next(iter(data_type.items())) - serialized_entry_type = ENTRY_TYPE_INDEX[entry_type] - - if isinstance(query_args, tuple): - pair_id, expiration_timestamp = query_args - return (serialized_entry_type, pair_id, expiration_timestamp) - else: - return (serialized_entry_type, query_args, 0) + serialized_inputs = [] + for arg in args: + if isinstance(arg, (AggregationMode, Entry)): + serialized = arg.serialize() + if isinstance(serialized, tuple): + serialized_inputs.extend(serialized) + else: + serialized_inputs.append(serialized) + elif isinstance(arg, int): + serialized_inputs.append(arg) + else: + raise TypeError( + f"Unsupported type: {type(arg)}. Must be AggregationMode, Entry, or int" + ) + return tuple(serialized_inputs) @pytest_asyncio.fixture(scope="module") async def pragma_caller(owner): + pragma_summary_stats_address = get_deployments()["MockPragmaSummaryStats"] pragma_oracle_address = get_deployments()["MockPragmaOracle"] return await deploy( "CairoPrecompiles", "PragmaCaller", pragma_oracle_address, + pragma_summary_stats_address, caller_eoa=owner.starknet_contract, ) @pytest_asyncio.fixture() -async def cairo_pragma(mocked_values, pragma_caller): +async def cairo_pragma_oracle(mocked_values, pragma_caller): await invoke("MockPragmaOracle", "set_price", *mocked_values) await invoke( "kakarot", @@ -63,15 +103,21 @@ async def cairo_pragma(mocked_values, pragma_caller): return get_contract("MockPragmaOracle") +@pytest_asyncio.fixture() +async def cairo_pragma_summary_stats(): + return get_contract("MockPragmaSummaryStats") + + @pytest.mark.asyncio(scope="module") @pytest.mark.CairoPrecompiles class TestPragmaPrecompile: @pytest.mark.parametrize( - "data_type, mocked_values", + "data_type, aggregation_mode, mocked_values", [ ( - {"SpotEntry": int.from_bytes(b"BTC/USD", byteorder="big")}, + Entry(key=int.from_bytes(b"BTC/USD", byteorder="big")), + AggregationMode.MEDIAN, ( int.from_bytes(b"BTC/USD", byteorder="big"), 70000, @@ -81,7 +127,11 @@ class TestPragmaPrecompile: ), ), ( - {"FutureEntry": (int.from_bytes(b"ETH/USD", byteorder="big"), 0)}, + Entry( + key=int.from_bytes(b"ETH/USD", byteorder="big"), + expiration_timestamp=0, + ), + AggregationMode.MEAN, ( int.from_bytes(b"ETH/USD", byteorder="big"), 4000, @@ -91,7 +141,8 @@ class TestPragmaPrecompile: ), ), ( - {"GenericEntry": int.from_bytes(b"SOL/USD", byteorder="big")}, + Entry(key=int.from_bytes(b"SOL/USD", byteorder="big"), is_generic=True), + AggregationMode.MEDIAN, ( int.from_bytes(b"SOL/USD", byteorder="big"), 180, @@ -103,11 +154,20 @@ class TestPragmaPrecompile: ], ) async def test_should_return_data_median_for_query( - self, cairo_pragma, pragma_caller, data_type, mocked_values, max_fee + self, + cairo_pragma_oracle, + pragma_caller, + data_type, + aggregation_mode, + mocked_values, + max_fee, ): - (cairo_res,) = await cairo_pragma.functions["get_data_median"].call(data_type) - solidity_input = serialize_data_type(data_type) - sol_res = await pragma_caller.getDataMedianSpot(solidity_input) + (cairo_res,) = await cairo_pragma_oracle.functions["get_data"].call( + data_type.to_dict(), + aggregation_mode.to_tuple(), + ) + solidity_input = serialize_cairo_inputs(aggregation_mode, data_type) + sol_res = await pragma_caller.getData(solidity_input) serialized_cairo_res = serialize_cairo_response(cairo_res) assert serialized_cairo_res == sol_res @@ -133,6 +193,205 @@ async def test_should_return_data_median_for_query( assert res_maybe_expiration_timestamp == ( # behavior coded inside the mock mocked_last_updated_timestamp + 1000 - if data_type.get("FutureEntry") + if data_type.expiration_timestamp is not None else 0 ) + + @pytest.mark.parametrize( + "data_type, aggregation_mode, mocked_values", + [ + ( + Entry(key=int.from_bytes(b"BTC/USD", byteorder="big")), + AggregationMode.MEDIAN, + ( + int.from_bytes(b"BTC/USD", byteorder="big"), + 70000, + 18, + 1717143838, + 1, + ), + ), + ( + Entry( + key=int.from_bytes(b"ETH/USD", byteorder="big"), + expiration_timestamp=0, + ), + AggregationMode.MEAN, + ( + int.from_bytes(b"ETH/USD", byteorder="big"), + 4000, + 18, + 1717143838, + 1, + ), + ), + ], + ) + async def test_should_get_mean_for_query( + self, + cairo_pragma_oracle, + cairo_pragma_summary_stats, + pragma_caller, + data_type, + aggregation_mode, + mocked_values, + max_fee, + ): + (cairo_res,) = await cairo_pragma_summary_stats.functions[ + "calculate_mean" + ].call( + data_type.to_dict(), + 0, + 0, + aggregation_mode.to_tuple(), + ) + solidity_input = serialize_cairo_inputs(data_type, 0, 0, aggregation_mode) + sol_res = await pragma_caller.calculateMean(solidity_input) + assert cairo_res == sol_res + + ( + res_price, + res_decimals, + ) = sol_res + ( + _, + mocked_price, + mocked_decimals, + _, + _, + ) = mocked_values + assert res_price == mocked_price + assert res_decimals == mocked_decimals + + @pytest.mark.parametrize( + "data_type, aggregation_mode, mocked_values", + [ + ( + Entry(key=int.from_bytes(b"BTC/USD", byteorder="big")), + AggregationMode.MEDIAN, + ( + int.from_bytes(b"BTC/USD", byteorder="big"), + 70000, + 18, + 1717143838, + 1, + ), + ), + ( + Entry( + key=int.from_bytes(b"ETH/USD", byteorder="big"), + expiration_timestamp=0, + ), + AggregationMode.MEAN, + ( + int.from_bytes(b"ETH/USD", byteorder="big"), + 4000, + 18, + 1717143838, + 1, + ), + ), + ], + ) + async def test_should_get_volatility_for_query( + self, + cairo_pragma_oracle, + cairo_pragma_summary_stats, + pragma_caller, + data_type, + aggregation_mode, + mocked_values, + max_fee, + ): + (cairo_res,) = await cairo_pragma_summary_stats.functions[ + "calculate_volatility" + ].call( + data_type.to_dict(), + 0, + 0, + 0, + aggregation_mode.to_tuple(), + ) + solidity_input = serialize_cairo_inputs(data_type, 0, 0, 0, aggregation_mode) + sol_res = await pragma_caller.calculateVolatility(solidity_input) + assert cairo_res == sol_res + + ( + res_price, + res_decimals, + ) = sol_res + ( + _, + mocked_price, + mocked_decimals, + _, + _, + ) = mocked_values + assert res_price == mocked_price + assert res_decimals == mocked_decimals + + @pytest.mark.parametrize( + "data_type, aggregation_mode, mocked_values", + [ + ( + Entry(key=int.from_bytes(b"BTC/USD", byteorder="big")), + AggregationMode.MEDIAN, + ( + int.from_bytes(b"BTC/USD", byteorder="big"), + 70000, + 18, + 1717143838, + 1, + ), + ), + ( + Entry( + key=int.from_bytes(b"ETH/USD", byteorder="big"), + expiration_timestamp=0, + ), + AggregationMode.MEAN, + ( + int.from_bytes(b"ETH/USD", byteorder="big"), + 4000, + 18, + 1717143838, + 1, + ), + ), + ], + ) + async def test_should_get_twap_for_query( + self, + cairo_pragma_oracle, + cairo_pragma_summary_stats, + pragma_caller, + data_type, + aggregation_mode, + mocked_values, + max_fee, + ): + (cairo_res,) = await cairo_pragma_summary_stats.functions[ + "calculate_twap" + ].call( + data_type.to_dict(), + aggregation_mode.to_tuple(), + 0, + 0, + ) + solidity_input = serialize_cairo_inputs(data_type, aggregation_mode, 0, 0) + sol_res = await pragma_caller.calculateTwap(solidity_input) + assert cairo_res == sol_res + + ( + res_price, + res_decimals, + ) = sol_res + ( + _, + mocked_price, + mocked_decimals, + _, + _, + ) = mocked_values + assert res_price == mocked_price + assert res_decimals == mocked_decimals