Skip to content

Commit

Permalink
Implement compute_swap_step function and add benchmarks
Browse files Browse the repository at this point in the history
Added the compute_swap_step function in src/utils/swap_math.rs that calculates result of swapping given parameters. Also added new benchmark tests for this function under benches/swap_math.rs and updated the version in Cargo.toml to 0.3.0. The compute_swap_step function helps in computing swap steps more efficiently in the codebase.
  • Loading branch information
shuhuiluo committed Dec 31, 2023
1 parent 24a1f11 commit 4e92c9c
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 2 deletions.
3 changes: 2 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "uniswap-v3-sdk-rs"
version = "0.2.0"
version = "0.3.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand All @@ -20,6 +20,7 @@ uniswap_v3_math = "0.4.1"

[dev-dependencies]
criterion = "0.5.1"
ethers-core = "2.0.11"

[[bench]]
name = "bit_math"
Expand All @@ -29,6 +30,10 @@ harness = false
name = "sqrt_price_math"
harness = false

[[bench]]
name = "swap_math"
harness = false

[[bench]]
name = "tick_math"
harness = false
97 changes: 97 additions & 0 deletions benches/swap_math.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
use alloy_primitives::{keccak256, I256, U256};
use alloy_sol_types::SolValue;
use criterion::{criterion_group, criterion_main, Criterion};
use ethers_core;
use uniswap_v3_math::{swap_math, utils::ruint_to_u256};
use uniswap_v3_sdk_rs::utils::compute_swap_step;

fn pseudo_random(seed: u64) -> U256 {
keccak256(seed.abi_encode()).into()
}

fn pseudo_random_128(seed: u64) -> u128 {
let s: U256 = keccak256(seed.abi_encode()).into();
u128::from_be_bytes(s.to_be_bytes::<32>()[..16].try_into().unwrap())
}

fn generate_inputs() -> Vec<(U256, U256, u128, I256, u32)> {
(0u64..100)
.map(|i| {
(
pseudo_random(i),
pseudo_random(i.pow(2)),
pseudo_random_128(i.pow(3)),
I256::from_raw(pseudo_random(i.pow(4))),
i as u32,
)
})
.collect()
}

fn compute_swap_step_benchmark(c: &mut Criterion) {
let inputs = generate_inputs();
c.bench_function("compute_swap_step", |b| {
b.iter(|| {
for (
sqrt_ratio_current_x96,
sqrt_ratio_target_x96,
liquidity,
amount_remaining,
fee_pips,
) in &inputs
{
let _ = compute_swap_step(
*sqrt_ratio_current_x96,
*sqrt_ratio_target_x96,
*liquidity,
*amount_remaining,
*fee_pips,
);
}
})
});
}

fn compute_swap_step_benchmark_ref(c: &mut Criterion) {
use ethers_core::types::{I256, U256};

let inputs: Vec<(U256, U256, u128, I256, u32)> = generate_inputs()
.into_iter()
.map(|i| {
(
ruint_to_u256(i.0),
ruint_to_u256(i.1),
i.2,
I256::from_raw(ruint_to_u256(i.3.into_raw())),
i.4,
)
})
.collect();
c.bench_function("compute_swap_step_ref", |b| {
b.iter(|| {
for (
sqrt_ratio_current_x96,
sqrt_ratio_target_x96,
liquidity,
amount_remaining,
fee_pips,
) in &inputs
{
let _ = swap_math::compute_swap_step(
*sqrt_ratio_current_x96,
*sqrt_ratio_target_x96,
*liquidity,
*amount_remaining,
*fee_pips,
);
}
})
});
}

criterion_group!(
benches,
compute_swap_step_benchmark,
compute_swap_step_benchmark_ref,
);
criterion_main!(benches);
2 changes: 2 additions & 0 deletions src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mod nearest_usable_tick;
mod position;
mod price_tick_conversions;
mod sqrt_price_math;
mod swap_math;
mod tick_math;

pub use bit_math::*;
Expand All @@ -20,6 +21,7 @@ pub use nearest_usable_tick::nearest_usable_tick;
pub use position::get_tokens_owed;
pub use price_tick_conversions::*;
pub use sqrt_price_math::*;
pub use swap_math::compute_swap_step;
pub use tick_math::*;

use alloy_primitives::U256;
Expand Down
141 changes: 141 additions & 0 deletions src/utils/swap_math.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
use super::{
get_amount_0_delta, get_amount_1_delta, get_next_sqrt_price_from_input,
get_next_sqrt_price_from_output, mul_div, mul_div_rounding_up,
};
use alloy_primitives::{I256, U256};
use uniswap_v3_math::error::UniswapV3MathError;

