From 08d71eb560ba0109cb9e2d52c533f0ffcd6f275c Mon Sep 17 00:00:00 2001 From: Shuhui Luo <107524008+shuhuiluo@users.noreply.github.com> Date: Tue, 9 Jan 2024 05:45:56 -0800 Subject: [PATCH] Implement pool swap functions and add relevant tests The commit introduces logic for `get_output_amount` and `get_input_amount` in the `Pool` struct for handling token swaps. It also includes a helper function, `_swap`, used to facilitate these transactions. Test cases have been added to ensure the correct functionality and reliability of these new features. --- src/entities/pool.rs | 544 +++++++++++++++++++++++++++++++++++++++---- src/utils/mod.rs | 21 +- 2 files changed, 512 insertions(+), 53 deletions(-) diff --git a/src/entities/pool.rs b/src/entities/pool.rs index f14ef48..240a9d2 100644 --- a/src/entities/pool.rs +++ b/src/entities/pool.rs @@ -1,7 +1,8 @@ use crate::prelude::*; -use alloy_primitives::{Address, B256, U256}; +use alloy_primitives::{Address, B256, I256, U256}; use num_bigint::BigUint; use once_cell::sync::Lazy; +use std::ops::Neg; use std::sync::Arc; use uniswap_sdk_core::prelude::*; @@ -21,6 +22,24 @@ pub struct Pool { _token1_price: Option>, } +struct SwapState { + amount_specified_remaining: I256, + amount_calculated: I256, + sqrt_price_x96: U256, + tick: i32, + liquidity: u128, +} + +struct StepComputations { + sqrt_price_start_x96: U256, + tick_next: i32, + initialized: bool, + sqrt_price_next_x96: U256, + amount_in: U256, + amount_out: U256, + fee_amount: U256, +} + impl Pool { /// Compute the pool address pub fn get_address( @@ -143,29 +162,216 @@ impl Pool { } } - pub async fn get_output_amount( + /// Given an input amount of a token, return the computed output amount, and a pool with state updated after the trade + /// + /// # Arguments + /// + /// * `input_amount`: The input amount for which to quote the output amount + /// * `sqrt_price_limit_x96`: The Q64.96 sqrt price limit + /// + /// returns: The output amount and the pool with updated state + /// + pub fn get_output_amount( &self, - _input_amount: CurrencyAmount, - _sqrt_price_limit_x96: Option, + input_amount: CurrencyAmount, + sqrt_price_limit_x96: Option, ) -> (CurrencyAmount, Self) { - todo!("get_output_amount") + assert!(self.involves_token(&input_amount.meta.currency), "TOKEN"); + + let zero_for_one = input_amount.meta.currency.equals(&self.token0); + + let (output_amount, sqrt_ratio_x96, liquidity, _) = self._swap( + zero_for_one, + big_int_to_i256(input_amount.quotient()), + sqrt_price_limit_x96, + ); + let output_token = if zero_for_one { + self.token1.clone() + } else { + self.token0.clone() + }; + ( + CurrencyAmount::from_raw_amount(output_token, i256_to_big_int(output_amount.neg())), + Pool::new( + self.token0.clone(), + self.token1.clone(), + self.fee, + sqrt_ratio_x96, + liquidity, + Some(self.tick_data_provider.clone()), + ), + ) } - pub async fn get_input_amount( + /// Given a desired output amount of a token, return the computed input amount and a pool with state updated after the trade + /// + /// # Arguments + /// + /// * `output_amount`: the output amount for which to quote the input amount + /// * `sqrt_price_limit_x96`: The Q64.96 sqrt price limit. If zero for one, the price cannot be less than this value + /// after the swap. If one for zero, the price cannot be greater than this value after the swap + /// + /// returns: The input amount and the pool with updated state + /// + pub fn get_input_amount( &self, - _output_amount: CurrencyAmount, - _sqrt_price_limit_x96: Option, + output_amount: CurrencyAmount, + sqrt_price_limit_x96: Option, ) -> (CurrencyAmount, Self) { - todo!("get_input_amount") + assert!(self.involves_token(&output_amount.meta.currency), "TOKEN"); + + let zero_for_one = output_amount.meta.currency.equals(&self.token1); + + let (input_amount, sqrt_ratio_x96, liquidity, _) = self._swap( + zero_for_one, + big_int_to_i256(output_amount.quotient()).neg(), + sqrt_price_limit_x96, + ); + let input_token = if zero_for_one { + self.token0.clone() + } else { + self.token1.clone() + }; + ( + CurrencyAmount::from_raw_amount(input_token, i256_to_big_int(input_amount)), + Pool::new( + self.token0.clone(), + self.token1.clone(), + self.fee, + sqrt_ratio_x96, + liquidity, + Some(self.tick_data_provider.clone()), + ), + ) } - async fn _swap( + fn _swap( &self, - _zero_for_one: bool, - _amount_specified: U256, - _sqrt_price_limit_x96: Option, - ) -> (U256, U256, u128, i32) { - todo!("swap") + zero_for_one: bool, + amount_specified: I256, + sqrt_price_limit_x96: Option, + ) -> (I256, U256, u128, i32) { + const ONE: U256 = U256::from_limbs([1, 0, 0, 0]); + let sqrt_price_limit_x96 = sqrt_price_limit_x96.unwrap_or_else(|| { + if zero_for_one { + MIN_SQRT_RATIO + ONE + } else { + MAX_SQRT_RATIO - ONE + } + }); + + if zero_for_one { + assert!(sqrt_price_limit_x96 > MIN_SQRT_RATIO, "RATIO_MIN"); + assert!(sqrt_price_limit_x96 < self.sqrt_ratio_x96, "RATIO_CURRENT"); + } else { + assert!(sqrt_price_limit_x96 < MAX_SQRT_RATIO, "RATIO_MAX"); + assert!(sqrt_price_limit_x96 > self.sqrt_ratio_x96, "RATIO_CURRENT"); + } + + let exact_input = amount_specified >= I256::ZERO; + + // keep track of swap state + let mut state = SwapState { + amount_specified_remaining: amount_specified, + amount_calculated: I256::ZERO, + sqrt_price_x96: self.sqrt_ratio_x96, + tick: self.tick_current, + liquidity: self.liquidity, + }; + + // start swap while loop + while !state.amount_specified_remaining.is_zero() + && state.sqrt_price_x96 != sqrt_price_limit_x96 + { + let mut step = StepComputations { + sqrt_price_start_x96: state.sqrt_price_x96, + tick_next: 0, + initialized: false, + sqrt_price_next_x96: U256::ZERO, + amount_in: U256::ZERO, + amount_out: U256::ZERO, + fee_amount: U256::ZERO, + }; + + step.sqrt_price_start_x96 = state.sqrt_price_x96; + // because each iteration of the while loop rounds, we can't optimize this code (relative to the smart contract) + // by simply traversing to the next available tick, we instead need to exactly replicate + (step.tick_next, step.initialized) = self + .tick_data_provider + .next_initialized_tick_within_one_word( + state.tick, + zero_for_one, + self.tick_spacing(), + ) + .unwrap(); + + if step.tick_next < MIN_TICK { + step.tick_next = MIN_TICK; + } else if step.tick_next > MAX_TICK { + step.tick_next = MAX_TICK; + } + + step.sqrt_price_next_x96 = get_sqrt_ratio_at_tick(step.tick_next).unwrap(); + ( + state.sqrt_price_x96, + step.amount_in, + step.amount_out, + step.fee_amount, + ) = compute_swap_step( + state.sqrt_price_x96, + if zero_for_one { + step.sqrt_price_next_x96.max(sqrt_price_limit_x96) + } else { + step.sqrt_price_next_x96.min(sqrt_price_limit_x96) + }, + state.liquidity, + state.amount_specified_remaining, + self.fee as u32, + ) + .unwrap(); + + if exact_input { + state.amount_specified_remaining = I256::from_raw( + state.amount_specified_remaining.into_raw() - step.amount_in - step.fee_amount, + ); + state.amount_calculated = + I256::from_raw(state.amount_calculated.into_raw() - step.amount_out); + } else { + state.amount_specified_remaining = + I256::from_raw(state.amount_specified_remaining.into_raw() + step.amount_out); + state.amount_calculated = I256::from_raw( + state.amount_calculated.into_raw() + step.amount_in + step.fee_amount, + ); + } + + if state.sqrt_price_x96 == step.sqrt_price_next_x96 { + // if the tick is initialized, run the tick transition + if step.initialized { + let mut liquidity_net = self + .tick_data_provider + .get_tick(step.tick_next) + .unwrap() + .liquidity_net; + // if we're moving leftward, we interpret liquidityNet as the opposite sign + // safe because liquidityNet cannot be type(int128).min + if zero_for_one { + liquidity_net = liquidity_net.neg(); + } + state.liquidity = add_delta(state.liquidity, liquidity_net).unwrap(); + } + state.tick = step.tick_next - zero_for_one as i32; + } else { + // recompute unless we're on a lower tick boundary (i.e. already transitioned ticks), and haven't moved + state.tick = get_tick_at_sqrt_ratio(state.sqrt_price_x96).unwrap(); + } + } + + ( + state.amount_calculated, + state.sqrt_price_x96, + state.liquidity, + state.tick, + ) } } @@ -185,7 +391,7 @@ mod tests { "USD Coin" ) }); - static _DAI: Lazy = Lazy::new(|| { + static DAI: Lazy = Lazy::new(|| { token!( 1, "0x6B175474E89094C44Da98b954EedeAC495271d0F", @@ -195,82 +401,318 @@ mod tests { ) }); + mod constructor { + use super::*; + + #[test] + #[should_panic(expected = "CHAIN_IDS")] + fn cannot_be_used_for_tokens_on_different_chains() { + let weth9 = WETH9::default().get(3).unwrap().clone(); + Pool::new( + USDC.clone(), + weth9.clone(), + FeeAmount::MEDIUM, + ONE_ETHER, + 0, + None, + ); + } + + #[test] + #[should_panic(expected = "ADDRESSES")] + fn cannot_be_given_two_of_the_same_token() { + Pool::new( + USDC.clone(), + USDC.clone(), + FeeAmount::MEDIUM, + ONE_ETHER, + 0, + None, + ); + } + + #[test] + fn works_with_valid_arguments_for_empty_pool_medium_fee() { + let weth9 = WETH9::default().get(1).unwrap().clone(); + Pool::new( + USDC.clone(), + weth9.clone(), + FeeAmount::MEDIUM, + ONE_ETHER, + 0, + None, + ); + } + + #[test] + fn works_with_valid_arguments_for_empty_pool_low_fee() { + let weth9 = WETH9::default().get(1).unwrap().clone(); + Pool::new( + USDC.clone(), + weth9.clone(), + FeeAmount::LOW, + ONE_ETHER, + 0, + None, + ); + } + + #[test] + fn works_with_valid_arguments_for_empty_pool_lowest_fee() { + let weth9 = WETH9::default().get(1).unwrap().clone(); + Pool::new( + USDC.clone(), + weth9.clone(), + FeeAmount::LOWEST, + ONE_ETHER, + 0, + None, + ); + } + + #[test] + fn works_with_valid_arguments_for_empty_pool_high_fee() { + let weth9 = WETH9::default().get(1).unwrap().clone(); + Pool::new( + USDC.clone(), + weth9.clone(), + FeeAmount::HIGH, + ONE_ETHER, + 0, + None, + ); + } + } + #[test] - #[should_panic(expected = "CHAIN_IDS")] - fn test_constructor_cannot_be_used_for_tokens_on_different_chains() { - let weth9 = WETH9::default().get(3).unwrap().clone(); - Pool::new( + fn get_address_matches_an_example() { + let result = Pool::get_address(&USDC, &DAI, FeeAmount::LOW, None, None); + assert_eq!(result, address!("6c6Bc977E13Df9b0de53b251522280BB72383700")); + } + + #[test] + fn token0_always_is_the_token_that_sorts_before() { + let pool = Pool::new( + USDC.clone(), + DAI.clone(), + FeeAmount::LOW, + encode_sqrt_ratio_x96(1, 1), + 0, + None, + ); + assert!(pool.token0.equals(&DAI.clone())); + let pool = Pool::new( + DAI.clone(), + USDC.clone(), + FeeAmount::LOW, + encode_sqrt_ratio_x96(1, 1), + 0, + None, + ); + assert!(pool.token0.equals(&DAI.clone())); + } + + #[test] + fn token1_always_is_the_token_that_sorts_after() { + let pool = Pool::new( + USDC.clone(), + DAI.clone(), + FeeAmount::LOW, + encode_sqrt_ratio_x96(1, 1), + 0, + None, + ); + assert!(pool.token1.equals(&USDC.clone())); + let pool = Pool::new( + DAI.clone(), + USDC.clone(), + FeeAmount::LOW, + encode_sqrt_ratio_x96(1, 1), + 0, + None, + ); + assert!(pool.token1.equals(&USDC.clone())); + } + + #[test] + fn token0_price_returns_price_of_token0_in_terms_of_token1() { + let mut pool = Pool::new( + USDC.clone(), + DAI.clone(), + FeeAmount::LOW, + encode_sqrt_ratio_x96(101e6 as u128, 100e18 as u128), + 0, + None, + ); + assert_eq!( + pool.token0_price().to_significant(5, Rounding::RoundHalfUp), + "1.01" + ); + let mut pool = Pool::new( + DAI.clone(), USDC.clone(), - weth9.clone(), - FeeAmount::MEDIUM, - ONE_ETHER, + FeeAmount::LOW, + encode_sqrt_ratio_x96(101e6 as u128, 100e18 as u128), 0, None, ); + assert_eq!( + pool.token0_price().to_significant(5, Rounding::RoundHalfUp), + "1.01" + ); } #[test] - #[should_panic(expected = "ADDRESSES")] - fn test_constructor_cannot_be_given_two_of_the_same_token() { - Pool::new( + fn token1_price_returns_price_of_token1_in_terms_of_token0() { + let mut pool = Pool::new( USDC.clone(), + DAI.clone(), + FeeAmount::LOW, + encode_sqrt_ratio_x96(101e6 as u128, 100e18 as u128), + 0, + None, + ); + assert_eq!( + pool.token1_price().to_significant(5, Rounding::RoundHalfUp), + "0.9901" + ); + let mut pool = Pool::new( + DAI.clone(), USDC.clone(), - FeeAmount::MEDIUM, - ONE_ETHER, + FeeAmount::LOW, + encode_sqrt_ratio_x96(101e6 as u128, 100e18 as u128), 0, None, ); + assert_eq!( + pool.token1_price().to_significant(5, Rounding::RoundHalfUp), + "0.9901" + ); } #[test] - fn test_constructor_works_with_valid_arguments_for_empty_pool_medium_fee() { - let weth9 = WETH9::default().get(1).unwrap().clone(); - Pool::new( + fn price_of_returns_price_of_token_in_terms_of_other_token() { + let mut pool = Pool::new( USDC.clone(), - weth9.clone(), - FeeAmount::MEDIUM, - ONE_ETHER, + DAI.clone(), + FeeAmount::LOW, + encode_sqrt_ratio_x96(1, 1), 0, None, ); + assert!(pool.price_of(&DAI.clone()).equal_to(&pool.token0_price())); + assert!(pool.price_of(&USDC.clone()).equal_to(&pool.token1_price())); } #[test] - fn test_constructor_works_with_valid_arguments_for_empty_pool_low_fee() { - let weth9 = WETH9::default().get(1).unwrap().clone(); - Pool::new( + #[should_panic(expected = "TOKEN")] + fn price_of_throws_if_invalid_token() { + let mut pool = Pool::new( USDC.clone(), - weth9.clone(), + DAI.clone(), FeeAmount::LOW, - ONE_ETHER, + encode_sqrt_ratio_x96(1, 1), 0, None, ); + pool.price_of(&WETH9::default().get(1).unwrap().clone()); } #[test] - fn test_constructor_works_with_valid_arguments_for_empty_pool_lowest_fee() { - let weth9 = WETH9::default().get(1).unwrap().clone(); - Pool::new( + fn chain_id_returns_token0_chain_id() { + let pool = Pool::new( + USDC.clone(), + DAI.clone(), + FeeAmount::LOW, + encode_sqrt_ratio_x96(1, 1), + 0, + None, + ); + assert_eq!(pool.chain_id(), 1); + let pool = Pool::new( + DAI.clone(), USDC.clone(), - weth9.clone(), - FeeAmount::LOWEST, - ONE_ETHER, + FeeAmount::LOW, + encode_sqrt_ratio_x96(1, 1), 0, None, ); + assert_eq!(pool.chain_id(), 1); } #[test] - fn test_constructor_works_with_valid_arguments_for_empty_pool_high_fee() { - let weth9 = WETH9::default().get(1).unwrap().clone(); - Pool::new( + fn involves_token() { + let pool = Pool::new( USDC.clone(), - weth9.clone(), - FeeAmount::HIGH, - ONE_ETHER, + DAI.clone(), + FeeAmount::LOW, + encode_sqrt_ratio_x96(1, 1), 0, None, ); + assert!(pool.involves_token(&USDC.clone())); + assert!(pool.involves_token(&DAI.clone())); + assert!(!pool.involves_token(&WETH9::default().get(1).unwrap().clone())); + } + + mod swaps { + use super::*; + + fn pool() -> Pool { + Pool::new( + USDC.clone(), + DAI.clone(), + FeeAmount::LOW, + encode_sqrt_ratio_x96(1, 1), + ONE_ETHER.into_limbs()[0] as u128, + Some(Arc::new(TickListDataProvider::new( + vec![ + Tick::new( + nearest_usable_tick(MIN_TICK, FeeAmount::LOW.tick_spacing()), + ONE_ETHER.into_limbs()[0] as u128, + ONE_ETHER.into_limbs()[0] as i128, + ), + Tick::new( + nearest_usable_tick(MAX_TICK, FeeAmount::LOW.tick_spacing()), + ONE_ETHER.into_limbs()[0] as u128, + -(ONE_ETHER.into_limbs()[0] as i128), + ), + ], + FeeAmount::LOW.tick_spacing(), + ))), + ) + } + + #[test] + fn get_output_amount_usdc_to_dai() { + let (output_amount, _) = + pool().get_output_amount(CurrencyAmount::from_raw_amount(USDC.clone(), 100), None); + assert!(output_amount.meta.currency.equals(&DAI.clone())); + assert_eq!(output_amount.quotient(), 98.into()); + } + + #[test] + fn get_output_amount_dai_to_usdc() { + let (output_amount, _) = + pool().get_output_amount(CurrencyAmount::from_raw_amount(DAI.clone(), 100), None); + assert!(output_amount.meta.currency.equals(&USDC.clone())); + assert_eq!(output_amount.quotient(), 98.into()); + } + + #[test] + fn get_input_amount_usdc_to_dai() { + let (input_amount, _) = + pool().get_input_amount(CurrencyAmount::from_raw_amount(DAI.clone(), 98), None); + assert!(input_amount.meta.currency.equals(&USDC.clone())); + assert_eq!(input_amount.quotient(), 100.into()); + } + + #[test] + fn get_input_amount_dai_to_usdc() { + let (input_amount, _) = + pool().get_input_amount(CurrencyAmount::from_raw_amount(USDC.clone(), 98), None); + assert!(input_amount.meta.currency.equals(&DAI.clone())); + assert_eq!(input_amount.quotient(), 100.into()); + } } } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index fd63243..66ed9f7 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -28,9 +28,10 @@ pub use swap_math::compute_swap_step; pub use tick_list::TickList; pub use tick_math::*; -use alloy_primitives::U256; +use alloy_primitives::{I256, U256}; use num_bigint::{BigInt, BigUint, Sign}; -use num_traits::ToBytes; +use num_traits::{Signed, ToBytes}; +use std::ops::Neg; pub const Q96: U256 = U256::from_limbs([0, 4294967296, 0, 0]); pub const Q128: U256 = U256::from_limbs([0, 0, 1, 0]); @@ -44,6 +45,14 @@ pub fn u256_to_big_int(x: U256) -> BigInt { BigInt::from_bytes_be(Sign::Plus, &x.to_be_bytes::<32>()) } +pub fn i256_to_big_int(x: I256) -> BigInt { + if x.is_positive() { + u256_to_big_int(x.into_raw()) + } else { + u256_to_big_int(x.neg().into_raw()).neg() + } +} + pub fn big_uint_to_u256(x: BigUint) -> U256 { U256::from_be_slice(&x.to_be_bytes()) } @@ -51,3 +60,11 @@ pub fn big_uint_to_u256(x: BigUint) -> U256 { pub fn big_int_to_u256(x: BigInt) -> U256 { U256::from_be_slice(&x.to_be_bytes()) } + +pub fn big_int_to_i256(x: BigInt) -> I256 { + if x.is_positive() { + I256::from_raw(big_int_to_u256(x)) + } else { + I256::from_raw(big_int_to_u256(x.neg())).neg() + } +}