/// Computes the result of swapping some amount in, or amount out, given the parameters of the swap
///
/// The fee, plus the amount in, will never exceed the amount remaining if the swap's `amountSpecified` is positive
///
/// # Arguments
///
/// * `sqrt_ratio_current_x96`: The current sqrt price of the pool
/// * `sqrt_ratio_target_x96`: The price that cannot be exceeded, from which the direction of the swap is inferred
/// * `liquidity`: The usable liquidity
/// * `amount_remaining`: How much input or output amount is remaining to be swapped in/out
/// * `fee_pips`: The fee taken from the input amount, expressed in hundredths of a bip
///
/// # Returns
///
/// * `sqrt_ratio_next_x96`: The price after swapping the amount in/out, not to exceed the price target
/// * `amount_in`: The amount to be swapped in, of either token0 or token1, based on the direction of the swap
/// * `amount_out`: The amount to be received, of either token0 or token1, based on the direction of the swap
/// * `fee_amount`: The amount of input that will be taken as a fee
///
pub fn compute_swap_step(
sqrt_ratio_current_x96: U256,
sqrt_ratio_target_x96: U256,
liquidity: u128,
amount_remaining: I256,
fee_pips: u32,
) -> Result<(U256, U256, U256, U256), UniswapV3MathError> {
const MAX_FEE: U256 = U256::from_limbs([1000000, 0, 0, 0]);
let fee_pips = U256::from_limbs([fee_pips as u64, 0, 0, 0]);
let fee_complement = MAX_FEE - fee_pips;
let zero_for_one = sqrt_ratio_current_x96 >= sqrt_ratio_target_x96;
let exact_in = amount_remaining >= I256::ZERO;

let sqrt_ratio_next_x96: U256;
let mut amount_in: U256;
let mut amount_out: U256;
let fee_amount: U256;
if exact_in {
let amount_remaining_abs = amount_remaining.into_raw();
let amount_remaining_less_fee = mul_div(amount_remaining_abs, fee_complement, MAX_FEE)?;

amount_in = if zero_for_one {
get_amount_0_delta(
sqrt_ratio_target_x96,
sqrt_ratio_current_x96,
liquidity,
true,
)?
} else {
get_amount_1_delta(
sqrt_ratio_current_x96,
sqrt_ratio_target_x96,
liquidity,
true,
)?
};

if amount_remaining_less_fee >= amount_in {
sqrt_ratio_next_x96 = sqrt_ratio_target_x96;
fee_amount = mul_div_rounding_up(amount_in, fee_pips, fee_complement)?;
} else {
amount_in = amount_remaining_less_fee;
sqrt_ratio_next_x96 = get_next_sqrt_price_from_input(
sqrt_ratio_current_x96,
liquidity,
amount_in,
zero_for_one,
)?;
fee_amount = amount_remaining_abs - amount_in;
}

amount_out = if zero_for_one {
get_amount_1_delta(
sqrt_ratio_next_x96,
sqrt_ratio_current_x96,
liquidity,
false,
)?
} else {
get_amount_0_delta(
sqrt_ratio_current_x96,
sqrt_ratio_next_x96,
liquidity,
false,
)?
};
} else {
let amount_remaining_abs = (-amount_remaining).into_raw();

amount_out = if zero_for_one {
get_amount_1_delta(
sqrt_ratio_target_x96,
sqrt_ratio_current_x96,
liquidity,
false,
)?
} else {
get_amount_0_delta(
sqrt_ratio_current_x96,
sqrt_ratio_target_x96,
liquidity,
false,
)?
};

if amount_remaining_abs >= amount_out {
sqrt_ratio_next_x96 = sqrt_ratio_target_x96;
} else {
amount_out = amount_remaining_abs;
sqrt_ratio_next_x96 = get_next_sqrt_price_from_output(
sqrt_ratio_current_x96,
liquidity,
amount_out,
zero_for_one,
)?;
}

amount_in = if zero_for_one {
get_amount_0_delta(sqrt_ratio_next_x96, sqrt_ratio_current_x96, liquidity, true)?
} else {
get_amount_1_delta(sqrt_ratio_current_x96, sqrt_ratio_next_x96, liquidity, true)?
};
fee_amount = mul_div_rounding_up(amount_in, fee_pips, fee_complement)?;
}

Ok((sqrt_ratio_next_x96, amount_in, amount_out, fee_amount))
}

#[cfg(test)]
mod tests {
#[test]
fn test_compute_swap_step() {
// TODO: Add more tests
}
}

0 comments on commit 4e92c9c

Please sign in to comment.