diff --git a/Cargo.toml b/Cargo.toml index 66e7e806c..90489e427 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,6 @@ default = [ "evm", "print_txn_corpus", "full_trace", - # "flashloan_v2", ] evm = [] cmp = [] diff --git a/integration_test.py b/integration_test.py index 63764a4fc..1b144ac1f 100644 --- a/integration_test.py +++ b/integration_test.py @@ -56,7 +56,7 @@ def test_one(path): start_time = time.time() cmd = [ TIMEOUT_BIN, - "3m", + "5s", "./target/release/ityfuzz", "evm", "-t", @@ -64,12 +64,17 @@ def test_one(path): "-f", "--panic-on-bug", ] - print(" ".join(cmd)) # exit(0) if "concolic" in path: cmd.append("--concolic --concolic-caller") + if "taint" in path: + cmd.append("--sha3-bypass") + + + print(" ".join(cmd)) + p = subprocess.run( " ".join(cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True ) diff --git a/src/evm/concolic/concolic_host.rs b/src/evm/concolic/concolic_host.rs index d4d4d7ee1..f2a08177f 100644 --- a/src/evm/concolic/concolic_host.rs +++ b/src/evm/concolic/concolic_host.rs @@ -2,7 +2,6 @@ use std::{ borrow::Borrow, collections::{HashMap, HashSet}, fmt::{Debug, Display}, - marker::PhantomData, ops::{Add, Div, Mul, Not, Sub}, rc::Rc, sync::{Arc, Mutex, RwLock}, @@ -11,11 +10,7 @@ use std::{ use bytes::Bytes; use itertools::Itertools; use lazy_static::lazy_static; -use libafl::{ - prelude::{HasMetadata, Input}, - schedulers::Scheduler, - state::{HasCorpus, State}, -}; +use libafl::{prelude::HasMetadata, schedulers::Scheduler}; use revm_interpreter::Interpreter; use serde::{Deserialize, Serialize}; use tracing::{debug, error}; @@ -34,14 +29,12 @@ use crate::{ concolic::expr::{simplify, ConcolicOp, Expr}, corpus_initializer::SourceMapMap, host::FuzzHost, - input::{ConciseEVMInput, EVMInput, EVMInputT}, + input::EVMInput, middlewares::middleware::{Middleware, MiddlewareType, MiddlewareType::Concolic}, srcmap::parser::SourceMapLocation, - types::{as_u64, is_zero, EVMAddress, EVMU256}, + types::{as_u64, is_zero, EVMAddress, EVMFuzzState, EVMU256}, }, - generic_vm::vm_state::VMStateT, input::VMInputT, - state::{HasCaller, HasCurrentInputIdx, HasItyState}, }; lazy_static! { @@ -566,7 +559,7 @@ pub struct ConcolicCallCtx { } #[derive(Debug, Serialize, Deserialize)] -pub struct ConcolicHost { +pub struct ConcolicHost { pub symbolic_stack: Vec>>, pub symbolic_memory: SymbolicMemory, pub symbolic_state: HashMap>>, @@ -576,14 +569,12 @@ pub struct ConcolicHost { pub ctxs: Vec, // For current PC, the number of times it has been visited - pub phantom: PhantomData<(I, VS)>, - pub num_threads: usize, pub call_depth: usize, } #[allow(clippy::vec_box)] -impl ConcolicHost { +impl ConcolicHost { pub fn new(testcase_ref: Rc, num_threads: usize) -> Self { Self { symbolic_stack: Vec::new(), @@ -592,7 +583,6 @@ impl ConcolicHost { input_bytes: Self::construct_input_from_abi(testcase_ref.get_data_abi().expect("data abi not found")), constraints: vec![], testcase_ref, - phantom: Default::default(), ctxs: vec![], num_threads, call_depth: 0, @@ -710,21 +700,11 @@ impl ConcolicHost { } } -impl Middleware for ConcolicHost +impl Middleware for ConcolicHost where - I: Input + VMInputT + EVMInputT + 'static, - VS: VMStateT, - S: State - + HasCaller - + HasCorpus - + HasItyState - + HasMetadata - + HasCurrentInputIdx - + Debug - + Clone, - SC: Scheduler + Clone, + SC: Scheduler + Clone, { - unsafe fn on_step(&mut self, interp: &mut Interpreter, _host: &mut FuzzHost, state: &mut S) { + unsafe fn on_step(&mut self, interp: &mut Interpreter, _host: &mut FuzzHost, state: &mut EVMFuzzState) { macro_rules! fast_peek { ($idx:expr) => { interp.stack.data()[interp.stack.len() - 1 - $idx] @@ -1389,8 +1369,8 @@ where unsafe fn on_return( &mut self, _interp: &mut Interpreter, - _host: &mut FuzzHost, - _state: &mut S, + _host: &mut FuzzHost, + _state: &mut EVMFuzzState, _by: &Bytes, ) { self.pop_ctx(); diff --git a/src/evm/contract_utils.rs b/src/evm/contract_utils.rs index 4972ebfdf..4c9314642 100644 --- a/src/evm/contract_utils.rs +++ b/src/evm/contract_utils.rs @@ -37,7 +37,7 @@ use self::crypto::{digest::Digest, sha3::Sha3}; use super::{ blaz::{is_bytecode_similar_lax, is_bytecode_similar_strict_ranking}, host::FuzzHost, - input::{ConciseEVMInput, EVMInput}, + input::ConciseEVMInput, middlewares::cheatcode::{Cheatcode, CHEATCODE_ADDRESS}, types::ProjectSourceMapTy, vm::{EVMExecutor, EVMState}, @@ -748,7 +748,7 @@ impl ContractLoader { fn get_vm_with_cheatcode( deployer: EVMAddress, ) -> ( - EVMExecutor>, + EVMExecutor>, EVMFuzzState, ) { let mut state: EVMFuzzState = FuzzState::new(0); diff --git a/src/evm/corpus_initializer.rs b/src/evm/corpus_initializer.rs index 3978c825d..226f726f8 100644 --- a/src/evm/corpus_initializer.rs +++ b/src/evm/corpus_initializer.rs @@ -66,12 +66,12 @@ where SC: ABIScheduler + Clone, ISC: Scheduler, { - executor: &'a mut EVMExecutor, + executor: &'a mut EVMExecutor, scheduler: SC, infant_scheduler: ISC, state: &'a mut EVMFuzzState, #[cfg(feature = "use_presets")] - presets: Vec<&'a dyn Preset>, + presets: Vec<&'a dyn Preset>, work_dir: String, } @@ -178,7 +178,7 @@ where ISC: Scheduler, { pub fn new( - executor: &'a mut EVMExecutor, + executor: &'a mut EVMExecutor, scheduler: SC, infant_scheduler: ISC, state: &'a mut EVMFuzzState, @@ -196,7 +196,7 @@ where } #[cfg(feature = "use_presets")] - pub fn register_preset(&mut self, preset: &'a dyn Preset) { + pub fn register_preset(&mut self, preset: &'a dyn Preset) { self.presets.push(preset); } @@ -328,7 +328,6 @@ where .insert(contract.deployed_address, build_artifact.clone()); } - #[cfg(feature = "flashloan_v2")] { handle_contract_insertion!( self.state, @@ -454,9 +453,7 @@ where step: false, env: artifacts.initial_env.clone(), access_pattern: Rc::new(RefCell::new(AccessPattern::new())), - #[cfg(feature = "flashloan_v2")] liquidation_percent: 0, - #[cfg(feature = "flashloan_v2")] input_type: EVMInputTy::ABI, direct_data: Default::default(), randomness: vec![0], diff --git a/src/evm/feedbacks.rs b/src/evm/feedbacks.rs index 873156221..450155201 100644 --- a/src/evm/feedbacks.rs +++ b/src/evm/feedbacks.rs @@ -10,73 +10,51 @@ use libafl::{ executors::ExitKind, feedbacks::Feedback, observers::ObserversTuple, - prelude::{HasCorpus, HasMetadata, HasRand, State, Testcase, UsesInput}, + prelude::Testcase, schedulers::Scheduler, - state::HasClientPerfMonitor, Error, }; use libafl_bolts::Named; +use super::{input::EVMInput, types::EVMFuzzState}; use crate::{ - evm::{ - input::{ConciseEVMInput, EVMInputT}, - middlewares::sha3_bypass::Sha3TaintAnalysis, - types::EVMAddress, - vm::EVMExecutor, - }, + evm::{input::ConciseEVMInput, middlewares::sha3_bypass::Sha3TaintAnalysis, vm::EVMExecutor}, generic_vm::vm_state::VMStateT, input::VMInputT, - state::{HasCaller, HasCurrentInputIdx, HasItyState}, }; /// A wrapper around a feedback that also performs sha3 taint analysis /// when the feedback is interesting. #[allow(clippy::type_complexity)] -pub struct Sha3WrappedFeedback +pub struct Sha3WrappedFeedback where - S: State + HasCorpus + HasCaller + Debug + Clone + HasClientPerfMonitor + 'static, - I: VMInputT + EVMInputT, VS: VMStateT, - F: Feedback, - SC: Scheduler + Clone, + F: Feedback, + SC: Scheduler + Clone, { pub inner_feedback: Box, pub sha3_taints: Rc>, - pub evm_executor: Rc>>, + pub evm_executor: Rc>>, pub enabled: bool, } -impl Feedback for Sha3WrappedFeedback +impl Feedback for Sha3WrappedFeedback where - S: State - + HasRand - + HasCorpus - + HasItyState - + HasMetadata - + HasCaller - + HasCurrentInputIdx - + HasClientPerfMonitor - + Default - + Clone - + Debug - + UsesInput - + 'static, - I: VMInputT + EVMInputT + 'static, VS: VMStateT + 'static, - F: Feedback, - SC: Scheduler + Clone + 'static, + F: Feedback, + SC: Scheduler + Clone + 'static, { fn is_interesting( &mut self, - state: &mut S, + state: &mut EVMFuzzState, manager: &mut EM, - input: &S::Input, + input: &EVMInput, observers: &OT, exit_kind: &ExitKind, ) -> Result where - EM: EventFirer, - OT: ObserversTuple, + EM: EventFirer, + OT: ObserversTuple, { // checks if the inner feedback is interesting if self.enabled { @@ -110,30 +88,28 @@ where #[allow(unused_variables)] fn append_metadata( &mut self, - state: &mut S, + state: &mut EVMFuzzState, observers: &OT, - testcase: &mut Testcase, + testcase: &mut Testcase, ) -> Result<(), Error> where - OT: ObserversTuple, + OT: ObserversTuple, { self.inner_feedback.as_mut().append_metadata(state, observers, testcase) } } -impl Sha3WrappedFeedback +impl Sha3WrappedFeedback where - S: State + HasCorpus + HasCaller + Debug + Clone + HasClientPerfMonitor + 'static, - I: VMInputT + EVMInputT, VS: VMStateT, - F: Feedback, - SC: Scheduler + Clone, + F: Feedback, + SC: Scheduler + Clone, { #[allow(clippy::type_complexity)] pub(crate) fn new( inner_feedback: F, sha3_taints: Rc>, - evm_executor: Rc>>, + evm_executor: Rc>>, enabled: bool, ) -> Self { Self { @@ -145,26 +121,22 @@ where } } -impl Named for Sha3WrappedFeedback +impl Named for Sha3WrappedFeedback where - S: State + HasCorpus + HasCaller + Debug + Clone + HasClientPerfMonitor + 'static, - I: VMInputT + EVMInputT, VS: VMStateT, - F: Feedback, - SC: Scheduler + Clone, + F: Feedback, + SC: Scheduler + Clone, { fn name(&self) -> &str { todo!() } } -impl Debug for Sha3WrappedFeedback +impl Debug for Sha3WrappedFeedback where - S: State + HasCorpus + HasCaller + Debug + Clone + HasClientPerfMonitor + 'static, - I: VMInputT + EVMInputT, VS: VMStateT, - F: Feedback, - SC: Scheduler + Clone, + F: Feedback, + SC: Scheduler + Clone, { fn fmt(&self, _f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { todo!() diff --git a/src/evm/host.rs b/src/evm/host.rs index 28f6d1f2c..008783d4a 100644 --- a/src/evm/host.rs +++ b/src/evm/host.rs @@ -16,10 +16,7 @@ use alloy_dyn_abi::DynSolType; use alloy_sol_types::SolValue; use bytes::Bytes; use itertools::Itertools; -use libafl::{ - prelude::{HasCorpus, HasMetadata, HasRand, Scheduler, UsesInput}, - state::State, -}; +use libafl::prelude::{HasMetadata, Scheduler}; use revm::precompile::{Precompile, Precompiles}; use revm_interpreter::{ analysis::to_analysed, @@ -68,6 +65,7 @@ use super::{ ERROR_PREFIX, REVERT_PREFIX, }, + types::EVMFuzzState, vm::{IS_FAST_CALL, MEM_LIMIT, SETCODE_ONLY}, }; use crate::{ @@ -75,7 +73,7 @@ use crate::{ abi::{get_abi_type_boxed, register_abi_instance}, contract_utils::extract_sig_from_contract, corpus_initializer::ABIMap, - input::{ConciseEVMInput, EVMInput, EVMInputT, EVMInputTy}, + input::{EVMInput, EVMInputTy}, middlewares::middleware::{add_corpus, CallMiddlewareReturn, Middleware, MiddlewareType}, mutator::AccessPattern, onchain::{ @@ -85,11 +83,10 @@ use crate::{ types::{as_u64, generate_random_address, is_zero, EVMAddress, EVMU256}, vm::{is_reverted_or_control_leak, EVMState, SinglePostExecution, IN_DEPLOY, IS_FAST_CALL_STATIC}, }, - generic_vm::{vm_executor::MAP_SIZE, vm_state::VMStateT}, + generic_vm::vm_executor::MAP_SIZE, handle_contract_insertion, - input::VMInputT, invoke_middlewares, - state::{HasCaller, HasHashToAddress, HasItyState}, + state::{HasCaller, HasHashToAddress}, state_input::StagedVMState, }; @@ -162,12 +159,9 @@ pub fn is_precompile(address: EVMAddress, num_of_precompiles: usize) -> bool { } #[allow(clippy::type_complexity)] -pub struct FuzzHost +pub struct FuzzHost where - S: State + HasCorpus + HasCaller + Debug + Clone + 'static, - I: VMInputT + EVMInputT, - VS: VMStateT, - SC: Scheduler + Clone, + SC: Scheduler + Clone, { pub evmstate: EVMState, // these are internal to the host @@ -180,11 +174,11 @@ where pub pc_to_create: HashMap<(EVMAddress, usize), usize>, pub pc_to_call_hash: HashMap<(EVMAddress, usize, usize), HashSet>>, pub middlewares_enabled: bool, - pub middlewares: Rc>>>>>, + pub middlewares: Rc>>>>>, pub coverage_changed: bool, - pub flashloan_middleware: Option>>>, + pub flashloan_middleware: Option>>, pub middlewares_latent_call_actions: Vec, @@ -244,12 +238,9 @@ where pub expected_calls: ExpectedCallTracker, } -impl Debug for FuzzHost +impl Debug for FuzzHost where - S: State + HasCorpus + HasCaller + Debug + Clone + 'static, - I: VMInputT + EVMInputT, - VS: VMStateT, - SC: Scheduler + Clone, + SC: Scheduler + Clone, { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("FuzzHost") @@ -268,12 +259,9 @@ where } // all clones would not include middlewares and states -impl Clone for FuzzHost +impl Clone for FuzzHost where - S: State + HasCorpus + HasCaller + Debug + Clone + 'static, - I: VMInputT + EVMInputT, - VS: VMStateT, - SC: Scheduler + Clone, + SC: Scheduler + Clone, { fn clone(&self) -> Self { Self { @@ -331,21 +319,9 @@ const UNBOUND_CALL_THRESHOLD: usize = 3; // unbounded const CONTROL_LEAK_THRESHOLD: usize = 10; -impl FuzzHost +impl FuzzHost where - S: State - + HasRand - + HasCaller - + Debug - + Clone - + HasCorpus - + HasMetadata - + HasItyState - + UsesInput - + 'static, - I: VMInputT + EVMInputT + 'static, - VS: VMStateT, - SC: Scheduler + Clone, + SC: Scheduler + Clone, { pub fn new(scheduler: SC, workdir: String) -> Self { Self { @@ -398,25 +374,27 @@ where } /// custom spec id run_inspect - pub fn run_inspect(&mut self, interp: &mut Interpreter, state: &mut S) -> InstructionResult { + pub fn run_inspect(&mut self, interp: &mut Interpreter, state: &mut EVMFuzzState) -> InstructionResult { match self.spec_id { - SpecId::LATEST => interp.run_inspect::, LatestSpec>(self, state), - SpecId::FRONTIER => interp.run_inspect::, FrontierSpec>(self, state), - SpecId::HOMESTEAD => interp.run_inspect::, HomesteadSpec>(self, state), - SpecId::TANGERINE => interp.run_inspect::, TangerineSpec>(self, state), - SpecId::SPURIOUS_DRAGON => interp.run_inspect::, SpuriousDragonSpec>(self, state), - SpecId::BYZANTIUM => interp.run_inspect::, ByzantiumSpec>(self, state), + SpecId::LATEST => interp.run_inspect::, LatestSpec>(self, state), + SpecId::FRONTIER => interp.run_inspect::, FrontierSpec>(self, state), + SpecId::HOMESTEAD => interp.run_inspect::, HomesteadSpec>(self, state), + SpecId::TANGERINE => interp.run_inspect::, TangerineSpec>(self, state), + SpecId::SPURIOUS_DRAGON => { + interp.run_inspect::, SpuriousDragonSpec>(self, state) + } + SpecId::BYZANTIUM => interp.run_inspect::, ByzantiumSpec>(self, state), SpecId::CONSTANTINOPLE | SpecId::PETERSBURG => { - interp.run_inspect::, PetersburgSpec>(self, state) + interp.run_inspect::, PetersburgSpec>(self, state) } - SpecId::ISTANBUL => interp.run_inspect::, IstanbulSpec>(self, state), + SpecId::ISTANBUL => interp.run_inspect::, IstanbulSpec>(self, state), SpecId::MUIR_GLACIER | SpecId::BERLIN => { - interp.run_inspect::, BerlinSpec>(self, state) + interp.run_inspect::, BerlinSpec>(self, state) } - SpecId::LONDON => interp.run_inspect::, LondonSpec>(self, state), - SpecId::MERGE => interp.run_inspect::, MergeSpec>(self, state), - SpecId::SHANGHAI => interp.run_inspect::, ShanghaiSpec>(self, state), - _ => interp.run_inspect::, LatestSpec>(self, state), + SpecId::LONDON => interp.run_inspect::, LondonSpec>(self, state), + SpecId::MERGE => interp.run_inspect::, MergeSpec>(self, state), + SpecId::SHANGHAI => interp.run_inspect::, ShanghaiSpec>(self, state), + _ => interp.run_inspect::, LatestSpec>(self, state), } } @@ -425,13 +403,13 @@ where self.middlewares.deref().borrow_mut().clear(); } - pub fn add_middlewares(&mut self, middlewares: Rc>>) { + pub fn add_middlewares(&mut self, middlewares: Rc>>) { self.middlewares_enabled = true; // let ty = middlewares.deref().borrow().get_type(); self.middlewares.deref().borrow_mut().push(middlewares); } - pub fn remove_middlewares(&mut self, middlewares: Rc>>) { + pub fn remove_middlewares(&mut self, middlewares: Rc>>) { let ty = middlewares.deref().borrow().get_type(); self.middlewares .deref() @@ -446,14 +424,11 @@ where .retain(|x| x.deref().borrow().get_type() != *ty); } - pub fn add_flashloan_middleware(&mut self, middlware: Flashloan) { + pub fn add_flashloan_middleware(&mut self, middlware: Flashloan) { self.flashloan_middleware = Some(Rc::new(RefCell::new(middlware))); } - pub fn initialize(&mut self, state: &S) - where - S: HasHashToAddress, - { + pub fn initialize(&mut self, state: &EVMFuzzState) { self.hash_to_address = state.get_hash_to_address().clone(); for key in self.hash_to_address.keys() { let addresses = self.hash_to_address.get(key).unwrap(); @@ -514,7 +489,7 @@ where self.setcode_data.clear(); } - pub fn set_code(&mut self, address: EVMAddress, mut code: Bytecode, state: &mut S) { + pub fn set_code(&mut self, address: EVMAddress, mut code: Bytecode, state: &mut EVMFuzzState) { unsafe { invoke_middlewares!(self, None, state, on_insert, &mut code, address); } @@ -522,7 +497,12 @@ where .insert(address, Arc::new(BytecodeLocked::try_from(to_analysed(code)).unwrap())); } - pub fn find_static_call_read_slot(&self, _address: EVMAddress, _data: Bytes, _state: &mut S) -> Vec { + pub fn find_static_call_read_slot( + &self, + _address: EVMAddress, + _data: Bytes, + _state: &mut EVMFuzzState, + ) -> Vec { vec![] // let call = Contract::new_with_context_not_cloned::( // data, @@ -574,7 +554,7 @@ where input: &mut CallInputs, interp: &mut Interpreter, (out_offset, out_len): (usize, usize), - state: &mut S, + state: &mut EVMFuzzState, ) -> (InstructionResult, Gas, Bytes) { macro_rules! push_interp { () => {{ @@ -737,7 +717,11 @@ where res } - fn call_forbid_control_leak(&mut self, input: &mut CallInputs, state: &mut S) -> (InstructionResult, Gas, Bytes) { + fn call_forbid_control_leak( + &mut self, + input: &mut CallInputs, + state: &mut EVMFuzzState, + ) -> (InstructionResult, Gas, Bytes) { let mut hash = input.input.to_vec(); hash.resize(4, 0); // if there is code, then call the code @@ -760,7 +744,11 @@ where (Revert, Gas::new(0), Bytes::new()) } - fn call_precompile(&mut self, input: &mut CallInputs, _state: &mut S) -> (InstructionResult, Gas, Bytes) { + fn call_precompile( + &mut self, + input: &mut CallInputs, + _state: &mut EVMFuzzState, + ) -> (InstructionResult, Gas, Bytes) { let precompile = self .precompiles .get(&input.contract) @@ -980,23 +968,11 @@ macro_rules! invoke_middlewares { }; } -impl Host for FuzzHost +impl Host for FuzzHost where - S: State - + HasRand - + HasCaller - + Debug - + Clone - + HasCorpus - + HasMetadata - + HasItyState - + UsesInput - + 'static, - I: VMInputT + EVMInputT + 'static, - VS: VMStateT, - SC: Scheduler + Clone, + SC: Scheduler + Clone, { - fn step(&mut self, interp: &mut Interpreter, state: &mut S) -> InstructionResult { + fn step(&mut self, interp: &mut Interpreter, state: &mut EVMFuzzState) -> InstructionResult { unsafe { // debug!("pc: {}", interp.program_counter()); // debug!("{:?}", *interp.instruction_pointer); @@ -1061,7 +1037,7 @@ where WRITE_MAP[process_rw_key!(key)] = compressed_value; let res = - as Host>::sload(self, interp.contract.address, fast_peek!(0)); + as Host>::sload(self, interp.contract.address, fast_peek!(0)); let value_changed = res.expect("sload failed").0 != value; let idx = interp.program_counter() % MAP_SIZE; @@ -1160,7 +1136,12 @@ where Continue } - fn step_end(&mut self, _interp: &mut Interpreter, _ret: InstructionResult, _: &mut S) -> InstructionResult { + fn step_end( + &mut self, + _interp: &mut Interpreter, + _ret: InstructionResult, + _: &mut EVMFuzzState, + ) -> InstructionResult { Continue } @@ -1295,7 +1276,7 @@ where fn create( &mut self, inputs: &mut CreateInputs, - state: &mut S, + state: &mut EVMFuzzState, ) -> (InstructionResult, Option, Gas, Bytes) { if unsafe { IN_DEPLOY } { // todo: use nonce + hash instead @@ -1353,10 +1334,7 @@ where parsed_abi = abis; } // notify flashloan and blacklisting flashloan addresses - #[cfg(feature = "flashloan_v2")] - { - handle_contract_insertion!(state, self, r_addr, parsed_abi); - } + handle_contract_insertion!(state, self, r_addr, parsed_abi); parsed_abi.iter().filter(|v| !v.is_constructor).for_each(|abi| { #[cfg(not(feature = "fuzz_static"))] @@ -1379,9 +1357,7 @@ where env: Default::default(), access_pattern: Rc::new(RefCell::new(AccessPattern::new())), - #[cfg(feature = "flashloan_v2")] liquidation_percent: 0, - #[cfg(feature = "flashloan_v2")] input_type: EVMInputTy::ABI, direct_data: Default::default(), randomness: vec![0], @@ -1404,7 +1380,7 @@ where input: &mut CallInputs, interp: &mut Interpreter, output_info: (usize, usize), - state: &mut S, + state: &mut EVMFuzzState, ) -> (InstructionResult, Gas, Bytes) { self.apply_prank(&interp.contract().caller, input); self.call_depth += 1; diff --git a/src/evm/input.rs b/src/evm/input.rs index d28b476e7..dfd705f38 100644 --- a/src/evm/input.rs +++ b/src/evm/input.rs @@ -68,7 +68,6 @@ pub trait EVMInputT { fn set_txn_value(&mut self, v: EVMU256); /// Get input type - #[cfg(feature = "flashloan_v2")] fn get_input_type(&self) -> EVMInputTy; /// Get additional random bytes for mutator @@ -79,12 +78,10 @@ pub trait EVMInputT { /// Get the percentage of the token amount in all callers' account to /// liquidate - #[cfg(feature = "flashloan_v2")] fn get_liquidation_percent(&self) -> u8; /// Set the percentage of the token amount in all callers' account to /// liquidate - #[cfg(feature = "flashloan_v2")] fn set_liquidation_percent(&mut self, v: u8); fn get_repeat(&self) -> usize; @@ -94,7 +91,6 @@ pub trait EVMInputT { #[derive(Serialize, Deserialize, Clone)] pub struct EVMInput { /// Input type - #[cfg(feature = "flashloan_v2")] pub input_type: EVMInputTy, /// Caller address @@ -128,7 +124,6 @@ pub struct EVMInput { pub access_pattern: Rc>, /// Percentage of the token amount in all callers' account to liquidate - #[cfg(feature = "flashloan_v2")] pub liquidation_percent: u8, /// If ABI is empty, use direct data, which is the raw input data @@ -145,7 +140,6 @@ pub struct EVMInput { #[derive(Serialize, Deserialize, Clone, Debug, Default)] pub struct ConciseEVMInput { /// Input type - #[cfg(feature = "flashloan_v2")] pub input_type: EVMInputTy, /// Caller address @@ -170,7 +164,6 @@ pub struct ConciseEVMInput { pub env: Env, /// Percentage of the token amount in all callers' account to liquidate - #[cfg(feature = "flashloan_v2")] pub liquidation_percent: u8, /// Additional random bytes for mutator @@ -204,7 +197,6 @@ impl ConciseEVMInput { }; Self { - #[cfg(feature = "flashloan_v2")] input_type: input.get_input_type(), caller: input.get_caller(), contract: input.get_contract(), @@ -218,7 +210,6 @@ impl ConciseEVMInput { txn_value: input.get_txn_value(), step: input.is_step(), env: input.get_vm_env().clone(), - #[cfg(feature = "flashloan_v2")] liquidation_percent: input.get_liquidation_percent(), randomness: input.get_randomness(), repeat: input.get_repeat(), @@ -236,7 +227,6 @@ impl ConciseEVMInput { I: VMInputT + EVMInputT, { Self { - #[cfg(feature = "flashloan_v2")] input_type: input.get_input_type(), caller: input.get_caller(), contract: input.get_contract(), @@ -250,7 +240,6 @@ impl ConciseEVMInput { txn_value: input.get_txn_value(), step: input.is_step(), env: input.get_vm_env().clone(), - #[cfg(feature = "flashloan_v2")] liquidation_percent: input.get_liquidation_percent(), randomness: input.get_randomness(), repeat: input.get_repeat(), @@ -263,7 +252,6 @@ impl ConciseEVMInput { pub fn to_input(&self, sstate: EVMStagedVMState) -> (EVMInput, u32) { ( EVMInput { - #[cfg(feature = "flashloan_v2")] input_type: self.input_type.clone(), caller: self.caller, contract: self.contract, @@ -277,7 +265,6 @@ impl ConciseEVMInput { step: self.step, env: self.env.clone(), access_pattern: Rc::new(RefCell::new(AccessPattern::new())), - #[cfg(feature = "flashloan_v2")] liquidation_percent: self.liquidation_percent, #[cfg(not(feature = "debug"))] direct_data: Bytes::new(), @@ -292,7 +279,6 @@ impl ConciseEVMInput { // Variable `liq` is used when `debug` feature is disabled #[allow(unused_variables)] - #[cfg(feature = "flashloan_v2")] fn pretty_txn(&self) -> Option { #[cfg(not(feature = "debug"))] match self.data { @@ -308,14 +294,6 @@ impl ConciseEVMInput { self.as_transfer() } - #[cfg(not(feature = "flashloan_v2"))] - fn pretty_txn(&self) -> Option { - match self.data { - Some(ref d) => self.as_abi_call(d.to_colored_string()), - None => self.as_transfer(), - } - } - #[allow(dead_code)] #[inline] fn as_abi_call(&self, call_str: String) -> Option { @@ -386,7 +364,6 @@ impl ConciseEVMInput { )) } - #[cfg(feature = "flashloan_v2")] #[inline] fn append_liquidation(&self, indent: String, call: String) -> String { if self.liquidation_percent == 0 { @@ -406,12 +383,6 @@ impl ConciseEVMInput { [call, liq].join("\n") } - #[cfg(not(feature = "flashloan_v2"))] - #[inline] - fn append_liquidation(&self, _indent: String, call: String) -> String { - call - } - #[inline] fn colored_value(&self) -> String { let value = self.txn_value.unwrap_or_default(); @@ -489,12 +460,10 @@ impl SolutionTx for ConciseEVMInput { self.txn_value.unwrap_or_default().to_string() } - #[cfg(feature = "flashloan_v2")] fn is_borrow(&self) -> bool { self.input_type == EVMInputTy::Borrow } - #[cfg(feature = "flashloan_v2")] fn liq_percent(&self) -> u8 { self.liquidation_percent } @@ -562,7 +531,6 @@ impl EVMInputT for EVMInput { self.txn_value = Some(v); } - #[cfg(feature = "flashloan_v2")] fn get_input_type(&self) -> EVMInputTy { self.input_type.clone() } @@ -575,12 +543,10 @@ impl EVMInputT for EVMInput { self.randomness = v; } - #[cfg(feature = "flashloan_v2")] fn get_liquidation_percent(&self) -> u8 { self.liquidation_percent } - #[cfg(feature = "flashloan_v2")] fn set_liquidation_percent(&mut self, v: u8) { self.liquidation_percent = v; } diff --git a/src/evm/middlewares/call_printer.rs b/src/evm/middlewares/call_printer.rs index c63840a86..33a8a0845 100644 --- a/src/evm/middlewares/call_printer.rs +++ b/src/evm/middlewares/call_printer.rs @@ -2,28 +2,18 @@ use std::{collections::HashMap, fmt::Debug, fs::OpenOptions, io::Write}; use bytes::Bytes; use itertools::Itertools; -use libafl::{ - inputs::Input, - prelude::{HasCorpus, HasMetadata, State}, - schedulers::Scheduler, -}; +use libafl::{prelude::HasMetadata, schedulers::Scheduler}; use revm_interpreter::Interpreter; use serde::{Deserialize, Serialize}; use serde_json; use tracing::debug; -use crate::{ - evm::{ - blaz::builder::ArtifactInfoMetadata, - host::FuzzHost, - input::{ConciseEVMInput, EVMInputT}, - middlewares::middleware::{Middleware, MiddlewareType}, - srcmap::parser::SourceMapLocation, - types::{as_u64, convert_u256_to_h160, EVMAddress, ProjectSourceMapTy, EVMU256}, - }, - generic_vm::vm_state::VMStateT, - input::VMInputT, - state::{HasCaller, HasCurrentInputIdx, HasItyState}, +use crate::evm::{ + blaz::builder::ArtifactInfoMetadata, + host::FuzzHost, + middlewares::middleware::{Middleware, MiddlewareType}, + srcmap::parser::SourceMapLocation, + types::{as_u64, convert_u256_to_h160, EVMAddress, EVMFuzzState, ProjectSourceMapTy, EVMU256}, }; #[derive(Clone, Debug, Serialize, Deserialize, Default)] @@ -130,21 +120,11 @@ impl CallPrinter { } } -impl Middleware for CallPrinter +impl Middleware for CallPrinter where - I: Input + VMInputT + EVMInputT + 'static, - VS: VMStateT, - S: State - + HasCaller - + HasCorpus - + HasItyState - + HasMetadata - + HasCurrentInputIdx - + Debug - + Clone, - SC: Scheduler + Clone, + SC: Scheduler + Clone, { - unsafe fn on_step(&mut self, interp: &mut Interpreter, host: &mut FuzzHost, state: &mut S) { + unsafe fn on_step(&mut self, interp: &mut Interpreter, host: &mut FuzzHost, state: &mut EVMFuzzState) { if self.entry { self.entry = false; let code_address = interp.contract.address; @@ -305,8 +285,8 @@ where unsafe fn on_return( &mut self, _interp: &mut Interpreter, - _host: &mut FuzzHost, - _state: &mut S, + _host: &mut FuzzHost, + _state: &mut EVMFuzzState, by: &Bytes, ) { self.offsets += 1; diff --git a/src/evm/middlewares/cheatcode.rs b/src/evm/middlewares/cheatcode.rs index 64c333937..697e0c1b1 100644 --- a/src/evm/middlewares/cheatcode.rs +++ b/src/evm/middlewares/cheatcode.rs @@ -12,26 +12,16 @@ use alloy_primitives::{Address, Bytes as AlloyBytes, Log as RawLog, B256}; use alloy_sol_types::{SolInterface, SolValue}; use bytes::Bytes; use foundry_cheatcodes::Vm::{self, CallerMode, VmCalls}; -use libafl::{ - prelude::Input, - schedulers::Scheduler, - state::{HasCorpus, HasMetadata, HasRand, State}, -}; +use libafl::{schedulers::Scheduler, state::HasMetadata}; use revm_interpreter::{analysis::to_analysed, opcode, BytecodeLocked, InstructionResult, Interpreter}; use revm_primitives::{Bytecode, Env, SpecId, B160, U256}; use tracing::{debug, error, warn}; use super::middleware::{Middleware, MiddlewareType}; -use crate::{ - evm::{ - host::FuzzHost, - input::{ConciseEVMInput, EVMInputT}, - types::EVMAddress, - vm::EVMState, - }, - generic_vm::vm_state::VMStateT, - input::VMInputT, - state::{HasCaller, HasItyState}, +use crate::evm::{ + host::FuzzHost, + types::{EVMAddress, EVMFuzzState}, + vm::EVMState, }; /// 0x7109709ECfa91a80626fF3989D68f67F5b1DD12D @@ -65,13 +55,13 @@ pub const ERROR_PREFIX: [u8; 4] = [11, 196, 69, 3]; pub type ExpectedCallTracker = HashMap, (ExpectedCallData, u64)>>; #[derive(Clone, Debug, Default)] -pub struct Cheatcode { +pub struct Cheatcode { /// Recorded storage reads and writes accesses: Option, /// Recorded logs recorded_logs: Option>, - _phantom: PhantomData<(I, VS, S, SC)>, + _phantom: PhantomData, } /// Prank information. @@ -187,21 +177,11 @@ macro_rules! cheat_call_error { }}; } -impl Middleware for Cheatcode +impl Middleware for Cheatcode where - I: Input + VMInputT + EVMInputT + 'static, - S: State - + HasCorpus - + HasCaller - + HasItyState - + HasMetadata - + HasRand - + Clone - + Debug, - VS: VMStateT, - SC: Scheduler + Clone + Debug, + SC: Scheduler + Clone + Debug, { - unsafe fn on_step(&mut self, interp: &mut Interpreter, host: &mut FuzzHost, _state: &mut S) { + unsafe fn on_step(&mut self, interp: &mut Interpreter, host: &mut FuzzHost, _state: &mut EVMFuzzState) { let op = interp.current_opcode(); match get_opcode_type(op, interp) { OpcodeType::CheatCall => self.cheat_call(interp, host), @@ -217,19 +197,9 @@ where } } -impl Cheatcode +impl Cheatcode where - I: Input + VMInputT + EVMInputT + 'static, - S: State - + HasCorpus - + HasCaller - + HasItyState - + HasMetadata - + HasRand - + Clone - + Debug, - VS: VMStateT, - SC: Scheduler + Clone, + SC: Scheduler + Clone, { pub fn new() -> Self { Self { @@ -240,7 +210,7 @@ where } /// Call cheatcode address - pub fn cheat_call(&mut self, interp: &mut Interpreter, host: &mut FuzzHost) { + pub fn cheat_call(&mut self, interp: &mut Interpreter, host: &mut FuzzHost) { let op = interp.current_opcode(); let calldata = unsafe { pop_cheatcall_stack(interp, op) }; if let Err(err) = calldata { @@ -418,19 +388,9 @@ where } /// Cheat VmCalls -impl Cheatcode +impl Cheatcode where - I: Input + VMInputT + EVMInputT + 'static, - S: State - + HasCorpus - + HasCaller - + HasItyState - + HasMetadata - + HasRand - + Clone - + Debug, - VS: VMStateT, - SC: Scheduler + Clone, + SC: Scheduler + Clone, { /// Sets `block.timestamp`. #[inline] @@ -520,7 +480,7 @@ where /// Sets an address' code. #[inline] - fn etch(&self, host: &mut FuzzHost, args: Vm::etchCall) -> Option> { + fn etch(&self, host: &mut FuzzHost, args: Vm::etchCall) -> Option> { let Vm::etchCall { target, newRuntimeBytecode, @@ -612,12 +572,7 @@ where /// Sets the *next* call's `msg.sender` to be the input address. #[inline] - fn prank0( - &mut self, - host: &mut FuzzHost, - old_caller: &EVMAddress, - args: Vm::prank_0Call, - ) -> Option> { + fn prank0(&mut self, host: &mut FuzzHost, old_caller: &EVMAddress, args: Vm::prank_0Call) -> Option> { let Vm::prank_0Call { msgSender } = args; host.prank = Some(Prank::new( *old_caller, @@ -636,7 +591,7 @@ where #[inline] fn prank1( &mut self, - host: &mut FuzzHost, + host: &mut FuzzHost, old_caller: &EVMAddress, old_origin: &EVMAddress, args: Vm::prank_1Call, @@ -659,7 +614,7 @@ where #[inline] fn start_prank0( &mut self, - host: &mut FuzzHost, + host: &mut FuzzHost, old_caller: &EVMAddress, args: Vm::startPrank_0Call, ) -> Option> { @@ -681,7 +636,7 @@ where #[inline] fn start_prank1( &mut self, - host: &mut FuzzHost, + host: &mut FuzzHost, old_caller: &EVMAddress, old_origin: &EVMAddress, args: Vm::startPrank_1Call, @@ -701,14 +656,14 @@ where /// Resets subsequent calls' `msg.sender` to be `address(this)`. #[inline] - fn stop_prank(&mut self, host: &mut FuzzHost) -> Option> { + fn stop_prank(&mut self, host: &mut FuzzHost) -> Option> { let _ = host.prank.take(); None } /// Expects an error on next call with any revert data. #[inline] - fn expect_revert0(&mut self, host: &mut FuzzHost) -> Option> { + fn expect_revert0(&mut self, host: &mut FuzzHost) -> Option> { host.expected_revert = Some(ExpectedRevert { reason: None, depth: host.call_depth, @@ -718,7 +673,7 @@ where /// Expects an error on next call that starts with the revert data. #[inline] - fn expect_revert1(&mut self, host: &mut FuzzHost, args: Vm::expectRevert_1Call) -> Option> { + fn expect_revert1(&mut self, host: &mut FuzzHost, args: Vm::expectRevert_1Call) -> Option> { let Vm::expectRevert_1Call { revertData } = args; let reason = Some(Bytes::from(revertData.0.to_vec())); host.expected_revert = Some(ExpectedRevert { @@ -730,7 +685,7 @@ where /// Expects an error on next call that exactly matches the revert data. #[inline] - fn expect_revert2(&mut self, host: &mut FuzzHost, args: Vm::expectRevert_2Call) -> Option> { + fn expect_revert2(&mut self, host: &mut FuzzHost, args: Vm::expectRevert_2Call) -> Option> { let Vm::expectRevert_2Call { revertData } = args; let reason = Some(Bytes::from(revertData)); host.expected_revert = Some(ExpectedRevert { @@ -746,7 +701,7 @@ where /// logs were emitted in the expected order with the expected topics and /// data (as specified by the booleans). #[inline] - fn expect_emit0(&mut self, host: &mut FuzzHost, args: Vm::expectEmit_0Call) -> Option> { + fn expect_emit0(&mut self, host: &mut FuzzHost, args: Vm::expectEmit_0Call) -> Option> { let Vm::expectEmit_0Call { checkTopic1, checkTopic2, @@ -765,7 +720,7 @@ where /// Same as the previous method, but also checks supplied address against /// emitting contract. #[inline] - fn expect_emit1(&mut self, host: &mut FuzzHost, args: Vm::expectEmit_1Call) -> Option> { + fn expect_emit1(&mut self, host: &mut FuzzHost, args: Vm::expectEmit_1Call) -> Option> { let Vm::expectEmit_1Call { checkTopic1, checkTopic2, @@ -788,7 +743,7 @@ where /// after the call, we check if logs were emitted in the expected order /// with the expected topics and data. #[inline] - fn expect_emit2(&mut self, host: &mut FuzzHost) -> Option> { + fn expect_emit2(&mut self, host: &mut FuzzHost) -> Option> { let expected = ExpectedEmit { depth: host.call_depth, checks: [true, true, true, true], @@ -801,7 +756,7 @@ where /// Same as the previous method, but also checks supplied address against /// emitting contract. #[inline] - fn expect_emit3(&mut self, host: &mut FuzzHost, args: Vm::expectEmit_3Call) -> Option> { + fn expect_emit3(&mut self, host: &mut FuzzHost, args: Vm::expectEmit_3Call) -> Option> { let Vm::expectEmit_3Call { emitter } = args; let expected = ExpectedEmit { depth: host.call_depth, @@ -1261,13 +1216,8 @@ mod tests { &mut state, ); - let mut evm_executor: EVMExecutor< - EVMInput, - EVMFuzzState, - EVMState, - ConciseEVMInput, - StdScheduler, - > = EVMExecutor::new(fuzz_host, generate_random_address(&mut state)); + let mut evm_executor: EVMExecutor> = + EVMExecutor::new(fuzz_host, generate_random_address(&mut state)); let mut deploy_state = FuzzState::new(0); // Deploy Reverter @@ -1300,7 +1250,6 @@ mod tests { step: false, env: Default::default(), access_pattern: Rc::new(RefCell::new(AccessPattern::new())), - #[cfg(feature = "flashloan_v2")] liquidation_percent: 0, direct_data: Bytes::from( [ @@ -1309,7 +1258,6 @@ mod tests { ] .concat(), ), - #[cfg(feature = "flashloan_v2")] input_type: EVMInputTy::ABI, randomness: vec![], repeat: 1, diff --git a/src/evm/middlewares/coverage.rs b/src/evm/middlewares/coverage.rs index 507d3dd71..ee532af84 100644 --- a/src/evm/middlewares/coverage.rs +++ b/src/evm/middlewares/coverage.rs @@ -9,11 +9,7 @@ use std::{ }; use itertools::Itertools; -use libafl::{ - inputs::Input, - prelude::{HasCorpus, HasMetadata, State}, - schedulers::Scheduler, -}; +use libafl::{prelude::HasMetadata, schedulers::Scheduler}; use revm_interpreter::{ opcode::{INVALID, JUMPDEST, JUMPI, STOP}, Interpreter, @@ -23,25 +19,19 @@ use serde::Serialize; use serde_json; use tracing::info; -use crate::{ - evm::{ - blaz::builder::ArtifactInfoMetadata, - bytecode_iterator::all_bytecode, - host::FuzzHost, - input::{ConciseEVMInput, EVMInputT}, - middlewares::middleware::{Middleware, MiddlewareType}, - srcmap::parser::{ - pretty_print_source_map, - pretty_print_source_map_single, - SourceMapAvailability, - SourceMapWithCode, - }, - types::{is_zero, EVMAddress, ProjectSourceMapTy}, - vm::IN_DEPLOY, +use crate::evm::{ + blaz::builder::ArtifactInfoMetadata, + bytecode_iterator::all_bytecode, + host::FuzzHost, + middlewares::middleware::{Middleware, MiddlewareType}, + srcmap::parser::{ + pretty_print_source_map, + pretty_print_source_map_single, + SourceMapAvailability, + SourceMapWithCode, }, - generic_vm::vm_state::VMStateT, - input::VMInputT, - state::{HasCaller, HasCurrentInputIdx, HasItyState}, + types::{is_zero, EVMAddress, EVMFuzzState, ProjectSourceMapTy}, + vm::IN_DEPLOY, }; pub static mut EVAL_COVERAGE: bool = false; @@ -337,21 +327,11 @@ impl Coverage { } } -impl Middleware for Coverage +impl Middleware for Coverage where - I: Input + VMInputT + EVMInputT + 'static, - VS: VMStateT, - S: State - + HasCaller - + HasCorpus - + HasItyState - + HasMetadata - + HasCurrentInputIdx - + Debug - + Clone, - SC: Scheduler + Clone, + SC: Scheduler + Clone, { - unsafe fn on_step(&mut self, interp: &mut Interpreter, _host: &mut FuzzHost, _state: &mut S) { + unsafe fn on_step(&mut self, interp: &mut Interpreter, _host: &mut FuzzHost, _state: &mut EVMFuzzState) { if IN_DEPLOY || !EVAL_COVERAGE { return; } @@ -368,8 +348,8 @@ where unsafe fn on_insert( &mut self, _: Option<&mut Interpreter>, - _host: &mut FuzzHost, - state: &mut S, + _host: &mut FuzzHost, + state: &mut EVMFuzzState, bytecode: &mut Bytecode, address: EVMAddress, ) { diff --git a/src/evm/middlewares/math_calculate.rs b/src/evm/middlewares/math_calculate.rs index 816f2a624..aef348a7e 100644 --- a/src/evm/middlewares/math_calculate.rs +++ b/src/evm/middlewares/math_calculate.rs @@ -1,27 +1,17 @@ use std::{collections::HashSet, fmt::Debug, str::FromStr}; -use libafl::{ - inputs::Input, - prelude::{HasCorpus, HasMetadata, State}, - schedulers::Scheduler, -}; +use libafl::{prelude::HasMetadata, schedulers::Scheduler}; use revm_interpreter::Interpreter; use revm_primitives::{keccak256, B256}; use serde::Serialize; use tracing::info; -use crate::{ - evm::{ - host::FuzzHost, - input::{ConciseEVMInput, EVMInputT}, - middlewares::middleware::{Middleware, MiddlewareType}, - onchain::endpoints::{Chain, OnChainConfig}, - types::EVMAddress, - uniswap::{get_uniswap_info, UniswapProvider}, - }, - generic_vm::vm_state::VMStateT, - input::VMInputT, - state::{HasCaller, HasCurrentInputIdx, HasItyState}, +use crate::evm::{ + host::FuzzHost, + middlewares::middleware::{Middleware, MiddlewareType}, + onchain::endpoints::{Chain, OnChainConfig}, + types::{EVMAddress, EVMFuzzState}, + uniswap::{get_uniswap_info, UniswapProvider}, }; #[derive(Serialize, Debug, Clone, Default)] @@ -49,21 +39,11 @@ impl MathCalculateMiddleware { } } -impl Middleware for MathCalculateMiddleware +impl Middleware for MathCalculateMiddleware where - I: Input + VMInputT + EVMInputT + 'static, - VS: VMStateT, - S: State - + HasCaller - + HasCorpus - + HasItyState - + HasMetadata - + HasCurrentInputIdx - + Debug - + Clone, - SC: Scheduler + Clone, + SC: Scheduler + Clone, { - unsafe fn on_step(&mut self, interp: &mut Interpreter, host: &mut FuzzHost, _state: &mut S) { + unsafe fn on_step(&mut self, interp: &mut Interpreter, host: &mut FuzzHost, _state: &mut EVMFuzzState) { let addr = interp.contract.code_address; let pc = interp.program_counter(); macro_rules! check { diff --git a/src/evm/middlewares/middleware.rs b/src/evm/middlewares/middleware.rs index 765ba52a5..37550cf2a 100644 --- a/src/evm/middlewares/middleware.rs +++ b/src/evm/middlewares/middleware.rs @@ -3,26 +3,19 @@ use std::{clone::Clone, fmt::Debug, time::Duration}; use bytes::Bytes; use libafl::{ corpus::{Corpus, Testcase}, - inputs::Input, - prelude::UsesInput, schedulers::Scheduler, - state::{HasCorpus, HasMetadata, State}, + state::HasCorpus, }; use primitive_types::U512; use revm_interpreter::Interpreter; use revm_primitives::Bytecode; use serde::{Deserialize, Serialize}; -use crate::{ - evm::{ - host::FuzzHost, - input::{ConciseEVMInput, EVMInput, EVMInputT}, - types::{EVMAddress, EVMU256}, - vm::EVMState, - }, - generic_vm::vm_state::VMStateT, - input::VMInputT, - state::{HasCaller, HasItyState}, +use crate::evm::{ + host::FuzzHost, + input::EVMInput, + types::{EVMAddress, EVMFuzzState, EVMU256}, + vm::EVMState, }; #[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize, Copy)] @@ -68,22 +61,11 @@ pub enum MiddlewareOp { MakeSubsequentCallSuccess(Bytes), } -pub fn add_corpus(host: &mut FuzzHost, state: &mut S, input: &EVMInput) +pub fn add_corpus(host: &mut FuzzHost, state: &mut EVMFuzzState, input: &EVMInput) where - I: Input + VMInputT + EVMInputT + 'static, - S: State - + HasCorpus - + HasItyState - + HasMetadata - + HasCaller - + Clone - + Debug - + UsesInput - + 'static, - VS: VMStateT + Default, - SC: Scheduler + Clone, + SC: Scheduler + Clone, { - let mut tc = Testcase::new(input.as_any().downcast_ref::().unwrap().clone()) as Testcase; + let mut tc = Testcase::new(input.clone()) as Testcase; tc.set_exec_time(Duration::from_secs(0)); let idx = state.corpus_mut().add(tc).expect("failed to add"); host.scheduler @@ -91,22 +73,19 @@ where .expect("failed to call scheduler on_add"); } -pub trait Middleware: Debug +pub trait Middleware: Debug where - S: State + HasCorpus + HasCaller + Clone + Debug, - I: VMInputT + EVMInputT, - VS: VMStateT, - SC: Scheduler + Clone, + SC: Scheduler + Clone, { #[allow(clippy::missing_safety_doc)] - unsafe fn on_step(&mut self, interp: &mut Interpreter, host: &mut FuzzHost, state: &mut S); + unsafe fn on_step(&mut self, interp: &mut Interpreter, host: &mut FuzzHost, state: &mut EVMFuzzState); #[allow(clippy::missing_safety_doc)] unsafe fn on_return( &mut self, _interp: &mut Interpreter, - _host: &mut FuzzHost, - _state: &mut S, + _host: &mut FuzzHost, + _state: &mut EVMFuzzState, _ret: &Bytes, ) { } @@ -115,8 +94,8 @@ where unsafe fn before_execute( &mut self, _interp: Option<&mut Interpreter>, - _host: &mut FuzzHost, - _state: &mut S, + _host: &mut FuzzHost, + _state: &mut EVMFuzzState, _is_step: bool, _data: &mut Bytes, _evm_state: &mut EVMState, @@ -127,8 +106,8 @@ where unsafe fn on_insert( &mut self, _interp: Option<&mut Interpreter>, - _host: &mut FuzzHost, - _state: &mut S, + _host: &mut FuzzHost, + _state: &mut EVMFuzzState, _bytecode: &mut Bytecode, _address: EVMAddress, ) { diff --git a/src/evm/middlewares/reentrancy.rs b/src/evm/middlewares/reentrancy.rs index 2ba0eae40..10c67c9cd 100644 --- a/src/evm/middlewares/reentrancy.rs +++ b/src/evm/middlewares/reentrancy.rs @@ -4,25 +4,15 @@ use std::{ }; use bytes::Bytes; -use libafl::{ - inputs::Input, - prelude::{HasCorpus, HasMetadata, State}, - schedulers::Scheduler, -}; +use libafl::schedulers::Scheduler; use revm_interpreter::Interpreter; use serde::{Deserialize, Serialize}; -use crate::{ - evm::{ - host::FuzzHost, - input::{ConciseEVMInput, EVMInputT}, - middlewares::middleware::{Middleware, MiddlewareType}, - types::{EVMAddress, EVMU256}, - vm::EVMState, - }, - generic_vm::vm_state::VMStateT, - input::VMInputT, - state::{HasCaller, HasCurrentInputIdx, HasItyState}, +use crate::evm::{ + host::FuzzHost, + middlewares::middleware::{Middleware, MiddlewareType}, + types::{EVMAddress, EVMFuzzState, EVMU256}, + vm::EVMState, }; #[derive(Serialize, Debug, Clone, Default)] @@ -89,21 +79,11 @@ fn merge_sorted_vec_dedup(dst: &mut Vec, another_one: &Vec) { } // Reentrancy: Read, Read, Write -impl Middleware for ReentrancyTracer +impl Middleware for ReentrancyTracer where - I: Input + VMInputT + EVMInputT + 'static, - VS: VMStateT, - S: State - + HasCaller - + HasCorpus - + HasItyState - + HasMetadata - + HasCurrentInputIdx - + Debug - + Clone, - SC: Scheduler + Clone, + SC: Scheduler + Clone, { - unsafe fn on_step(&mut self, interp: &mut Interpreter, host: &mut FuzzHost, _state: &mut S) { + unsafe fn on_step(&mut self, interp: &mut Interpreter, host: &mut FuzzHost, _state: &mut EVMFuzzState) { match *interp.instruction_pointer { 0x54 => { let depth = host.evmstate.post_execution.len() as u32; @@ -182,8 +162,8 @@ where unsafe fn before_execute( &mut self, interp: Option<&mut Interpreter>, - host: &mut FuzzHost, - state: &mut S, + host: &mut FuzzHost, + state: &mut EVMFuzzState, is_step: bool, data: &mut Bytes, evm_state: &mut EVMState, diff --git a/src/evm/middlewares/sha3_bypass.rs b/src/evm/middlewares/sha3_bypass.rs index 09d65d0ff..8506ba1d2 100644 --- a/src/evm/middlewares/sha3_bypass.rs +++ b/src/evm/middlewares/sha3_bypass.rs @@ -6,24 +6,14 @@ use std::{ }; use bytes::Bytes; -use libafl::{ - inputs::Input, - prelude::{HasCorpus, HasMetadata, State}, - schedulers::Scheduler, -}; +use libafl::{prelude::HasMetadata, schedulers::Scheduler}; use revm_interpreter::{opcode::JUMPI, Interpreter}; use tracing::debug; -use crate::{ - evm::{ - host::FuzzHost, - input::{ConciseEVMInput, EVMInputT}, - middlewares::middleware::{Middleware, MiddlewareType}, - types::{as_u64, EVMAddress, EVMU256}, - }, - generic_vm::vm_state::VMStateT, - input::VMInputT, - state::{HasCaller, HasCurrentInputIdx, HasItyState}, +use crate::evm::{ + host::FuzzHost, + middlewares::middleware::{Middleware, MiddlewareType}, + types::{as_u64, EVMAddress, EVMFuzzState, EVMU256}, }; const MAX_CALL_DEPTH: u64 = 3; @@ -114,21 +104,11 @@ impl Sha3TaintAnalysis { } } -impl Middleware for Sha3TaintAnalysis +impl Middleware for Sha3TaintAnalysis where - I: Input + VMInputT + EVMInputT + 'static, - VS: VMStateT, - S: State - + HasCaller - + HasCorpus - + HasItyState - + HasMetadata - + HasCurrentInputIdx - + Debug - + Clone, - SC: Scheduler + Clone, + SC: Scheduler + Clone, { - unsafe fn on_step(&mut self, interp: &mut Interpreter, host: &mut FuzzHost, _state: &mut S) { + unsafe fn on_step(&mut self, interp: &mut Interpreter, host: &mut FuzzHost, _state: &mut EVMFuzzState) { // skip taint analysis if call depth is too deep if host.call_depth > MAX_CALL_DEPTH { return; @@ -396,8 +376,8 @@ where unsafe fn on_return( &mut self, _interp: &mut Interpreter, - _host: &mut FuzzHost, - _state: &mut S, + _host: &mut FuzzHost, + _state: &mut EVMFuzzState, _by: &Bytes, ) { self.pop_ctx(); @@ -419,21 +399,11 @@ impl Sha3Bypass { } } -impl Middleware for Sha3Bypass +impl Middleware for Sha3Bypass where - I: Input + VMInputT + EVMInputT + 'static, - VS: VMStateT, - S: State - + HasCaller - + HasCorpus - + HasItyState - + HasMetadata - + HasCurrentInputIdx - + Debug - + Clone, - SC: Scheduler + Clone, + SC: Scheduler + Clone, { - unsafe fn on_step(&mut self, interp: &mut Interpreter, host: &mut FuzzHost, _state: &mut S) { + unsafe fn on_step(&mut self, interp: &mut Interpreter, host: &mut FuzzHost, _state: &mut EVMFuzzState) { if *interp.instruction_pointer == JUMPI { let jumpi = interp.program_counter(); if self @@ -470,7 +440,7 @@ mod tests { use super::*; use crate::{ evm::{ - input::{EVMInput, EVMInputTy}, + input::{ConciseEVMInput, EVMInput, EVMInputTy}, mutator::AccessPattern, types::{generate_random_address, EVMFuzzState}, vm::{EVMExecutor, EVMState}, @@ -486,13 +456,7 @@ mod tests { if !path.exists() { let _ = std::fs::create_dir(path); } - let mut evm_executor: EVMExecutor< - EVMInput, - EVMFuzzState, - EVMState, - ConciseEVMInput, - StdScheduler, - > = EVMExecutor::new( + let mut evm_executor: EVMExecutor> = EVMExecutor::new( FuzzHost::new(StdScheduler::new(), "work_dir".to_string()), generate_random_address(&mut state), ); @@ -516,10 +480,8 @@ mod tests { step: false, env: Default::default(), access_pattern: Rc::new(RefCell::new(AccessPattern::new())), - #[cfg(feature = "flashloan_v2")] liquidation_percent: 0, direct_data: bys, - #[cfg(feature = "flashloan_v2")] input_type: EVMInputTy::ABI, randomness: vec![], repeat: 1, diff --git a/src/evm/mutator.rs b/src/evm/mutator.rs index a91710205..38b55a15b 100644 --- a/src/evm/mutator.rs +++ b/src/evm/mutator.rs @@ -14,11 +14,10 @@ use serde::{de::DeserializeOwned, Deserialize, Serialize}; /// Mutator for EVM inputs use crate::evm::input::EVMInputT; -#[cfg(feature = "flashloan_v2")] -use crate::evm::input::EVMInputTy::Borrow; use crate::{ evm::{ abi::ABIAddressToInstanceMap, + input::EVMInputTy::Borrow, types::{convert_u256_to_h160, EVMAddress, EVMU256}, vm::{Constraint, EVMStateT}, }, @@ -126,19 +125,13 @@ where for constraint in &constraints { match constraint { Constraint::MustStepNow => { - #[cfg(feature = "flashloan_v2")] - { - if input.get_input_type() == Borrow { - return false; - } + if input.get_input_type() == Borrow { + return false; } } Constraint::Contract(_) => { - #[cfg(feature = "flashloan_v2")] - { - if input.get_input_type() == Borrow { - return false; - } + if input.get_input_type() == Borrow { + return false; } } _ => {} @@ -173,10 +166,7 @@ where input.set_contract_and_abi(target, abi); } Constraint::NoLiquidation => { - #[cfg(feature = "flashloan_v2")] - { - input.set_liquidation_percent(0); - } + input.set_liquidation_percent(0); } Constraint::MustStepNow => { input.set_step(true); @@ -227,26 +217,18 @@ where // use exploit template if state.has_preset() && state.rand_mut().below(100) < 20 { // if flashloan_v2, we don't mutate if it's a borrow - #[cfg(feature = "flashloan_v2")] - { - if input.get_input_type() != Borrow { - match state.get_next_call() { - Some((addr, abi)) => { - input.set_contract_and_abi(addr, Some(abi)); - input.mutate(state); - return Ok(MutationResult::Mutated); - } - None => { - // debug!("cannot find next call"); - } + if input.get_input_type() != Borrow { + match state.get_next_call() { + Some((addr, abi)) => { + input.set_contract_and_abi(addr, Some(abi)); + input.mutate(state); + return Ok(MutationResult::Mutated); + } + None => { + // debug!("cannot find next call"); } } } - - #[cfg(not(feature = "flashloan_v2"))] - { - // todo!("set function") - } } // determine whether we should conduct havoc // (a sequence of mutations in batch vs single mutation) @@ -296,15 +278,7 @@ where mutated = true; }; } - #[cfg(feature = "flashloan_v2")] - { - if input.get_input_type() != Borrow { - turn_to_step!(); - } - } - - #[cfg(not(feature = "flashloan_v2"))] - { + if input.get_input_type() != Borrow { turn_to_step!(); } @@ -318,7 +292,6 @@ where // we should not mutate the VM state, but only mutate the bytes if input.is_step() { let res = match state.rand_mut().below(100) { - #[cfg(feature = "flashloan_v2")] 0..=5 => { let prev_percent = input.get_liquidation_percent(); input.set_liquidation_percent(if state.rand_mut().below(100) < 80 { 10 } else { 0 } as u8); @@ -336,7 +309,6 @@ where // if the input is to borrow token, we should mutate the randomness // (use to select the paths to buy token), VM state, and bytes - #[cfg(feature = "flashloan_v2")] if input.get_input_type() == Borrow { let rand_u8 = state.rand_mut().below(255) as u8; return match state.rand_mut().below(3) { @@ -353,7 +325,6 @@ where // mutate the bytes or VM state or liquidation percent (percentage of token to // liquidate) by default match state.rand_mut().below(100) { - #[cfg(feature = "flashloan_v2")] 6..=10 => { let prev_percent = input.get_liquidation_percent(); input.set_liquidation_percent(if state.rand_mut().below(100) < 80 { 10 } else { 0 } as u8); diff --git a/src/evm/onchain/flashloan.rs b/src/evm/onchain/flashloan.rs index 5c3246969..2ee03e5c2 100644 --- a/src/evm/onchain/flashloan.rs +++ b/src/evm/onchain/flashloan.rs @@ -7,7 +7,6 @@ use std::{ cell::RefCell, collections::{HashMap, HashSet}, fmt::Debug, - marker::PhantomData, ops::Deref, rc::Rc, str::FromStr, @@ -20,7 +19,7 @@ use libafl::{ inputs::Input, prelude::{HasCorpus, State, UsesInput}, schedulers::Scheduler, - state::{HasMetadata, HasRand}, + state::HasMetadata, }; // impl_serdeany is used when `flashloan_v2` feature is not enabled #[allow(unused_imports)] @@ -29,6 +28,7 @@ use revm_interpreter::Interpreter; use serde::{Deserialize, Serialize}; use tracing::debug; +use crate::evm::{types::EVMFuzzState, uniswap::TokenContext}; // Some components are used when `flashloan_v2` feature is not enabled #[allow(unused_imports)] use crate::{ @@ -56,37 +56,18 @@ macro_rules! scale { EVMU512::from(1_000_000) }; } -pub struct Flashloan -where - S: State + HasCaller + Debug + Clone + 'static, - I: VMInputT + EVMInputT, - VS: VMStateT, -{ - phantom: PhantomData<(VS, I, S)>, +pub struct Flashloan { oracle: Box, use_contract_value: bool, - #[cfg(feature = "flashloan_v2")] known_addresses: HashSet, - #[cfg(feature = "flashloan_v2")] - endpoint: OnChainConfig, - #[cfg(feature = "flashloan_v2")] + endpoint: Option, erc20_address: HashSet, - #[cfg(feature = "flashloan_v2")] pair_address: HashSet, - #[cfg(feature = "flashloan_v2")] - pub onchain_middlware: Rc>>, - #[cfg(feature = "flashloan_v2")] pub unbound_tracker: HashMap>, // pc -> [address called] - #[cfg(feature = "flashloan_v2")] pub flashloan_oracle: Rc>, } -impl Debug for Flashloan -where - S: State + HasCaller + Debug + Clone + 'static, - I: VMInputT + EVMInputT, - VS: VMStateT, -{ +impl Debug for Flashloan { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Flashloan") .field("oracle", &self.oracle) @@ -122,7 +103,6 @@ where let mut tc = Testcase::new( { EVMInput { - #[cfg(feature = "flashloan_v2")] input_type: EVMInputTy::Borrow, caller: state.get_rand_caller(), contract: token, @@ -133,7 +113,6 @@ where step: false, env: Default::default(), access_pattern: Rc::new(RefCell::new(AccessPattern::new())), - #[cfg(feature = "flashloan_v2")] liquidation_percent: 0, direct_data: Default::default(), randomness: vec![0], @@ -150,47 +129,20 @@ where scheduler.on_add(state, idx).expect("failed to call scheduler on_add"); } -impl Flashloan -where - S: State - + HasRand - + HasCaller - + HasCorpus - + Debug - + Clone - + HasMetadata - + HasItyState - + UsesInput - + 'static, - I: VMInputT + EVMInputT + 'static, - VS: VMStateT, -{ - #[cfg(not(feature = "flashloan_v2"))] - pub fn new(use_contract_value: bool) -> Self { - Self { - phantom: PhantomData, - oracle: Box::new(DummyPriceOracle {}), - use_contract_value, - } - } - - #[cfg(feature = "flashloan_v2")] +impl Flashloan { pub fn new( use_contract_value: bool, - endpoint: OnChainConfig, + endpoint: Option, price_oracle: Box, - onchain_middleware: Rc>>, flashloan_oracle: Rc>, ) -> Self { Self { - phantom: PhantomData, oracle: price_oracle, use_contract_value, known_addresses: Default::default(), endpoint, erc20_address: Default::default(), pair_address: Default::default(), - onchain_middlware: onchain_middleware, unbound_tracker: Default::default(), flashloan_oracle, } @@ -214,8 +166,18 @@ where .map(|price| Self::calculate_usd_value(price, amount)) } - #[cfg(feature = "flashloan_v2")] - pub fn on_contract_insertion(&mut self, addr: &EVMAddress, abi: &[ABIConfig], _state: &mut S) -> (bool, bool) { + fn get_token_context(&mut self, addr: EVMAddress) -> Option { + self.endpoint + .as_mut() + .map(|endpoint| endpoint.fetch_uniswap_path_cached(addr).clone()) + } + + pub fn on_contract_insertion( + &mut self, + addr: &EVMAddress, + abi: &[ABIConfig], + _state: &mut EVMFuzzState, + ) -> (bool, bool) { // should not happen, just sanity check if self.known_addresses.contains(addr) { return (false, false); @@ -237,18 +199,23 @@ where let mut is_pair = false; // check abi_signatures_token is subset of abi.name { - let oracle = self.flashloan_oracle.deref().try_borrow_mut(); - // avoid delegate call on token -> make oracle borrow multiple times - if oracle.is_ok() { - if abi_signatures_token.iter().all(|x| abi_names.contains(x)) { - oracle - .unwrap() - .register_token(*addr, self.endpoint.fetch_uniswap_path_cached(*addr).clone()); - self.erc20_address.insert(*addr); - is_erc20 = true; + if abi_signatures_token.iter().all(|x| abi_names.contains(x)) { + match self.get_token_context(*addr) { + Some(token_ctx) => { + let oracle = self.flashloan_oracle.deref().try_borrow_mut(); + // avoid delegate call on token -> make oracle borrow multiple times + if oracle.is_ok() { + oracle.unwrap().register_token(*addr, token_ctx); + self.erc20_address.insert(*addr); + is_erc20 = true; + } else { + debug!("Unable to liquidate token {:?}", addr); + } + } + None => { + debug!("Unable to liquidate token {:?}", addr); + } } - } else { - debug!("Ignoring token {:?}", addr); } } @@ -262,10 +229,9 @@ where (is_erc20, is_pair) } - #[cfg(feature = "flashloan_v2")] - pub fn on_pair_insertion(&mut self, host: &FuzzHost, state: &mut S, pair: EVMAddress) + pub fn on_pair_insertion(&mut self, host: &FuzzHost, state: &mut EVMFuzzState, pair: EVMAddress) where - SC: Scheduler + Clone, + SC: Scheduler + Clone, { let slots = host.find_static_call_read_slot( pair, @@ -283,14 +249,8 @@ where } } -#[cfg(feature = "flashloan_v2")] -impl Flashloan -where - S: State + HasCaller + Debug + Clone + 'static, - I: VMInputT + EVMInputT, - VS: VMStateT, -{ - pub fn analyze_call(&self, input: &I, flashloan_data: &mut FlashloanData) { +impl Flashloan { + pub fn analyze_call(&self, input: &EVMInput, flashloan_data: &mut FlashloanData) { // if the txn is a transfer op, record it if input.get_txn_value().is_some() { flashloan_data.owed += EVMU512::from(input.get_txn_value().unwrap()) * scale!(); @@ -311,161 +271,11 @@ where } } -impl Middleware for Flashloan +impl Middleware for Flashloan where - S: State - + HasRand - + HasCaller - + HasMetadata - + HasCorpus - + Debug - + Clone - + HasItyState - + UsesInput - + 'static, - I: VMInputT + EVMInputT + 'static, - VS: VMStateT, - SC: Scheduler + Clone, + SC: Scheduler + Clone, { - #[cfg(not(feature = "flashloan_v2"))] - unsafe fn on_step(&mut self, interp: &mut Interpreter, host: &mut FuzzHost, _state: &mut S) { - macro_rules! earned { - ($amount:expr) => { - host.evmstate.flashloan_data.earned += $amount; - }; - () => {}; - } - - macro_rules! owed { - ($amount:expr) => { - host.evmstate.flashloan_data.owed += $amount; - }; - () => {}; - } - - let offset_of_arg_offset: usize = match *interp.instruction_pointer { - 0xf1 | 0xf2 => 3, - 0xf4 | 0xfa => 2, - _ => { - return; - } - }; - - let value_transfer = match *interp.instruction_pointer { - 0xf1 | 0xf2 => interp.stack.peek(2).unwrap(), - _ => EVMU256::ZERO, - }; - - // todo: fix for delegatecall - let call_target: EVMAddress = convert_u256_to_h160(interp.stack.peek(1).unwrap()); - - if value_transfer > EVMU256::ZERO && call_target == interp.contract.caller { - earned!(EVMU512::from(value_transfer) * float_scale_to_u512(1.0, 5)) - } - - let offset = interp.stack.peek(offset_of_arg_offset).unwrap(); - let size = interp.stack.peek(offset_of_arg_offset + 1).unwrap(); - if size < EVMU256::from(4) { - return; - } - let data = interp.memory.get_slice(as_u64(offset) as usize, as_u64(size) as usize); - // debug!("Calling address: {:?} {:?}", hex::encode(call_target), - // hex::encode(data)); - - macro_rules! make_transfer_call_success { - () => { - host.middlewares_latent_call_actions - .push(ReturnSuccess(Bytes::from([vec![0x0; 31], vec![0x1]].concat()))); - }; - } - - macro_rules! make_balance_call_success { - () => { - host.middlewares_latent_call_actions - .push(ReturnSuccess(Bytes::from(vec![0xff; 32]))); - }; - } - macro_rules! handle_contract_contract_transfer { - () => { - if !self.use_contract_value { - make_transfer_call_success!(); - } - }; - } - - macro_rules! handle_dst_is_attacker { - ($amount:expr) => { - if self.use_contract_value { - // if we use contract value, we make attacker earns amount for oracle proc - // we assume the subsequent would revert if no enough balance - earned!($amount); - } else { - earned!($amount); - make_transfer_call_success!(); - } - }; - } - match data[0..4] { - // balanceOf / approval - [0x70, 0xa0, 0x82, 0x31] | [0x09, 0x5e, 0xa7, 0xb3] => { - if !self.use_contract_value { - make_balance_call_success!(); - } - } - // transfer - [0xa9, 0x05, 0x9c, 0xbb] => { - let dst = EVMAddress::from_slice(&data[16..36]); - let amount = EVMU256::try_from_be_slice(&data[36..68]).unwrap(); - // debug!( - // "transfer from {:?} to {:?} amount {:?}", - // interp.contract.address, dst, amount - // ); - - match self.calculate_usd_value_from_addr(call_target, amount) { - Some(value) => { - if dst == interp.contract.caller { - return handle_dst_is_attacker!(value); - } - } - // if no value, we can't borrow it! - // bypass by explicitly returning value for every token - _ => {} - } - handle_contract_contract_transfer!() - } - // transferFrom - [0x23, 0xb8, 0x72, 0xdd] => { - let src = EVMAddress::from_slice(&data[16..36]); - let dst = EVMAddress::from_slice(&data[48..68]); - let amount = EVMU256::try_from_be_slice(&data[68..100]).unwrap(); - let _make_success = - MiddlewareOp::MakeSubsequentCallSuccess(Bytes::from([vec![0x0; 31], vec![0x1]].concat())); - match self.calculate_usd_value_from_addr(call_target, amount) { - Some(value) => { - if src == interp.contract.caller { - make_transfer_call_success!(); - return owed!(value); - } else if dst == interp.contract.caller { - return handle_dst_is_attacker!(value); - } - } - // if no value, we can't borrow it! - // bypass by explicitly returning value for every token - _ => {} - } - if src != interp.contract.caller && dst != interp.contract.caller { - handle_contract_contract_transfer!() - } - } - _ => {} - }; - } - - #[cfg(feature = "flashloan_v2")] - unsafe fn on_step(&mut self, interp: &mut Interpreter, host: &mut FuzzHost, s: &mut S) - where - S: HasCaller, - { + unsafe fn on_step(&mut self, interp: &mut Interpreter, host: &mut FuzzHost, s: &mut EVMFuzzState) { // if simply static call, we dont care // if unsafe { IS_FAST_CALL_STATIC } { // return; @@ -514,26 +324,6 @@ where } } -#[cfg(not(feature = "flashloan_v2"))] -#[derive(Clone, Debug, Serialize, Deserialize, Default)] -pub struct FlashloanData { - pub owed: EVMU512, - pub earned: EVMU512, -} -#[cfg(not(feature = "flashloan_v2"))] -impl FlashloanData { - pub fn new() -> Self { - Self { - owed: EVMU512::from(0), - earned: EVMU512::from(0), - } - } -} - -#[cfg(not(feature = "flashloan_v2"))] -impl_serdeany!(FlashloanData); - -#[cfg(feature = "flashloan_v2")] #[derive(Clone, Debug, Serialize, Deserialize, Default)] pub struct FlashloanData { pub oracle_recheck_reserve: HashSet, @@ -545,7 +335,6 @@ pub struct FlashloanData { pub extra_info: String, } -#[cfg(feature = "flashloan_v2")] impl FlashloanData { pub fn new() -> Self { Self { diff --git a/src/evm/onchain/mod.rs b/src/evm/onchain/mod.rs index f8fe36ce4..da9e8d7c8 100644 --- a/src/evm/onchain/mod.rs +++ b/src/evm/onchain/mod.rs @@ -15,15 +15,12 @@ use std::{ use bytes::Bytes; use crypto::{digest::Digest, sha3::Sha3}; use itertools::Itertools; -use libafl::{ - prelude::{HasCorpus, HasMetadata, Input, UsesInput}, - schedulers::Scheduler, - state::{HasRand, State}, -}; +use libafl::{prelude::HasMetadata, schedulers::Scheduler}; use revm_interpreter::{analysis::to_analysed, Interpreter}; use revm_primitives::Bytecode; use tracing::debug; +use super::types::EVMFuzzState; use crate::{ evm::{ abi::{get_abi_type_boxed, register_abi_instance}, @@ -39,17 +36,15 @@ use crate::{ }, corpus_initializer::{ABIMap, SourceMapMap}, host::FuzzHost, - input::{ConciseEVMInput, EVMInput, EVMInputT, EVMInputTy}, + input::{EVMInput, EVMInputTy}, middlewares::middleware::{add_corpus, Middleware, MiddlewareType}, mutator::AccessPattern, onchain::{abi_decompiler::fetch_abi_heimdall, endpoints::OnChainConfig, flashloan::register_borrow_txn}, types::{convert_u256_to_h160, EVMAddress, EVMU256}, vm::IS_FAST_CALL, }, - generic_vm::vm_state::VMStateT, handle_contract_insertion, - input::VMInputT, - state::{HasCaller, HasItyState}, + state::HasCaller, state_input::StagedVMState, }; @@ -58,12 +53,7 @@ pub static mut WHITELIST_ADDR: Option> = None; const UNBOUND_THRESHOLD: usize = 30; -pub struct OnChain -where - I: Input + VMInputT, - S: State, - VS: VMStateT + Default, -{ +pub struct OnChain { pub loaded_data: HashSet<(EVMAddress, EVMU256)>, pub loaded_code: HashSet, pub loaded_abi: HashSet, @@ -76,15 +66,9 @@ where pub storage_dump: HashMap>>, pub builder: Option, pub address_to_abi: HashMap>, - pub phantom: std::marker::PhantomData<(I, S, VS)>, } -impl Debug for OnChain -where - I: Input + VMInputT, - S: State, - VS: VMStateT + Default, -{ +impl Debug for OnChain { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("OnChain") .field("loaded_data", &self.loaded_data) @@ -94,12 +78,7 @@ where } } -impl OnChain -where - I: Input + VMInputT, - S: State, - VS: VMStateT + Default, -{ +impl OnChain { pub fn new(endpoint: OnChainConfig, storage_fetching: StorageFetchingMode) -> Self { unsafe { BLACKLIST_ADDR = Some(HashSet::from([ @@ -139,7 +118,6 @@ where storage_all: Default::default(), storage_dump: Default::default(), builder: None, - phantom: Default::default(), address_to_abi: Default::default(), storage_fetching, } @@ -169,23 +147,11 @@ pub fn keccak_hex(data: EVMU256) -> String { hex::encode(output) } -impl Middleware for OnChain +impl Middleware for OnChain where - I: Input + VMInputT + EVMInputT + 'static, - S: State - + HasRand - + Debug - + HasCaller - + HasCorpus - + HasItyState - + HasMetadata - + Clone - + UsesInput - + 'static, - VS: VMStateT + Default + 'static, - SC: Scheduler + Clone, + SC: Scheduler + Clone, { - unsafe fn on_step(&mut self, interp: &mut Interpreter, host: &mut FuzzHost, state: &mut S) { + unsafe fn on_step(&mut self, interp: &mut Interpreter, host: &mut FuzzHost, state: &mut EVMFuzzState) { #[cfg(feature = "force_cache")] macro_rules! force_cache { ($ty: expr, $target: expr) => {{ @@ -340,33 +306,19 @@ where } } -impl OnChain -where - I: Input + VMInputT + EVMInputT + 'static, - S: State - + HasRand - + Debug - + HasCaller - + HasCorpus - + HasItyState - + HasMetadata - + Clone - + UsesInput - + 'static, - VS: VMStateT + Default + 'static, -{ +impl OnChain { #[allow(clippy::too_many_arguments)] pub fn load_code( &mut self, address_h160: EVMAddress, - host: &mut FuzzHost, + host: &mut FuzzHost, force_cache: bool, should_setup_abi: bool, is_proxy_call: bool, caller: EVMAddress, - state: &mut S, + state: &mut EVMFuzzState, ) where - SC: Scheduler + Clone, + SC: Scheduler + Clone, { let contract_code = self.endpoint.get_contract_code(address_h160, force_cache); let code = hex::decode(contract_code).unwrap(); @@ -492,7 +444,6 @@ where state.add_address(&target); // notify flashloan and blacklisting flashloan addresses - #[cfg(feature = "flashloan_v2")] { handle_contract_insertion!( state, @@ -537,9 +488,7 @@ where env: Default::default(), access_pattern: Rc::new(RefCell::new(AccessPattern::new())), - #[cfg(feature = "flashloan_v2")] liquidation_percent: 0, - #[cfg(feature = "flashloan_v2")] input_type: EVMInputTy::ABI, direct_data: Default::default(), randomness: vec![0], diff --git a/src/evm/oracles/erc20.rs b/src/evm/oracles/erc20.rs index 9f604c200..ceee5643a 100644 --- a/src/evm/oracles/erc20.rs +++ b/src/evm/oracles/erc20.rs @@ -1,14 +1,9 @@ -#[cfg(feature = "flashloan_v2")] -use std::collections::HashMap; -#[cfg(feature = "flashloan_v2")] -use std::ops::Deref; -use std::{cell::RefCell, rc::Rc}; +use std::{cell::RefCell, collections::HashMap, ops::Deref, rc::Rc}; use bytes::Bytes; use revm_primitives::Bytecode; +use tracing::debug; -#[cfg(feature = "flashloan_v2")] -use crate::evm::uniswap::TokenContext; use crate::{ evm::{ input::{ConciseEVMInput, EVMInput}, @@ -16,37 +11,21 @@ use crate::{ oracles::ERC20_BUG_IDX, producers::erc20::ERC20Producer, types::{EVMAddress, EVMFuzzState, EVMOracleCtx, EVMU256, EVMU512}, + uniswap::{generate_uniswap_router_sell, TokenContext}, vm::EVMState, }, oracle::Oracle, state::HasExecutionResult, }; -#[cfg(not(feature = "flashloan_v2"))] pub struct IERC20OracleFlashloan { pub balance_of: Vec, -} - -#[cfg(feature = "flashloan_v2")] -pub struct IERC20OracleFlashloan { - pub balance_of: Vec, - #[cfg(feature = "flashloan_v2")] pub known_tokens: HashMap, - #[cfg(feature = "flashloan_v2")] pub known_pair_reserve_slot: HashMap, - #[cfg(feature = "flashloan_v2")] pub erc20_producer: Rc>, } impl IERC20OracleFlashloan { - #[cfg(not(feature = "flashloan_v2"))] - pub fn new(_: Rc>) -> Self { - Self { - balance_of: hex::decode("70a08231").unwrap(), - } - } - - #[cfg(feature = "flashloan_v2")] pub fn new(erc20_producer: Rc>) -> Self { Self { balance_of: hex::decode("70a08231").unwrap(), @@ -56,12 +35,10 @@ impl IERC20OracleFlashloan { } } - #[cfg(feature = "flashloan_v2")] pub fn register_token(&mut self, token: EVMAddress, token_ctx: TokenContext) { self.known_tokens.insert(token, token_ctx); } - #[cfg(feature = "flashloan_v2")] pub fn register_pair_reserve_slot(&mut self, pair: EVMAddress, slot: EVMU256) { self.known_pair_reserve_slot.insert(pair, slot); } @@ -75,33 +52,12 @@ impl 0 } - #[cfg(not(feature = "flashloan_v2"))] - fn oracle(&self, ctx: &mut EVMOracleCtx<'_>, _stage: u64) -> Vec { - // has balance increased? - let exec_res = &ctx.fuzz_state.get_execution_result().new_state.state; - if exec_res.flashloan_data.earned > exec_res.flashloan_data.owed { - EVMBugResult::new_simple( - "erc20".to_string(), - ERC20_BUG_IDX, - format!( - "Earned {}wei more than owed {}wei", - exec_res.flashloan_data.earned, exec_res.flashloan_data.owed - ), - ConciseEVMInput::from_input(ctx.input, ctx.fuzz_state.get_execution_result()), - ) - .push_to_output(); - vec![ERC20_BUG_IDX] - } else { - vec![] - } - } - - #[cfg(feature = "flashloan_v2")] fn oracle(&self, ctx: &mut EVMOracleCtx<'_>, _stage: u64) -> Vec { - use crate::evm::{input::EVMInputT, uniswap::generate_uniswap_router_sell}; - + use crate::evm::input::EVMInputT; + // println!("Oracle: {:?}", ctx.input.get_randomness()); let liquidation_percent = ctx.input.get_liquidation_percent(); if liquidation_percent > 0 { + // println!("Liquidation percent: {}", liquidation_percent); let liquidation_percent = EVMU256::from(liquidation_percent); let mut liquidations_earned = Vec::new(); @@ -110,10 +66,7 @@ impl // prev_balance is nonexistent // #[cfg(feature = "flashloan_debug")] - // debug!( - // "Balance: {} -> {} for {:?} @ {:?}", - // prev_balance, new_balance, caller, token - // ); + debug!("Balance: {} for {:?} @ {:?}", new_balance, caller, token); if *new_balance > EVMU256::ZERO { let liq_amount = *new_balance * liquidation_percent / EVMU256::from(10); @@ -121,33 +74,51 @@ impl } } - let path_idx = ctx.input.get_randomness()[0] as usize; + let _path_idx = ctx.input.get_randomness()[0] as usize; let mut liquidation_txs = vec![]; - // debug!("Liquidations earned: {:?}", liquidations_earned); - for (caller, token_info, amount) in liquidations_earned { - let txs = generate_uniswap_router_sell(token_info, path_idx, amount, ctx.fuzz_state.callers_pool[0]); + for (caller, _token_info, _amount) in liquidations_earned { + // let txs = _token_info.borrow().sell( + // ctx.fuzz_state, + // _amount, + // ctx.fuzz_state.callers_pool[0], + // ctx.input.get_randomness().as_slice(), + // ); + + let txs = generate_uniswap_router_sell(_token_info, _path_idx, _amount, ctx.fuzz_state.callers_pool[0]); if txs.is_none() { continue; } + // liquidation_txs.extend( + // txs.iter() + // .map(|(addr, abi, _)| (caller, *addr, Bytes::from(abi.get_bytes()))), + // ); + liquidation_txs.extend( txs.unwrap() .iter() .map(|(abi, _, addr)| (caller, *addr, Bytes::from(abi.get_bytes()))), ); } - // debug!( - // "Liquidation txs: {:?}", - // liquidation_txs - // ); - // debug!("Earned before liquidation: {:?}", - // ctx.fuzz_state.get_execution_result().new_state.state.flashloan_data.earned); + liquidation_txs.iter().for_each(|(caller, target, by)| { + debug!("Liquidation tx: {:?} -> {:?} ({})", caller, target, hex::encode(by)); + }); + + debug!( + "Earned before liquidation: {:?}", + ctx.fuzz_state + .get_execution_result() + .new_state + .state + .flashloan_data + .earned + ); let (_out, state) = ctx.call_post_batch_dyn(&liquidation_txs); - // debug!("results: {:?}", out); - // debug!("result state: {:?}", state.flashloan_data); + debug!("results: {:?}", _out); + debug!("result state: {:?}", state.flashloan_data); ctx.fuzz_state.get_execution_result_mut().new_state.state = state; } diff --git a/src/evm/oracles/v2_pair.rs b/src/evm/oracles/v2_pair.rs index 5adc64bb5..d0f3914aa 100644 --- a/src/evm/oracles/v2_pair.rs +++ b/src/evm/oracles/v2_pair.rs @@ -57,7 +57,6 @@ impl _stage: u64, ) -> Vec { let mut violations = vec![]; - #[cfg(feature = "flashloan_v2")] { let to_check = ctx .fuzz_state @@ -111,10 +110,6 @@ impl } } } - #[cfg(not(feature = "flashloan_v2"))] - { - panic!("Flashloan v2 required to use pair (-p).") - } violations } } diff --git a/src/evm/presets/mod.rs b/src/evm/presets/mod.rs index f4d648fe1..ada377f17 100644 --- a/src/evm/presets/mod.rs +++ b/src/evm/presets/mod.rs @@ -2,9 +2,10 @@ pub mod pair; use std::{fmt::Debug, fs::File}; -use libafl::{prelude::State, schedulers::Scheduler, state::HasCorpus}; +use libafl::schedulers::Scheduler; use serde::{Deserialize, Deserializer}; +use super::types::EVMFuzzState; use crate::{ evm::{ input::{ConciseEVMInput, EVMInput, EVMInputT}, @@ -13,21 +14,19 @@ use crate::{ }, generic_vm::vm_state::VMStateT, input::VMInputT, - state::HasCaller, }; -pub trait Preset +pub trait Preset where - S: State + HasCorpus + HasCaller + Debug + Clone + 'static, I: VMInputT + EVMInputT, VS: VMStateT, - SC: Scheduler + Clone, + SC: Scheduler + Clone, { fn presets( &self, function_sig: [u8; 4], input: &EVMInput, - evm_executor: &EVMExecutor, + evm_executor: &EVMExecutor, ) -> Vec; } diff --git a/src/evm/presets/pair.rs b/src/evm/presets/pair.rs index 044d78fb2..881bd1304 100644 --- a/src/evm/presets/pair.rs +++ b/src/evm/presets/pair.rs @@ -1,37 +1,30 @@ -use std::fmt::Debug; - -use libafl::{ - schedulers::Scheduler, - state::{HasCorpus, State}, -}; +use libafl::schedulers::Scheduler; use crate::{ evm::{ abi::{A256InnerType, BoxedABI, A256}, input::{ConciseEVMInput, EVMInput, EVMInputT}, presets::Preset, - types::EVMAddress, + types::{EVMAddress, EVMFuzzState}, vm::EVMExecutor, }, generic_vm::vm_state::VMStateT, input::VMInputT, - state::HasCaller, }; pub struct PairPreset; -impl Preset for PairPreset +impl Preset for PairPreset where - S: State + HasCorpus + HasCaller + Debug + Clone + 'static, I: VMInputT + EVMInputT, VS: VMStateT, - SC: Scheduler + Clone, + SC: Scheduler + Clone, { fn presets( &self, function_sig: [u8; 4], input: &EVMInput, - _evm_executor: &EVMExecutor, + _evm_executor: &EVMExecutor, ) -> Vec { let mut res = vec![]; if let [0xbc, 0x25, 0xcf, 0x77] = function_sig { diff --git a/src/evm/producers/erc20.rs b/src/evm/producers/erc20.rs index 5e34f2605..71e065b8d 100644 --- a/src/evm/producers/erc20.rs +++ b/src/evm/producers/erc20.rs @@ -63,7 +63,6 @@ impl ConciseEVMInput, >, ) { - #[cfg(feature = "flashloan_v2")] { let tokens = ctx .fuzz_state diff --git a/src/evm/producers/pair.rs b/src/evm/producers/pair.rs index d926aae82..7bbf5fcdf 100644 --- a/src/evm/producers/pair.rs +++ b/src/evm/producers/pair.rs @@ -62,7 +62,6 @@ impl ConciseEVMInput, >, ) { - #[cfg(feature = "flashloan_v2")] { let reserves = ctx .fuzz_state diff --git a/src/evm/types.rs b/src/evm/types.rs index a0581c942..c081642fa 100644 --- a/src/evm/types.rs +++ b/src/evm/types.rs @@ -71,8 +71,7 @@ pub type EVMFuzzExecutor = FuzzExecutor< ConciseEVMInput, >; -pub type EVMQueueExecutor = - EVMExecutor>; +pub type EVMQueueExecutor = EVMExecutor>; /// convert array of 20x u8 to H160 pub fn convert_h160(v: [u8; 20]) -> H160 { diff --git a/src/evm/vm.rs b/src/evm/vm.rs index d1f8100f2..e32719dcf 100644 --- a/src/evm/vm.rs +++ b/src/evm/vm.rs @@ -15,11 +15,7 @@ use std::{ use bytes::Bytes; /// EVM executor implementation use itertools::Itertools; -use libafl::{ - prelude::{HasMetadata, HasRand, UsesInput}, - schedulers::Scheduler, - state::{HasCorpus, State}, -}; +use libafl::{prelude::HasMetadata, schedulers::Scheduler}; use revm_interpreter::{ BytecodeLocked, CallContext, @@ -36,7 +32,8 @@ use revm_primitives::Bytecode; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use tracing::{debug, error}; -use super::middlewares::reentrancy::ReentrancyData; +use super::{input::EVMInput, middlewares::reentrancy::ReentrancyData, types::EVMFuzzState}; +use crate::evm::uniswap::generate_uniswap_router_buy; // Some components are used when `flashloan_v2` feature is disabled #[allow(unused_imports)] use crate::{ @@ -47,7 +44,6 @@ use crate::{ middlewares::middleware::Middleware, onchain::flashloan::FlashloanData, types::{float_scale_to_u512, EVMAddress, EVMU256, EVMU512}, - uniswap::generate_uniswap_router_buy, vm::Constraint::{NoLiquidation, Value}, }, generic_vm::{ @@ -406,21 +402,19 @@ pub static mut IS_FAST_CALL_STATIC: bool = false; /// EVM executor, wrapper of revm #[derive(Debug, Clone)] -pub struct EVMExecutor +pub struct EVMExecutor where - S: State + HasCorpus + HasCaller + Debug + Clone + 'static, - I: VMInputT + EVMInputT, VS: VMStateT, - SC: Scheduler + Clone, + SC: Scheduler + Clone, { /// Host providing the blockchain environment (e.g., writing/reading /// storage), needed by revm - pub host: FuzzHost, + pub host: FuzzHost, /// [Depreciated] Deployer address pub deployer: EVMAddress, /// Known arbitrary (caller,pc) pub _known_arbitrary: HashSet<(EVMAddress, usize)>, - phandom: PhantomData<(I, S, VS, CI)>, + phandom: PhantomData<(EVMInput, VS, CI)>, } pub fn is_reverted_or_control_leak(ret: &InstructionResult) -> bool { @@ -448,27 +442,83 @@ pub struct IntermediateExecutionResult { pub memory: Vec, } -impl EVMExecutor +macro_rules! init_host { + ($host:expr) => { + $host.current_self_destructs = vec![]; + $host.current_arbitrary_calls = vec![]; + $host.call_count = 0; + $host.jumpi_trace = 37; + $host.current_typed_bug = vec![]; + $host.randomness = vec![9]; + // Uncomment the next line if middleware is needed. + // $host.add_middlewares(middleware.clone()); + }; +} + +macro_rules! execute_call_single { + ($ctx:expr, $host:expr, $state:expr, $address: expr, $by: expr) => {{ + let code = $host.code.get($address).expect("no code").clone(); + let call = Contract::new_with_context_analyzed($by.clone(), code, &$ctx); + let mut interp = Interpreter::new_with_memory_limit(call, 1e10 as u64, false, MEM_LIMIT); + let ret = $host.run_inspect(&mut interp, $state); + (interp.return_value().to_vec(), is_call_success!(ret)) + }}; +} + +impl EVMExecutor where - I: VMInputT + EVMInputT + 'static, - S: State - + HasRand - + HasCorpus - + HasItyState - + HasMetadata - + HasCaller - + HasCurrentInputIdx - + Default - + Clone - + Debug - + UsesInput - + 'static, VS: Default + VMStateT + 'static, CI: Serialize + DeserializeOwned + Debug + Clone + ConciseSerde + 'static, - SC: Scheduler + Clone + 'static, + SC: Scheduler + Clone + 'static, { + fn fast_call( + &mut self, + address: EVMAddress, + data: Bytes, + vm_state: &EVMState, + state: &mut EVMFuzzState, + value: EVMU256, + from: EVMAddress, + ) -> IntermediateExecutionResult { + unsafe { + IS_FAST_CALL = true; + } + // debug!("fast call: {:?} {:?} with {}", address, hex::encode(data.to_vec()), + // value); + let call = Contract::new_with_context_analyzed( + data, + self.host + .code + .get(&address) + .unwrap_or_else(|| panic!("no code {:?}", address)) + .clone(), + &CallContext { + address, + caller: from, + code_address: address, + apparent_value: value, + scheme: CallScheme::Call, + }, + ); + unsafe { + self.host.evmstate = vm_state.as_any().downcast_ref_unchecked::().clone(); + } + let mut interp = Interpreter::new_with_memory_limit(call, 1e10 as u64, false, MEM_LIMIT); + let ret = self.host.run_inspect(&mut interp, state); + unsafe { + IS_FAST_CALL = false; + } + IntermediateExecutionResult { + output: interp.return_value(), + new_state: self.host.evmstate.clone(), + pc: interp.program_counter(), + ret, + stack: Default::default(), + memory: Default::default(), + } + } /// Create a new EVM executor given a host and deployer address - pub fn new(fuzz_host: FuzzHost, deployer: EVMAddress) -> Self { + pub fn new(fuzz_host: FuzzHost, deployer: EVMAddress) -> Self { Self { host: fuzz_host, deployer, @@ -495,9 +545,9 @@ where call_ctx: &CallContext, vm_state: &EVMState, data: Bytes, - input: &I, + input: &EVMInput, post_exec: Option, - state: &mut S, + state: &mut EVMFuzzState, cleanup: bool, ) -> IntermediateExecutionResult { // Initial setups @@ -598,72 +648,21 @@ where } // hack to record txn value - #[cfg(feature = "flashloan_v2")] if let Some(ref m) = self.host.flashloan_middleware { m.deref() .borrow_mut() .analyze_call(input, &mut result.new_state.flashloan_data) } - #[cfg(not(feature = "flashloan_v2"))] - { - result.new_state.flashloan_data.owed += - EVMU512::from(call_ctx.apparent_value) * float_scale_to_u512(1.0, 5); - } - result } - /// Conduct a fast call that does not write to the feedback - fn fast_call( - &mut self, - address: EVMAddress, - data: Bytes, - vm_state: &VS, - state: &mut S, - value: EVMU256, - from: EVMAddress, - ) -> IntermediateExecutionResult { - unsafe { - IS_FAST_CALL = true; - } - // debug!("fast call: {:?} {:?} with {}", address, hex::encode(data.to_vec()), - // value); - let call = Contract::new_with_context_analyzed( - data, - self.host - .code - .get(&address) - .unwrap_or_else(|| panic!("no code {:?}", address)) - .clone(), - &CallContext { - address, - caller: from, - code_address: address, - apparent_value: value, - scheme: CallScheme::Call, - }, - ); - unsafe { - self.host.evmstate = vm_state.as_any().downcast_ref_unchecked::().clone(); - } - let mut interp = Interpreter::new_with_memory_limit(call, 1e10 as u64, false, MEM_LIMIT); - let ret = self.host.run_inspect(&mut interp, state); - unsafe { - IS_FAST_CALL = false; - } - IntermediateExecutionResult { - output: interp.return_value(), - new_state: self.host.evmstate.clone(), - pc: interp.program_counter(), - ret, - stack: Default::default(), - memory: Default::default(), - } - } - /// Execute a transaction, wrapper of [`EVMExecutor::execute_from_pc`] - fn execute_abi(&mut self, input: &I, state: &mut S) -> ExecutionResult, CI> { + fn execute_abi( + &mut self, + input: &EVMInput, + state: &mut EVMFuzzState, + ) -> ExecutionResult, CI> { // Get necessary info from input let mut vm_state = unsafe { input.get_state().as_any().downcast_ref_unchecked::().clone() }; @@ -857,38 +856,77 @@ where pub fn reexecute_with_middleware( &mut self, - input: &I, - state: &mut S, - middleware: Rc>>, + input: &EVMInput, + state: &mut EVMFuzzState, + middleware: Rc>>, ) { self.host.add_middlewares(middleware.clone()); self.execute(input, state); self.host.remove_middlewares(middleware); } + + fn fast_call_inner( + &mut self, + data: &[(EVMAddress, EVMAddress, Bytes, EVMU256)], + vm_state: &EVMState, + state: &mut EVMFuzzState, + ) -> (Vec<(Vec, bool)>, EVMState) { + unsafe { + self.host.evmstate = vm_state.clone(); + } + init_host!(self.host); + let res = data + .iter() + .map(|(caller, address, by, value)| { + let ctx = CallContext { + address: *address, + caller: *caller, + code_address: *address, + apparent_value: *value, + scheme: CallScheme::Call, + }; + execute_call_single!(ctx, self.host, state, address, by) + }) + .collect::, bool)>>(); + (res, self.host.evmstate.clone()) + } + + fn fast_call_inner_no_value( + &mut self, + data: &[(EVMAddress, EVMAddress, Bytes)], + vm_state: &EVMState, + state: &mut EVMFuzzState, + ) -> (Vec<(Vec, bool)>, EVMState) { + unsafe { + self.host.evmstate = vm_state.clone(); + } + init_host!(self.host); + let res = data + .iter() + .map(|(caller, address, by)| { + let ctx = CallContext { + address: *address, + caller: *caller, + code_address: *address, + apparent_value: Default::default(), + scheme: CallScheme::Call, + }; + execute_call_single!(ctx, self.host, state, address, by) + }) + .collect::, bool)>>(); + (res, self.host.evmstate.clone()) + } } pub static mut IN_DEPLOY: bool = false; pub static mut SETCODE_ONLY: bool = false; -impl GenericVM, I, S, CI> - for EVMExecutor +impl GenericVM, EVMInput, EVMFuzzState, CI> + for EVMExecutor where - I: VMInputT + EVMInputT + 'static, - S: State - + HasRand - + HasCorpus - + HasItyState - + HasMetadata - + HasCaller - + HasCurrentInputIdx - + Default - + Clone - + Debug - + UsesInput - + 'static, VS: VMStateT + Default + 'static, CI: Serialize + DeserializeOwned + Debug + Clone + ConciseSerde + 'static, - SC: Scheduler + Clone + 'static, + SC: Scheduler + Clone + 'static, { /// Deploy a contract fn deploy( @@ -896,7 +934,7 @@ where code: Bytecode, constructor_args: Option, deployed_address: EVMAddress, - state: &mut S, + state: &mut EVMFuzzState, ) -> Option { debug!("deployer = 0x{} ", hex::encode(self.deployer)); let deployer = Contract::new( @@ -912,7 +950,7 @@ where IN_DEPLOY = true; } let mut interp = Interpreter::new_with_memory_limit(deployer, 1e10 as u64, false, MEM_LIMIT); - let mut dummy_state = S::default(); + let mut dummy_state = EVMFuzzState::default(); let r = self.host.run_inspect(&mut interp, &mut dummy_state); unsafe { IN_DEPLOY = false; @@ -942,17 +980,12 @@ where Some(deployed_address) } - /// Execute an input (transaction) - #[cfg(not(feature = "flashloan_v2"))] - fn execute(&mut self, input: &I, state: &mut S) -> ExecutionResult, CI> { - use super::host::clear_branch_status; - clear_branch_status(); - self.execute_abi(input, state) - } - /// Execute an input (can be transaction or borrow) - #[cfg(feature = "flashloan_v2")] - fn execute(&mut self, input: &I, state: &mut S) -> ExecutionResult, CI> { + fn execute( + &mut self, + input: &EVMInput, + state: &mut EVMFuzzState, + ) -> ExecutionResult, CI> { use super::host::clear_branch_status; clear_branch_status(); match input.get_input_type() { @@ -968,7 +1001,6 @@ where input.get_txn_value().unwrap(), input.get_caller(), ); - // execute the transaction to get the state with the token borrowed match call_info { Some((abi, value, target)) => { let bys = abi.get_bytes(); @@ -980,7 +1012,6 @@ where value, input.get_caller(), ); - #[cfg(feature = "flashloan_v2")] if let Some(ref m) = self.host.flashloan_middleware { m.deref() .borrow_mut() @@ -1001,7 +1032,11 @@ where // we don't have enough liquidity to buy the token output: vec![], reverted: false, - new_state: StagedVMState::new_with_state(input.get_state().clone()), + new_state: StagedVMState::new_with_state(unsafe { + VMStateT::as_any(input.get_state()) + .downcast_ref_unchecked::() + .clone() + }), additional_info: None, }, } @@ -1015,7 +1050,12 @@ where } /// Execute a static call - fn fast_static_call(&mut self, data: &[(EVMAddress, Bytes)], vm_state: &VS, state: &mut S) -> Vec> { + fn fast_static_call( + &mut self, + data: &[(EVMAddress, Bytes)], + vm_state: &VS, + state: &mut EVMFuzzState, + ) -> Vec> { unsafe { IS_FAST_CALL_STATIC = true; self.host.evmstate = vm_state.as_any().downcast_ref_unchecked::().clone(); @@ -1060,18 +1100,13 @@ where &mut self, data: &[(EVMAddress, EVMAddress, Bytes)], vm_state: &VS, - state: &mut S, + state: &mut EVMFuzzState, ) -> (Vec<(Vec, bool)>, VS) { unsafe { // IS_FAST_CALL = true; self.host.evmstate = vm_state.as_any().downcast_ref_unchecked::().clone(); } - self.host.current_self_destructs = vec![]; - self.host.current_arbitrary_calls = vec![]; - self.host.call_count = 0; - self.host.jumpi_trace = 37; - self.host.current_typed_bug = vec![]; - self.host.randomness = vec![9]; + init_host!(self.host); // self.host.add_middlewares(middleware.clone()); @@ -1085,17 +1120,7 @@ where apparent_value: Default::default(), scheme: CallScheme::Call, }; - let code = self.host.code.get(address).expect("no code").clone(); - let call = Contract::new_with_context_analyzed(by.clone(), code.clone(), &ctx); - let mut interp = Interpreter::new_with_memory_limit(call, 1e10 as u64, false, MEM_LIMIT); - let ret = self.host.run_inspect(&mut interp, state); - // debug!("ret: {:?} {} {}", ret,hex::encode(by.clone()), - // hex::encode(interp.return_data_buffer.clone())); - if is_call_success!(ret) { - (interp.return_value().to_vec(), true) - } else { - (vec![], false) - } + execute_call_single!(ctx, self.host, state, address, by) }) .collect::, bool)>>(); @@ -1159,13 +1184,7 @@ mod tests { if !path.exists() { std::fs::create_dir(path).unwrap(); } - let mut evm_executor: EVMExecutor< - EVMInput, - EVMFuzzState, - EVMState, - ConciseEVMInput, - StdScheduler, - > = EVMExecutor::new( + let mut evm_executor: EVMExecutor> = EVMExecutor::new( FuzzHost::new(StdScheduler::new(), "work_dir".to_string()), generate_random_address(&mut state), ); @@ -1204,7 +1223,6 @@ mod tests { step: false, env: Default::default(), access_pattern: Rc::new(RefCell::new(AccessPattern::new())), - #[cfg(feature = "flashloan_v2")] liquidation_percent: 0, direct_data: Bytes::from( [ @@ -1213,7 +1231,6 @@ mod tests { ] .concat(), ), - #[cfg(feature = "flashloan_v2")] input_type: EVMInputTy::ABI, randomness: vec![], repeat: 1, @@ -1243,7 +1260,6 @@ mod tests { step: false, env: Default::default(), access_pattern: Rc::new(RefCell::new(AccessPattern::new())), - #[cfg(feature = "flashloan_v2")] liquidation_percent: 0, direct_data: Bytes::from( [ @@ -1252,7 +1268,6 @@ mod tests { ] .concat(), ), - #[cfg(feature = "flashloan_v2")] input_type: EVMInputTy::ABI, randomness: vec![], repeat: 1, diff --git a/src/fuzzers/evm_fuzzer.rs b/src/fuzzers/evm_fuzzer.rs index 6ecff58cf..71d87fd05 100644 --- a/src/fuzzers/evm_fuzzer.rs +++ b/src/fuzzers/evm_fuzzer.rs @@ -135,7 +135,7 @@ pub fn evm_fuzzer( let onchain_middleware = match config.onchain.clone() { Some(onchain) => { Some({ - let mid = Rc::new(RefCell::new(OnChain::::new( + let mid = Rc::new(RefCell::new(OnChain::new( // scheduler can be cloned because it never uses &mut self onchain, config.onchain_storage_fetching.unwrap(), @@ -188,19 +188,11 @@ pub fn evm_fuzzer( if config.flashloan { // we should use real balance of tokens in the contract instead of providing // flashloan to contract as well for on chain env - #[cfg(not(feature = "flashloan_v2"))] - fuzz_host.add_middlewares(Rc::new(RefCell::new( - Flashloan::::new(config.onchain.is_some()), - ))); - - #[cfg(feature = "flashloan_v2")] { - assert!(onchain_middleware.is_some(), "Flashloan v2 requires onchain env"); - fuzz_host.add_flashloan_middleware(Flashloan::::new( + fuzz_host.add_flashloan_middleware(Flashloan::new( true, - config.onchain.clone().unwrap(), + config.onchain.clone(), config.price_oracle, - onchain_middleware.clone().unwrap(), config.flashloan_oracle, )); } diff --git a/tests/evm/flashloan/token_config.json b/tests/evm/flashloan/token_config.json new file mode 100644 index 000000000..2e201b6d4 --- /dev/null +++ b/tests/evm/flashloan/token_config.json @@ -0,0 +1,9 @@ +{ + "0xA0A2eE912CAF7921eaAbC866c6ef6FEc8f7E90A4": [ + { + "type": "constant", + "price": 1, + "faucet": "0xA0A2eE912CAF7921eaAbC866c6ef6FEc8f7E90A4" + } + ] +} \ No newline at end of file diff --git a/tests/evm/sbe@constant-product-amm-_update/CPAMM.sol b/tests/evm/sbe@constant-product-amm-_update/CPAMM.sol deleted file mode 100644 index 1ce4ce6f6..000000000 --- a/tests/evm/sbe@constant-product-amm-_update/CPAMM.sol +++ /dev/null @@ -1,246 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.13; -import "../../../solidity_utils/lib.sol"; -contract CPAMM { - IERC20 public immutable token0; - IERC20 public immutable token1; - - uint public reserve0; - uint public reserve1; - - uint public totalSupply; - mapping(address => uint) public balanceOf; - - constructor(address _token0, address _token1) { - token0 = IERC20(_token0); - token1 = IERC20(_token1); - } - - function _mint(address _to, uint _amount) private { - balanceOf[_to] += _amount; - totalSupply += _amount; - } - - function _burn(address _from, uint _amount) private { - balanceOf[_from] -= _amount; - totalSupply -= _amount; - } - - function _update(uint _reserve0, uint _reserve1) private { - reserve0 = _reserve0; - reserve1 = _reserve1; - bug(); - } - - function swap(address _tokenIn, uint _amountIn) external returns (uint amountOut) { - require( - _tokenIn == address(token0) || _tokenIn == address(token1), - "invalid token" - ); - require(_amountIn > 0, "amount in = 0"); - - bool isToken0 = _tokenIn == address(token0); - (IERC20 tokenIn, IERC20 tokenOut, uint reserveIn, uint reserveOut) = isToken0 - ? (token0, token1, reserve0, reserve1) - : (token1, token0, reserve1, reserve0); - - tokenIn.transferFrom(msg.sender, address(this), _amountIn); - - /* - How much dy for dx? - - xy = k - (x + dx)(y - dy) = k - y - dy = k / (x + dx) - y - k / (x + dx) = dy - y - xy / (x + dx) = dy - (yx + ydx - xy) / (x + dx) = dy - ydx / (x + dx) = dy - */ - // 0.3% fee - uint amountInWithFee = (_amountIn * 997) / 1000; - amountOut = (reserveOut * amountInWithFee) / (reserveIn + amountInWithFee); - - tokenOut.transfer(msg.sender, amountOut); - - _update(token0.balanceOf(address(this)), token1.balanceOf(address(this))); - } - - function addLiquidity(uint _amount0, uint _amount1) external returns (uint shares) { - token0.transferFrom(msg.sender, address(this), _amount0); - token1.transferFrom(msg.sender, address(this), _amount1); - - /* - How much dx, dy to add? - - xy = k - (x + dx)(y + dy) = k' - - No price change, before and after adding liquidity - x / y = (x + dx) / (y + dy) - - x(y + dy) = y(x + dx) - x * dy = y * dx - - x / y = dx / dy - dy = y / x * dx - */ - if (reserve0 > 0 || reserve1 > 0) { - require(reserve0 * _amount1 == reserve1 * _amount0, "x / y != dx / dy"); - } - - /* - How much shares to mint? - - f(x, y) = value of liquidity - We will define f(x, y) = sqrt(xy) - - L0 = f(x, y) - L1 = f(x + dx, y + dy) - T = total shares - s = shares to mint - - Total shares should increase proportional to increase in liquidity - L1 / L0 = (T + s) / T - - L1 * T = L0 * (T + s) - - (L1 - L0) * T / L0 = s - */ - - /* - Claim - (L1 - L0) / L0 = dx / x = dy / y - - Proof - --- Equation 1 --- - (L1 - L0) / L0 = (sqrt((x + dx)(y + dy)) - sqrt(xy)) / sqrt(xy) - - dx / dy = x / y so replace dy = dx * y / x - - --- Equation 2 --- - Equation 1 = (sqrt(xy + 2ydx + dx^2 * y / x) - sqrt(xy)) / sqrt(xy) - - Multiply by sqrt(x) / sqrt(x) - Equation 2 = (sqrt(x^2y + 2xydx + dx^2 * y) - sqrt(x^2y)) / sqrt(x^2y) - = (sqrt(y)(sqrt(x^2 + 2xdx + dx^2) - sqrt(x^2)) / (sqrt(y)sqrt(x^2)) - - sqrt(y) on top and bottom cancels out - - --- Equation 3 --- - Equation 2 = (sqrt(x^2 + 2xdx + dx^2) - sqrt(x^2)) / (sqrt(x^2) - = (sqrt((x + dx)^2) - sqrt(x^2)) / sqrt(x^2) - = ((x + dx) - x) / x - = dx / x - - Since dx / dy = x / y, - dx / x = dy / y - - Finally - (L1 - L0) / L0 = dx / x = dy / y - */ - if (totalSupply == 0) { - shares = _sqrt(_amount0 * _amount1); - } else { - shares = _min( - (_amount0 * totalSupply) / reserve0, - (_amount1 * totalSupply) / reserve1 - ); - } - require(shares > 0, "shares = 0"); - _mint(msg.sender, shares); - - _update(token0.balanceOf(address(this)), token1.balanceOf(address(this))); - } - - function removeLiquidity( - uint _shares - ) external returns (uint amount0, uint amount1) { - /* - Claim - dx, dy = amount of liquidity to remove - dx = s / T * x - dy = s / T * y - - Proof - Let's find dx, dy such that - v / L = s / T - - where - v = f(dx, dy) = sqrt(dxdy) - L = total liquidity = sqrt(xy) - s = shares - T = total supply - - --- Equation 1 --- - v = s / T * L - sqrt(dxdy) = s / T * sqrt(xy) - - Amount of liquidity to remove must not change price so - dx / dy = x / y - - replace dy = dx * y / x - sqrt(dxdy) = sqrt(dx * dx * y / x) = dx * sqrt(y / x) - - Divide both sides of Equation 1 with sqrt(y / x) - dx = s / T * sqrt(xy) / sqrt(y / x) - = s / T * sqrt(x^2) = s / T * x - - Likewise - dy = s / T * y - */ - - // bal0 >= reserve0 - // bal1 >= reserve1 - uint bal0 = token0.balanceOf(address(this)); - uint bal1 = token1.balanceOf(address(this)); - - amount0 = (_shares * bal0) / totalSupply; - amount1 = (_shares * bal1) / totalSupply; - require(amount0 > 0 && amount1 > 0, "amount0 or amount1 = 0"); - - _burn(msg.sender, _shares); - _update(bal0 - amount0, bal1 - amount1); - - token0.transfer(msg.sender, amount0); - token1.transfer(msg.sender, amount1); - } - - function _sqrt(uint y) private pure returns (uint z) { - if (y > 3) { - z = y; - uint x = y / 2 + 1; - while (x < z) { - z = x; - x = (y / x + x) / 2; - } - } else if (y != 0) { - z = 1; - } - } - - function _min(uint x, uint y) private pure returns (uint) { - return x <= y ? x : y; - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint amount); - event Approval(address indexed owner, address indexed spender, uint amount); -} diff --git a/tests/evm/sbe@constant-product-amm-addLiquidity1/CPAMM.sol b/tests/evm/sbe@constant-product-amm-addLiquidity1/CPAMM.sol deleted file mode 100644 index 3e8c8a72c..000000000 --- a/tests/evm/sbe@constant-product-amm-addLiquidity1/CPAMM.sol +++ /dev/null @@ -1,246 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.13; -import "../../../solidity_utils/lib.sol"; -contract CPAMM { - IERC20 public immutable token0; - IERC20 public immutable token1; - - uint public reserve0; - uint public reserve1; - - uint public totalSupply; - mapping(address => uint) public balanceOf; - - constructor(address _token0, address _token1) { - token0 = IERC20(_token0); - token1 = IERC20(_token1); - } - - function _mint(address _to, uint _amount) private { - balanceOf[_to] += _amount; - totalSupply += _amount; - } - - function _burn(address _from, uint _amount) private { - balanceOf[_from] -= _amount; - totalSupply -= _amount; - } - - function _update(uint _reserve0, uint _reserve1) private { - reserve0 = _reserve0; - reserve1 = _reserve1; - } - - function swap(address _tokenIn, uint _amountIn) external returns (uint amountOut) { - require( - _tokenIn == address(token0) || _tokenIn == address(token1), - "invalid token" - ); - require(_amountIn > 0, "amount in = 0"); - - bool isToken0 = _tokenIn == address(token0); - (IERC20 tokenIn, IERC20 tokenOut, uint reserveIn, uint reserveOut) = isToken0 - ? (token0, token1, reserve0, reserve1) - : (token1, token0, reserve1, reserve0); - - tokenIn.transferFrom(msg.sender, address(this), _amountIn); - - /* - How much dy for dx? - - xy = k - (x + dx)(y - dy) = k - y - dy = k / (x + dx) - y - k / (x + dx) = dy - y - xy / (x + dx) = dy - (yx + ydx - xy) / (x + dx) = dy - ydx / (x + dx) = dy - */ - // 0.3% fee - uint amountInWithFee = (_amountIn * 997) / 1000; - amountOut = (reserveOut * amountInWithFee) / (reserveIn + amountInWithFee); - - tokenOut.transfer(msg.sender, amountOut); - - _update(token0.balanceOf(address(this)), token1.balanceOf(address(this))); - } - - function addLiquidity(uint _amount0, uint _amount1) external returns (uint shares) { - token0.transferFrom(msg.sender, address(this), _amount0); - token1.transferFrom(msg.sender, address(this), _amount1); - - /* - How much dx, dy to add? - - xy = k - (x + dx)(y + dy) = k' - - No price change, before and after adding liquidity - x / y = (x + dx) / (y + dy) - - x(y + dy) = y(x + dx) - x * dy = y * dx - - x / y = dx / dy - dy = y / x * dx - */ - if (reserve0 > 0 || reserve1 > 0) { - require(reserve0 * _amount1 == reserve1 * _amount0, "x / y != dx / dy"); - bug(); - } - - /* - How much shares to mint? - - f(x, y) = value of liquidity - We will define f(x, y) = sqrt(xy) - - L0 = f(x, y) - L1 = f(x + dx, y + dy) - T = total shares - s = shares to mint - - Total shares should increase proportional to increase in liquidity - L1 / L0 = (T + s) / T - - L1 * T = L0 * (T + s) - - (L1 - L0) * T / L0 = s - */ - - /* - Claim - (L1 - L0) / L0 = dx / x = dy / y - - Proof - --- Equation 1 --- - (L1 - L0) / L0 = (sqrt((x + dx)(y + dy)) - sqrt(xy)) / sqrt(xy) - - dx / dy = x / y so replace dy = dx * y / x - - --- Equation 2 --- - Equation 1 = (sqrt(xy + 2ydx + dx^2 * y / x) - sqrt(xy)) / sqrt(xy) - - Multiply by sqrt(x) / sqrt(x) - Equation 2 = (sqrt(x^2y + 2xydx + dx^2 * y) - sqrt(x^2y)) / sqrt(x^2y) - = (sqrt(y)(sqrt(x^2 + 2xdx + dx^2) - sqrt(x^2)) / (sqrt(y)sqrt(x^2)) - - sqrt(y) on top and bottom cancels out - - --- Equation 3 --- - Equation 2 = (sqrt(x^2 + 2xdx + dx^2) - sqrt(x^2)) / (sqrt(x^2) - = (sqrt((x + dx)^2) - sqrt(x^2)) / sqrt(x^2) - = ((x + dx) - x) / x - = dx / x - - Since dx / dy = x / y, - dx / x = dy / y - - Finally - (L1 - L0) / L0 = dx / x = dy / y - */ - if (totalSupply == 0) { - shares = _sqrt(_amount0 * _amount1); - } else { - shares = _min( - (_amount0 * totalSupply) / reserve0, - (_amount1 * totalSupply) / reserve1 - ); - } - require(shares > 0, "shares = 0"); - _mint(msg.sender, shares); - - _update(token0.balanceOf(address(this)), token1.balanceOf(address(this))); - } - - function removeLiquidity( - uint _shares - ) external returns (uint amount0, uint amount1) { - /* - Claim - dx, dy = amount of liquidity to remove - dx = s / T * x - dy = s / T * y - - Proof - Let's find dx, dy such that - v / L = s / T - - where - v = f(dx, dy) = sqrt(dxdy) - L = total liquidity = sqrt(xy) - s = shares - T = total supply - - --- Equation 1 --- - v = s / T * L - sqrt(dxdy) = s / T * sqrt(xy) - - Amount of liquidity to remove must not change price so - dx / dy = x / y - - replace dy = dx * y / x - sqrt(dxdy) = sqrt(dx * dx * y / x) = dx * sqrt(y / x) - - Divide both sides of Equation 1 with sqrt(y / x) - dx = s / T * sqrt(xy) / sqrt(y / x) - = s / T * sqrt(x^2) = s / T * x - - Likewise - dy = s / T * y - */ - - // bal0 >= reserve0 - // bal1 >= reserve1 - uint bal0 = token0.balanceOf(address(this)); - uint bal1 = token1.balanceOf(address(this)); - - amount0 = (_shares * bal0) / totalSupply; - amount1 = (_shares * bal1) / totalSupply; - require(amount0 > 0 && amount1 > 0, "amount0 or amount1 = 0"); - - _burn(msg.sender, _shares); - _update(bal0 - amount0, bal1 - amount1); - - token0.transfer(msg.sender, amount0); - token1.transfer(msg.sender, amount1); - } - - function _sqrt(uint y) private pure returns (uint z) { - if (y > 3) { - z = y; - uint x = y / 2 + 1; - while (x < z) { - z = x; - x = (y / x + x) / 2; - } - } else if (y != 0) { - z = 1; - } - } - - function _min(uint x, uint y) private pure returns (uint) { - return x <= y ? x : y; - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint amount); - event Approval(address indexed owner, address indexed spender, uint amount); -} diff --git a/tests/evm/sbe@constant-product-amm-addLiquidity2/CPAMM.sol b/tests/evm/sbe@constant-product-amm-addLiquidity2/CPAMM.sol deleted file mode 100644 index 98b7997e3..000000000 --- a/tests/evm/sbe@constant-product-amm-addLiquidity2/CPAMM.sol +++ /dev/null @@ -1,246 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.13; -import "../../../solidity_utils/lib.sol"; -contract CPAMM { - IERC20 public immutable token0; - IERC20 public immutable token1; - - uint public reserve0; - uint public reserve1; - - uint public totalSupply; - mapping(address => uint) public balanceOf; - - constructor(address _token0, address _token1) { - token0 = IERC20(_token0); - token1 = IERC20(_token1); - } - - function _mint(address _to, uint _amount) private { - balanceOf[_to] += _amount; - totalSupply += _amount; - } - - function _burn(address _from, uint _amount) private { - balanceOf[_from] -= _amount; - totalSupply -= _amount; - } - - function _update(uint _reserve0, uint _reserve1) private { - reserve0 = _reserve0; - reserve1 = _reserve1; - } - - function swap(address _tokenIn, uint _amountIn) external returns (uint amountOut) { - require( - _tokenIn == address(token0) || _tokenIn == address(token1), - "invalid token" - ); - require(_amountIn > 0, "amount in = 0"); - - bool isToken0 = _tokenIn == address(token0); - (IERC20 tokenIn, IERC20 tokenOut, uint reserveIn, uint reserveOut) = isToken0 - ? (token0, token1, reserve0, reserve1) - : (token1, token0, reserve1, reserve0); - - tokenIn.transferFrom(msg.sender, address(this), _amountIn); - - /* - How much dy for dx? - - xy = k - (x + dx)(y - dy) = k - y - dy = k / (x + dx) - y - k / (x + dx) = dy - y - xy / (x + dx) = dy - (yx + ydx - xy) / (x + dx) = dy - ydx / (x + dx) = dy - */ - // 0.3% fee - uint amountInWithFee = (_amountIn * 997) / 1000; - amountOut = (reserveOut * amountInWithFee) / (reserveIn + amountInWithFee); - - tokenOut.transfer(msg.sender, amountOut); - - _update(token0.balanceOf(address(this)), token1.balanceOf(address(this))); - } - - function addLiquidity(uint _amount0, uint _amount1) external returns (uint shares) { - token0.transferFrom(msg.sender, address(this), _amount0); - token1.transferFrom(msg.sender, address(this), _amount1); - - /* - How much dx, dy to add? - - xy = k - (x + dx)(y + dy) = k' - - No price change, before and after adding liquidity - x / y = (x + dx) / (y + dy) - - x(y + dy) = y(x + dx) - x * dy = y * dx - - x / y = dx / dy - dy = y / x * dx - */ - if (reserve0 > 0 || reserve1 > 0) { - require(reserve0 * _amount1 == reserve1 * _amount0, "x / y != dx / dy"); - } - - /* - How much shares to mint? - - f(x, y) = value of liquidity - We will define f(x, y) = sqrt(xy) - - L0 = f(x, y) - L1 = f(x + dx, y + dy) - T = total shares - s = shares to mint - - Total shares should increase proportional to increase in liquidity - L1 / L0 = (T + s) / T - - L1 * T = L0 * (T + s) - - (L1 - L0) * T / L0 = s - */ - - /* - Claim - (L1 - L0) / L0 = dx / x = dy / y - - Proof - --- Equation 1 --- - (L1 - L0) / L0 = (sqrt((x + dx)(y + dy)) - sqrt(xy)) / sqrt(xy) - - dx / dy = x / y so replace dy = dx * y / x - - --- Equation 2 --- - Equation 1 = (sqrt(xy + 2ydx + dx^2 * y / x) - sqrt(xy)) / sqrt(xy) - - Multiply by sqrt(x) / sqrt(x) - Equation 2 = (sqrt(x^2y + 2xydx + dx^2 * y) - sqrt(x^2y)) / sqrt(x^2y) - = (sqrt(y)(sqrt(x^2 + 2xdx + dx^2) - sqrt(x^2)) / (sqrt(y)sqrt(x^2)) - - sqrt(y) on top and bottom cancels out - - --- Equation 3 --- - Equation 2 = (sqrt(x^2 + 2xdx + dx^2) - sqrt(x^2)) / (sqrt(x^2) - = (sqrt((x + dx)^2) - sqrt(x^2)) / sqrt(x^2) - = ((x + dx) - x) / x - = dx / x - - Since dx / dy = x / y, - dx / x = dy / y - - Finally - (L1 - L0) / L0 = dx / x = dy / y - */ - if (totalSupply == 0) { - shares = _sqrt(_amount0 * _amount1); - bug(); - } else { - shares = _min( - (_amount0 * totalSupply) / reserve0, - (_amount1 * totalSupply) / reserve1 - ); - } - require(shares > 0, "shares = 0"); - _mint(msg.sender, shares); - - _update(token0.balanceOf(address(this)), token1.balanceOf(address(this))); - } - - function removeLiquidity( - uint _shares - ) external returns (uint amount0, uint amount1) { - /* - Claim - dx, dy = amount of liquidity to remove - dx = s / T * x - dy = s / T * y - - Proof - Let's find dx, dy such that - v / L = s / T - - where - v = f(dx, dy) = sqrt(dxdy) - L = total liquidity = sqrt(xy) - s = shares - T = total supply - - --- Equation 1 --- - v = s / T * L - sqrt(dxdy) = s / T * sqrt(xy) - - Amount of liquidity to remove must not change price so - dx / dy = x / y - - replace dy = dx * y / x - sqrt(dxdy) = sqrt(dx * dx * y / x) = dx * sqrt(y / x) - - Divide both sides of Equation 1 with sqrt(y / x) - dx = s / T * sqrt(xy) / sqrt(y / x) - = s / T * sqrt(x^2) = s / T * x - - Likewise - dy = s / T * y - */ - - // bal0 >= reserve0 - // bal1 >= reserve1 - uint bal0 = token0.balanceOf(address(this)); - uint bal1 = token1.balanceOf(address(this)); - - amount0 = (_shares * bal0) / totalSupply; - amount1 = (_shares * bal1) / totalSupply; - require(amount0 > 0 && amount1 > 0, "amount0 or amount1 = 0"); - - _burn(msg.sender, _shares); - _update(bal0 - amount0, bal1 - amount1); - - token0.transfer(msg.sender, amount0); - token1.transfer(msg.sender, amount1); - } - - function _sqrt(uint y) private pure returns (uint z) { - if (y > 3) { - z = y; - uint x = y / 2 + 1; - while (x < z) { - z = x; - x = (y / x + x) / 2; - } - } else if (y != 0) { - z = 1; - } - } - - function _min(uint x, uint y) private pure returns (uint) { - return x <= y ? x : y; - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint amount); - event Approval(address indexed owner, address indexed spender, uint amount); -} diff --git a/tests/evm/sbe@constant-product-amm-addLiquidity3/CPAMM.sol b/tests/evm/sbe@constant-product-amm-addLiquidity3/CPAMM.sol deleted file mode 100644 index 4de0e09a5..000000000 --- a/tests/evm/sbe@constant-product-amm-addLiquidity3/CPAMM.sol +++ /dev/null @@ -1,246 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.13; -import "../../../solidity_utils/lib.sol"; -contract CPAMM { - IERC20 public immutable token0; - IERC20 public immutable token1; - - uint public reserve0; - uint public reserve1; - - uint public totalSupply; - mapping(address => uint) public balanceOf; - - constructor(address _token0, address _token1) { - token0 = IERC20(_token0); - token1 = IERC20(_token1); - } - - function _mint(address _to, uint _amount) private { - balanceOf[_to] += _amount; - totalSupply += _amount; - } - - function _burn(address _from, uint _amount) private { - balanceOf[_from] -= _amount; - totalSupply -= _amount; - } - - function _update(uint _reserve0, uint _reserve1) private { - reserve0 = _reserve0; - reserve1 = _reserve1; - } - - function swap(address _tokenIn, uint _amountIn) external returns (uint amountOut) { - require( - _tokenIn == address(token0) || _tokenIn == address(token1), - "invalid token" - ); - require(_amountIn > 0, "amount in = 0"); - - bool isToken0 = _tokenIn == address(token0); - (IERC20 tokenIn, IERC20 tokenOut, uint reserveIn, uint reserveOut) = isToken0 - ? (token0, token1, reserve0, reserve1) - : (token1, token0, reserve1, reserve0); - - tokenIn.transferFrom(msg.sender, address(this), _amountIn); - - /* - How much dy for dx? - - xy = k - (x + dx)(y - dy) = k - y - dy = k / (x + dx) - y - k / (x + dx) = dy - y - xy / (x + dx) = dy - (yx + ydx - xy) / (x + dx) = dy - ydx / (x + dx) = dy - */ - // 0.3% fee - uint amountInWithFee = (_amountIn * 997) / 1000; - amountOut = (reserveOut * amountInWithFee) / (reserveIn + amountInWithFee); - - tokenOut.transfer(msg.sender, amountOut); - - _update(token0.balanceOf(address(this)), token1.balanceOf(address(this))); - } - - function addLiquidity(uint _amount0, uint _amount1) external returns (uint shares) { - token0.transferFrom(msg.sender, address(this), _amount0); - token1.transferFrom(msg.sender, address(this), _amount1); - - /* - How much dx, dy to add? - - xy = k - (x + dx)(y + dy) = k' - - No price change, before and after adding liquidity - x / y = (x + dx) / (y + dy) - - x(y + dy) = y(x + dx) - x * dy = y * dx - - x / y = dx / dy - dy = y / x * dx - */ - if (reserve0 > 0 || reserve1 > 0) { - require(reserve0 * _amount1 == reserve1 * _amount0, "x / y != dx / dy"); - } - - /* - How much shares to mint? - - f(x, y) = value of liquidity - We will define f(x, y) = sqrt(xy) - - L0 = f(x, y) - L1 = f(x + dx, y + dy) - T = total shares - s = shares to mint - - Total shares should increase proportional to increase in liquidity - L1 / L0 = (T + s) / T - - L1 * T = L0 * (T + s) - - (L1 - L0) * T / L0 = s - */ - - /* - Claim - (L1 - L0) / L0 = dx / x = dy / y - - Proof - --- Equation 1 --- - (L1 - L0) / L0 = (sqrt((x + dx)(y + dy)) - sqrt(xy)) / sqrt(xy) - - dx / dy = x / y so replace dy = dx * y / x - - --- Equation 2 --- - Equation 1 = (sqrt(xy + 2ydx + dx^2 * y / x) - sqrt(xy)) / sqrt(xy) - - Multiply by sqrt(x) / sqrt(x) - Equation 2 = (sqrt(x^2y + 2xydx + dx^2 * y) - sqrt(x^2y)) / sqrt(x^2y) - = (sqrt(y)(sqrt(x^2 + 2xdx + dx^2) - sqrt(x^2)) / (sqrt(y)sqrt(x^2)) - - sqrt(y) on top and bottom cancels out - - --- Equation 3 --- - Equation 2 = (sqrt(x^2 + 2xdx + dx^2) - sqrt(x^2)) / (sqrt(x^2) - = (sqrt((x + dx)^2) - sqrt(x^2)) / sqrt(x^2) - = ((x + dx) - x) / x - = dx / x - - Since dx / dy = x / y, - dx / x = dy / y - - Finally - (L1 - L0) / L0 = dx / x = dy / y - */ - if (totalSupply == 0) { - shares = _sqrt(_amount0 * _amount1); - } else { - shares = _min( - (_amount0 * totalSupply) / reserve0, - (_amount1 * totalSupply) / reserve1 - ); - bug(); - } - require(shares > 0, "shares = 0"); - _mint(msg.sender, shares); - - _update(token0.balanceOf(address(this)), token1.balanceOf(address(this))); - } - - function removeLiquidity( - uint _shares - ) external returns (uint amount0, uint amount1) { - /* - Claim - dx, dy = amount of liquidity to remove - dx = s / T * x - dy = s / T * y - - Proof - Let's find dx, dy such that - v / L = s / T - - where - v = f(dx, dy) = sqrt(dxdy) - L = total liquidity = sqrt(xy) - s = shares - T = total supply - - --- Equation 1 --- - v = s / T * L - sqrt(dxdy) = s / T * sqrt(xy) - - Amount of liquidity to remove must not change price so - dx / dy = x / y - - replace dy = dx * y / x - sqrt(dxdy) = sqrt(dx * dx * y / x) = dx * sqrt(y / x) - - Divide both sides of Equation 1 with sqrt(y / x) - dx = s / T * sqrt(xy) / sqrt(y / x) - = s / T * sqrt(x^2) = s / T * x - - Likewise - dy = s / T * y - */ - - // bal0 >= reserve0 - // bal1 >= reserve1 - uint bal0 = token0.balanceOf(address(this)); - uint bal1 = token1.balanceOf(address(this)); - - amount0 = (_shares * bal0) / totalSupply; - amount1 = (_shares * bal1) / totalSupply; - require(amount0 > 0 && amount1 > 0, "amount0 or amount1 = 0"); - - _burn(msg.sender, _shares); - _update(bal0 - amount0, bal1 - amount1); - - token0.transfer(msg.sender, amount0); - token1.transfer(msg.sender, amount1); - } - - function _sqrt(uint y) private pure returns (uint z) { - if (y > 3) { - z = y; - uint x = y / 2 + 1; - while (x < z) { - z = x; - x = (y / x + x) / 2; - } - } else if (y != 0) { - z = 1; - } - } - - function _min(uint x, uint y) private pure returns (uint) { - return x <= y ? x : y; - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint amount); - event Approval(address indexed owner, address indexed spender, uint amount); -} diff --git a/tests/evm/sbe@constant-product-amm-addLiquidity4/CPAMM.sol b/tests/evm/sbe@constant-product-amm-addLiquidity4/CPAMM.sol deleted file mode 100644 index aa9ae8252..000000000 --- a/tests/evm/sbe@constant-product-amm-addLiquidity4/CPAMM.sol +++ /dev/null @@ -1,246 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.13; -import "../../../solidity_utils/lib.sol"; -contract CPAMM { - IERC20 public immutable token0; - IERC20 public immutable token1; - - uint public reserve0; - uint public reserve1; - - uint public totalSupply; - mapping(address => uint) public balanceOf; - - constructor(address _token0, address _token1) { - token0 = IERC20(_token0); - token1 = IERC20(_token1); - } - - function _mint(address _to, uint _amount) private { - balanceOf[_to] += _amount; - totalSupply += _amount; - } - - function _burn(address _from, uint _amount) private { - balanceOf[_from] -= _amount; - totalSupply -= _amount; - } - - function _update(uint _reserve0, uint _reserve1) private { - reserve0 = _reserve0; - reserve1 = _reserve1; - } - - function swap(address _tokenIn, uint _amountIn) external returns (uint amountOut) { - require( - _tokenIn == address(token0) || _tokenIn == address(token1), - "invalid token" - ); - require(_amountIn > 0, "amount in = 0"); - - bool isToken0 = _tokenIn == address(token0); - (IERC20 tokenIn, IERC20 tokenOut, uint reserveIn, uint reserveOut) = isToken0 - ? (token0, token1, reserve0, reserve1) - : (token1, token0, reserve1, reserve0); - - tokenIn.transferFrom(msg.sender, address(this), _amountIn); - - /* - How much dy for dx? - - xy = k - (x + dx)(y - dy) = k - y - dy = k / (x + dx) - y - k / (x + dx) = dy - y - xy / (x + dx) = dy - (yx + ydx - xy) / (x + dx) = dy - ydx / (x + dx) = dy - */ - // 0.3% fee - uint amountInWithFee = (_amountIn * 997) / 1000; - amountOut = (reserveOut * amountInWithFee) / (reserveIn + amountInWithFee); - - tokenOut.transfer(msg.sender, amountOut); - - _update(token0.balanceOf(address(this)), token1.balanceOf(address(this))); - } - - function addLiquidity(uint _amount0, uint _amount1) external returns (uint shares) { - token0.transferFrom(msg.sender, address(this), _amount0); - token1.transferFrom(msg.sender, address(this), _amount1); - - /* - How much dx, dy to add? - - xy = k - (x + dx)(y + dy) = k' - - No price change, before and after adding liquidity - x / y = (x + dx) / (y + dy) - - x(y + dy) = y(x + dx) - x * dy = y * dx - - x / y = dx / dy - dy = y / x * dx - */ - if (reserve0 > 0 || reserve1 > 0) { - require(reserve0 * _amount1 == reserve1 * _amount0, "x / y != dx / dy"); - } - - /* - How much shares to mint? - - f(x, y) = value of liquidity - We will define f(x, y) = sqrt(xy) - - L0 = f(x, y) - L1 = f(x + dx, y + dy) - T = total shares - s = shares to mint - - Total shares should increase proportional to increase in liquidity - L1 / L0 = (T + s) / T - - L1 * T = L0 * (T + s) - - (L1 - L0) * T / L0 = s - */ - - /* - Claim - (L1 - L0) / L0 = dx / x = dy / y - - Proof - --- Equation 1 --- - (L1 - L0) / L0 = (sqrt((x + dx)(y + dy)) - sqrt(xy)) / sqrt(xy) - - dx / dy = x / y so replace dy = dx * y / x - - --- Equation 2 --- - Equation 1 = (sqrt(xy + 2ydx + dx^2 * y / x) - sqrt(xy)) / sqrt(xy) - - Multiply by sqrt(x) / sqrt(x) - Equation 2 = (sqrt(x^2y + 2xydx + dx^2 * y) - sqrt(x^2y)) / sqrt(x^2y) - = (sqrt(y)(sqrt(x^2 + 2xdx + dx^2) - sqrt(x^2)) / (sqrt(y)sqrt(x^2)) - - sqrt(y) on top and bottom cancels out - - --- Equation 3 --- - Equation 2 = (sqrt(x^2 + 2xdx + dx^2) - sqrt(x^2)) / (sqrt(x^2) - = (sqrt((x + dx)^2) - sqrt(x^2)) / sqrt(x^2) - = ((x + dx) - x) / x - = dx / x - - Since dx / dy = x / y, - dx / x = dy / y - - Finally - (L1 - L0) / L0 = dx / x = dy / y - */ - if (totalSupply == 0) { - shares = _sqrt(_amount0 * _amount1); - } else { - shares = _min( - (_amount0 * totalSupply) / reserve0, - (_amount1 * totalSupply) / reserve1 - ); - } - require(shares > 0, "shares = 0"); - _mint(msg.sender, shares); - - _update(token0.balanceOf(address(this)), token1.balanceOf(address(this))); - bug(); - } - - function removeLiquidity( - uint _shares - ) external returns (uint amount0, uint amount1) { - /* - Claim - dx, dy = amount of liquidity to remove - dx = s / T * x - dy = s / T * y - - Proof - Let's find dx, dy such that - v / L = s / T - - where - v = f(dx, dy) = sqrt(dxdy) - L = total liquidity = sqrt(xy) - s = shares - T = total supply - - --- Equation 1 --- - v = s / T * L - sqrt(dxdy) = s / T * sqrt(xy) - - Amount of liquidity to remove must not change price so - dx / dy = x / y - - replace dy = dx * y / x - sqrt(dxdy) = sqrt(dx * dx * y / x) = dx * sqrt(y / x) - - Divide both sides of Equation 1 with sqrt(y / x) - dx = s / T * sqrt(xy) / sqrt(y / x) - = s / T * sqrt(x^2) = s / T * x - - Likewise - dy = s / T * y - */ - - // bal0 >= reserve0 - // bal1 >= reserve1 - uint bal0 = token0.balanceOf(address(this)); - uint bal1 = token1.balanceOf(address(this)); - - amount0 = (_shares * bal0) / totalSupply; - amount1 = (_shares * bal1) / totalSupply; - require(amount0 > 0 && amount1 > 0, "amount0 or amount1 = 0"); - - _burn(msg.sender, _shares); - _update(bal0 - amount0, bal1 - amount1); - - token0.transfer(msg.sender, amount0); - token1.transfer(msg.sender, amount1); - } - - function _sqrt(uint y) private pure returns (uint z) { - if (y > 3) { - z = y; - uint x = y / 2 + 1; - while (x < z) { - z = x; - x = (y / x + x) / 2; - } - } else if (y != 0) { - z = 1; - } - } - - function _min(uint x, uint y) private pure returns (uint) { - return x <= y ? x : y; - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint amount); - event Approval(address indexed owner, address indexed spender, uint amount); -} diff --git a/tests/evm/sbe@constant-product-amm-burn/CPAMM.sol b/tests/evm/sbe@constant-product-amm-burn/CPAMM.sol deleted file mode 100644 index 5e02a7eb1..000000000 --- a/tests/evm/sbe@constant-product-amm-burn/CPAMM.sol +++ /dev/null @@ -1,246 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.13; -import "../../../solidity_utils/lib.sol"; -contract CPAMM { - IERC20 public immutable token0; - IERC20 public immutable token1; - - uint public reserve0; - uint public reserve1; - - uint public totalSupply; - mapping(address => uint) public balanceOf; - - constructor(address _token0, address _token1) { - token0 = IERC20(_token0); - token1 = IERC20(_token1); - } - - function _mint(address _to, uint _amount) private { - balanceOf[_to] += _amount; - totalSupply += _amount; - } - - function _burn(address _from, uint _amount) private { - balanceOf[_from] -= _amount; - totalSupply -= _amount; - bug(); - } - - function _update(uint _reserve0, uint _reserve1) private { - reserve0 = _reserve0; - reserve1 = _reserve1; - } - - function swap(address _tokenIn, uint _amountIn) external returns (uint amountOut) { - require( - _tokenIn == address(token0) || _tokenIn == address(token1), - "invalid token" - ); - require(_amountIn > 0, "amount in = 0"); - - bool isToken0 = _tokenIn == address(token0); - (IERC20 tokenIn, IERC20 tokenOut, uint reserveIn, uint reserveOut) = isToken0 - ? (token0, token1, reserve0, reserve1) - : (token1, token0, reserve1, reserve0); - - tokenIn.transferFrom(msg.sender, address(this), _amountIn); - - /* - How much dy for dx? - - xy = k - (x + dx)(y - dy) = k - y - dy = k / (x + dx) - y - k / (x + dx) = dy - y - xy / (x + dx) = dy - (yx + ydx - xy) / (x + dx) = dy - ydx / (x + dx) = dy - */ - // 0.3% fee - uint amountInWithFee = (_amountIn * 997) / 1000; - amountOut = (reserveOut * amountInWithFee) / (reserveIn + amountInWithFee); - - tokenOut.transfer(msg.sender, amountOut); - - _update(token0.balanceOf(address(this)), token1.balanceOf(address(this))); - } - - function addLiquidity(uint _amount0, uint _amount1) external returns (uint shares) { - token0.transferFrom(msg.sender, address(this), _amount0); - token1.transferFrom(msg.sender, address(this), _amount1); - - /* - How much dx, dy to add? - - xy = k - (x + dx)(y + dy) = k' - - No price change, before and after adding liquidity - x / y = (x + dx) / (y + dy) - - x(y + dy) = y(x + dx) - x * dy = y * dx - - x / y = dx / dy - dy = y / x * dx - */ - if (reserve0 > 0 || reserve1 > 0) { - require(reserve0 * _amount1 == reserve1 * _amount0, "x / y != dx / dy"); - } - - /* - How much shares to mint? - - f(x, y) = value of liquidity - We will define f(x, y) = sqrt(xy) - - L0 = f(x, y) - L1 = f(x + dx, y + dy) - T = total shares - s = shares to mint - - Total shares should increase proportional to increase in liquidity - L1 / L0 = (T + s) / T - - L1 * T = L0 * (T + s) - - (L1 - L0) * T / L0 = s - */ - - /* - Claim - (L1 - L0) / L0 = dx / x = dy / y - - Proof - --- Equation 1 --- - (L1 - L0) / L0 = (sqrt((x + dx)(y + dy)) - sqrt(xy)) / sqrt(xy) - - dx / dy = x / y so replace dy = dx * y / x - - --- Equation 2 --- - Equation 1 = (sqrt(xy + 2ydx + dx^2 * y / x) - sqrt(xy)) / sqrt(xy) - - Multiply by sqrt(x) / sqrt(x) - Equation 2 = (sqrt(x^2y + 2xydx + dx^2 * y) - sqrt(x^2y)) / sqrt(x^2y) - = (sqrt(y)(sqrt(x^2 + 2xdx + dx^2) - sqrt(x^2)) / (sqrt(y)sqrt(x^2)) - - sqrt(y) on top and bottom cancels out - - --- Equation 3 --- - Equation 2 = (sqrt(x^2 + 2xdx + dx^2) - sqrt(x^2)) / (sqrt(x^2) - = (sqrt((x + dx)^2) - sqrt(x^2)) / sqrt(x^2) - = ((x + dx) - x) / x - = dx / x - - Since dx / dy = x / y, - dx / x = dy / y - - Finally - (L1 - L0) / L0 = dx / x = dy / y - */ - if (totalSupply == 0) { - shares = _sqrt(_amount0 * _amount1); - } else { - shares = _min( - (_amount0 * totalSupply) / reserve0, - (_amount1 * totalSupply) / reserve1 - ); - } - require(shares > 0, "shares = 0"); - _mint(msg.sender, shares); - - _update(token0.balanceOf(address(this)), token1.balanceOf(address(this))); - } - - function removeLiquidity( - uint _shares - ) external returns (uint amount0, uint amount1) { - /* - Claim - dx, dy = amount of liquidity to remove - dx = s / T * x - dy = s / T * y - - Proof - Let's find dx, dy such that - v / L = s / T - - where - v = f(dx, dy) = sqrt(dxdy) - L = total liquidity = sqrt(xy) - s = shares - T = total supply - - --- Equation 1 --- - v = s / T * L - sqrt(dxdy) = s / T * sqrt(xy) - - Amount of liquidity to remove must not change price so - dx / dy = x / y - - replace dy = dx * y / x - sqrt(dxdy) = sqrt(dx * dx * y / x) = dx * sqrt(y / x) - - Divide both sides of Equation 1 with sqrt(y / x) - dx = s / T * sqrt(xy) / sqrt(y / x) - = s / T * sqrt(x^2) = s / T * x - - Likewise - dy = s / T * y - */ - - // bal0 >= reserve0 - // bal1 >= reserve1 - uint bal0 = token0.balanceOf(address(this)); - uint bal1 = token1.balanceOf(address(this)); - - amount0 = (_shares * bal0) / totalSupply; - amount1 = (_shares * bal1) / totalSupply; - require(amount0 > 0 && amount1 > 0, "amount0 or amount1 = 0"); - - _burn(msg.sender, _shares); - _update(bal0 - amount0, bal1 - amount1); - - token0.transfer(msg.sender, amount0); - token1.transfer(msg.sender, amount1); - } - - function _sqrt(uint y) private pure returns (uint z) { - if (y > 3) { - z = y; - uint x = y / 2 + 1; - while (x < z) { - z = x; - x = (y / x + x) / 2; - } - } else if (y != 0) { - z = 1; - } - } - - function _min(uint x, uint y) private pure returns (uint) { - return x <= y ? x : y; - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint amount); - event Approval(address indexed owner, address indexed spender, uint amount); -} diff --git a/tests/evm/sbe@constant-product-amm-mint/CPAMM.sol b/tests/evm/sbe@constant-product-amm-mint/CPAMM.sol deleted file mode 100644 index 1c0694796..000000000 --- a/tests/evm/sbe@constant-product-amm-mint/CPAMM.sol +++ /dev/null @@ -1,246 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.13; -import "../../../solidity_utils/lib.sol"; -contract CPAMM { - IERC20 public immutable token0; - IERC20 public immutable token1; - - uint public reserve0; - uint public reserve1; - - uint public totalSupply; - mapping(address => uint) public balanceOf; - - constructor(address _token0, address _token1) { - token0 = IERC20(_token0); - token1 = IERC20(_token1); - } - - function _mint(address _to, uint _amount) private { - balanceOf[_to] += _amount; - totalSupply += _amount; - bug(); - } - - function _burn(address _from, uint _amount) private { - balanceOf[_from] -= _amount; - totalSupply -= _amount; - } - - function _update(uint _reserve0, uint _reserve1) private { - reserve0 = _reserve0; - reserve1 = _reserve1; - } - - function swap(address _tokenIn, uint _amountIn) external returns (uint amountOut) { - require( - _tokenIn == address(token0) || _tokenIn == address(token1), - "invalid token" - ); - require(_amountIn > 0, "amount in = 0"); - - bool isToken0 = _tokenIn == address(token0); - (IERC20 tokenIn, IERC20 tokenOut, uint reserveIn, uint reserveOut) = isToken0 - ? (token0, token1, reserve0, reserve1) - : (token1, token0, reserve1, reserve0); - - tokenIn.transferFrom(msg.sender, address(this), _amountIn); - - /* - How much dy for dx? - - xy = k - (x + dx)(y - dy) = k - y - dy = k / (x + dx) - y - k / (x + dx) = dy - y - xy / (x + dx) = dy - (yx + ydx - xy) / (x + dx) = dy - ydx / (x + dx) = dy - */ - // 0.3% fee - uint amountInWithFee = (_amountIn * 997) / 1000; - amountOut = (reserveOut * amountInWithFee) / (reserveIn + amountInWithFee); - - tokenOut.transfer(msg.sender, amountOut); - - _update(token0.balanceOf(address(this)), token1.balanceOf(address(this))); - } - - function addLiquidity(uint _amount0, uint _amount1) external returns (uint shares) { - token0.transferFrom(msg.sender, address(this), _amount0); - token1.transferFrom(msg.sender, address(this), _amount1); - - /* - How much dx, dy to add? - - xy = k - (x + dx)(y + dy) = k' - - No price change, before and after adding liquidity - x / y = (x + dx) / (y + dy) - - x(y + dy) = y(x + dx) - x * dy = y * dx - - x / y = dx / dy - dy = y / x * dx - */ - if (reserve0 > 0 || reserve1 > 0) { - require(reserve0 * _amount1 == reserve1 * _amount0, "x / y != dx / dy"); - } - - /* - How much shares to mint? - - f(x, y) = value of liquidity - We will define f(x, y) = sqrt(xy) - - L0 = f(x, y) - L1 = f(x + dx, y + dy) - T = total shares - s = shares to mint - - Total shares should increase proportional to increase in liquidity - L1 / L0 = (T + s) / T - - L1 * T = L0 * (T + s) - - (L1 - L0) * T / L0 = s - */ - - /* - Claim - (L1 - L0) / L0 = dx / x = dy / y - - Proof - --- Equation 1 --- - (L1 - L0) / L0 = (sqrt((x + dx)(y + dy)) - sqrt(xy)) / sqrt(xy) - - dx / dy = x / y so replace dy = dx * y / x - - --- Equation 2 --- - Equation 1 = (sqrt(xy + 2ydx + dx^2 * y / x) - sqrt(xy)) / sqrt(xy) - - Multiply by sqrt(x) / sqrt(x) - Equation 2 = (sqrt(x^2y + 2xydx + dx^2 * y) - sqrt(x^2y)) / sqrt(x^2y) - = (sqrt(y)(sqrt(x^2 + 2xdx + dx^2) - sqrt(x^2)) / (sqrt(y)sqrt(x^2)) - - sqrt(y) on top and bottom cancels out - - --- Equation 3 --- - Equation 2 = (sqrt(x^2 + 2xdx + dx^2) - sqrt(x^2)) / (sqrt(x^2) - = (sqrt((x + dx)^2) - sqrt(x^2)) / sqrt(x^2) - = ((x + dx) - x) / x - = dx / x - - Since dx / dy = x / y, - dx / x = dy / y - - Finally - (L1 - L0) / L0 = dx / x = dy / y - */ - if (totalSupply == 0) { - shares = _sqrt(_amount0 * _amount1); - } else { - shares = _min( - (_amount0 * totalSupply) / reserve0, - (_amount1 * totalSupply) / reserve1 - ); - } - require(shares > 0, "shares = 0"); - _mint(msg.sender, shares); - - _update(token0.balanceOf(address(this)), token1.balanceOf(address(this))); - } - - function removeLiquidity( - uint _shares - ) external returns (uint amount0, uint amount1) { - /* - Claim - dx, dy = amount of liquidity to remove - dx = s / T * x - dy = s / T * y - - Proof - Let's find dx, dy such that - v / L = s / T - - where - v = f(dx, dy) = sqrt(dxdy) - L = total liquidity = sqrt(xy) - s = shares - T = total supply - - --- Equation 1 --- - v = s / T * L - sqrt(dxdy) = s / T * sqrt(xy) - - Amount of liquidity to remove must not change price so - dx / dy = x / y - - replace dy = dx * y / x - sqrt(dxdy) = sqrt(dx * dx * y / x) = dx * sqrt(y / x) - - Divide both sides of Equation 1 with sqrt(y / x) - dx = s / T * sqrt(xy) / sqrt(y / x) - = s / T * sqrt(x^2) = s / T * x - - Likewise - dy = s / T * y - */ - - // bal0 >= reserve0 - // bal1 >= reserve1 - uint bal0 = token0.balanceOf(address(this)); - uint bal1 = token1.balanceOf(address(this)); - - amount0 = (_shares * bal0) / totalSupply; - amount1 = (_shares * bal1) / totalSupply; - require(amount0 > 0 && amount1 > 0, "amount0 or amount1 = 0"); - - _burn(msg.sender, _shares); - _update(bal0 - amount0, bal1 - amount1); - - token0.transfer(msg.sender, amount0); - token1.transfer(msg.sender, amount1); - } - - function _sqrt(uint y) private pure returns (uint z) { - if (y > 3) { - z = y; - uint x = y / 2 + 1; - while (x < z) { - z = x; - x = (y / x + x) / 2; - } - } else if (y != 0) { - z = 1; - } - } - - function _min(uint x, uint y) private pure returns (uint) { - return x <= y ? x : y; - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint amount); - event Approval(address indexed owner, address indexed spender, uint amount); -} diff --git a/tests/evm/sbe@constant-product-amm-removeLiquidity/CPAMM.sol b/tests/evm/sbe@constant-product-amm-removeLiquidity/CPAMM.sol deleted file mode 100644 index c545c1812..000000000 --- a/tests/evm/sbe@constant-product-amm-removeLiquidity/CPAMM.sol +++ /dev/null @@ -1,246 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.13; -import "../../../solidity_utils/lib.sol"; -contract CPAMM { - IERC20 public immutable token0; - IERC20 public immutable token1; - - uint public reserve0; - uint public reserve1; - - uint public totalSupply; - mapping(address => uint) public balanceOf; - - constructor(address _token0, address _token1) { - token0 = IERC20(_token0); - token1 = IERC20(_token1); - } - - function _mint(address _to, uint _amount) private { - balanceOf[_to] += _amount; - totalSupply += _amount; - } - - function _burn(address _from, uint _amount) private { - balanceOf[_from] -= _amount; - totalSupply -= _amount; - } - - function _update(uint _reserve0, uint _reserve1) private { - reserve0 = _reserve0; - reserve1 = _reserve1; - } - - function swap(address _tokenIn, uint _amountIn) external returns (uint amountOut) { - require( - _tokenIn == address(token0) || _tokenIn == address(token1), - "invalid token" - ); - require(_amountIn > 0, "amount in = 0"); - - bool isToken0 = _tokenIn == address(token0); - (IERC20 tokenIn, IERC20 tokenOut, uint reserveIn, uint reserveOut) = isToken0 - ? (token0, token1, reserve0, reserve1) - : (token1, token0, reserve1, reserve0); - - tokenIn.transferFrom(msg.sender, address(this), _amountIn); - - /* - How much dy for dx? - - xy = k - (x + dx)(y - dy) = k - y - dy = k / (x + dx) - y - k / (x + dx) = dy - y - xy / (x + dx) = dy - (yx + ydx - xy) / (x + dx) = dy - ydx / (x + dx) = dy - */ - // 0.3% fee - uint amountInWithFee = (_amountIn * 997) / 1000; - amountOut = (reserveOut * amountInWithFee) / (reserveIn + amountInWithFee); - - tokenOut.transfer(msg.sender, amountOut); - - _update(token0.balanceOf(address(this)), token1.balanceOf(address(this))); - } - - function addLiquidity(uint _amount0, uint _amount1) external returns (uint shares) { - token0.transferFrom(msg.sender, address(this), _amount0); - token1.transferFrom(msg.sender, address(this), _amount1); - - /* - How much dx, dy to add? - - xy = k - (x + dx)(y + dy) = k' - - No price change, before and after adding liquidity - x / y = (x + dx) / (y + dy) - - x(y + dy) = y(x + dx) - x * dy = y * dx - - x / y = dx / dy - dy = y / x * dx - */ - if (reserve0 > 0 || reserve1 > 0) { - require(reserve0 * _amount1 == reserve1 * _amount0, "x / y != dx / dy"); - } - - /* - How much shares to mint? - - f(x, y) = value of liquidity - We will define f(x, y) = sqrt(xy) - - L0 = f(x, y) - L1 = f(x + dx, y + dy) - T = total shares - s = shares to mint - - Total shares should increase proportional to increase in liquidity - L1 / L0 = (T + s) / T - - L1 * T = L0 * (T + s) - - (L1 - L0) * T / L0 = s - */ - - /* - Claim - (L1 - L0) / L0 = dx / x = dy / y - - Proof - --- Equation 1 --- - (L1 - L0) / L0 = (sqrt((x + dx)(y + dy)) - sqrt(xy)) / sqrt(xy) - - dx / dy = x / y so replace dy = dx * y / x - - --- Equation 2 --- - Equation 1 = (sqrt(xy + 2ydx + dx^2 * y / x) - sqrt(xy)) / sqrt(xy) - - Multiply by sqrt(x) / sqrt(x) - Equation 2 = (sqrt(x^2y + 2xydx + dx^2 * y) - sqrt(x^2y)) / sqrt(x^2y) - = (sqrt(y)(sqrt(x^2 + 2xdx + dx^2) - sqrt(x^2)) / (sqrt(y)sqrt(x^2)) - - sqrt(y) on top and bottom cancels out - - --- Equation 3 --- - Equation 2 = (sqrt(x^2 + 2xdx + dx^2) - sqrt(x^2)) / (sqrt(x^2) - = (sqrt((x + dx)^2) - sqrt(x^2)) / sqrt(x^2) - = ((x + dx) - x) / x - = dx / x - - Since dx / dy = x / y, - dx / x = dy / y - - Finally - (L1 - L0) / L0 = dx / x = dy / y - */ - if (totalSupply == 0) { - shares = _sqrt(_amount0 * _amount1); - } else { - shares = _min( - (_amount0 * totalSupply) / reserve0, - (_amount1 * totalSupply) / reserve1 - ); - } - require(shares > 0, "shares = 0"); - _mint(msg.sender, shares); - - _update(token0.balanceOf(address(this)), token1.balanceOf(address(this))); - } - - function removeLiquidity( - uint _shares - ) external returns (uint amount0, uint amount1) { - /* - Claim - dx, dy = amount of liquidity to remove - dx = s / T * x - dy = s / T * y - - Proof - Let's find dx, dy such that - v / L = s / T - - where - v = f(dx, dy) = sqrt(dxdy) - L = total liquidity = sqrt(xy) - s = shares - T = total supply - - --- Equation 1 --- - v = s / T * L - sqrt(dxdy) = s / T * sqrt(xy) - - Amount of liquidity to remove must not change price so - dx / dy = x / y - - replace dy = dx * y / x - sqrt(dxdy) = sqrt(dx * dx * y / x) = dx * sqrt(y / x) - - Divide both sides of Equation 1 with sqrt(y / x) - dx = s / T * sqrt(xy) / sqrt(y / x) - = s / T * sqrt(x^2) = s / T * x - - Likewise - dy = s / T * y - */ - - // bal0 >= reserve0 - // bal1 >= reserve1 - uint bal0 = token0.balanceOf(address(this)); - uint bal1 = token1.balanceOf(address(this)); - - amount0 = (_shares * bal0) / totalSupply; - amount1 = (_shares * bal1) / totalSupply; - require(amount0 > 0 && amount1 > 0, "amount0 or amount1 = 0"); - - _burn(msg.sender, _shares); - _update(bal0 - amount0, bal1 - amount1); - - token0.transfer(msg.sender, amount0); - token1.transfer(msg.sender, amount1); - bug(); - } - - function _sqrt(uint y) private pure returns (uint z) { - if (y > 3) { - z = y; - uint x = y / 2 + 1; - while (x < z) { - z = x; - x = (y / x + x) / 2; - } - } else if (y != 0) { - z = 1; - } - } - - function _min(uint x, uint y) private pure returns (uint) { - return x <= y ? x : y; - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint amount); - event Approval(address indexed owner, address indexed spender, uint amount); -} diff --git a/tests/evm/sbe@discrete-staking-rewards-claim/DiscreteStakingRewards.sol b/tests/evm/sbe@discrete-staking-rewards-claim/DiscreteStakingRewards.sol deleted file mode 100644 index 59b772baa..000000000 --- a/tests/evm/sbe@discrete-staking-rewards-claim/DiscreteStakingRewards.sol +++ /dev/null @@ -1,92 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.13; -import "../../../solidity_utils/lib.sol"; - -contract DiscreteStakingRewards { - IERC20 public immutable stakingToken; - IERC20 public immutable rewardToken; - - mapping(address => uint) public balanceOf; - uint public totalSupply; - - uint private constant MULTIPLIER = 1e18; - uint private rewardIndex; - mapping(address => uint) private rewardIndexOf; - mapping(address => uint) private earned; - - constructor(address _stakingToken, address _rewardToken) { - stakingToken = IERC20(_stakingToken); - rewardToken = IERC20(_rewardToken); - } - - function updateRewardIndex(uint reward) external { - rewardToken.transferFrom(msg.sender, address(this), reward); - rewardIndex += (reward * MULTIPLIER) / totalSupply; - } - - function _calculateRewards(address account) private view returns (uint) { - uint shares = balanceOf[account]; - return (shares * (rewardIndex - rewardIndexOf[account])) / MULTIPLIER; - } - - function calculateRewardsEarned(address account) external view returns (uint) { - return earned[account] + _calculateRewards(account); - } - - function _updateRewards(address account) private { - earned[account] += _calculateRewards(account); - rewardIndexOf[account] = rewardIndex; - } - - function stake(uint amount) external { - _updateRewards(msg.sender); - - balanceOf[msg.sender] += amount; - totalSupply += amount; - - stakingToken.transferFrom(msg.sender, address(this), amount); - } - - function unstake(uint amount) external { - _updateRewards(msg.sender); - - balanceOf[msg.sender] -= amount; - totalSupply -= amount; - - stakingToken.transfer(msg.sender, amount); - } - - function claim() external returns (uint) { - _updateRewards(msg.sender); - - uint reward = earned[msg.sender]; - if (reward > 0) { - earned[msg.sender] = 0; - rewardToken.transfer(msg.sender, reward); - bug(); - } - - return reward; - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint value); - event Approval(address indexed owner, address indexed spender, uint value); -} diff --git a/tests/evm/sbe@english-auction-withdraw/EnglishAuction.sol b/tests/evm/sbe@english-auction-withdraw/EnglishAuction.sol deleted file mode 100644 index 2520e1a86..000000000 --- a/tests/evm/sbe@english-auction-withdraw/EnglishAuction.sol +++ /dev/null @@ -1,87 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.13; -import "../../../solidity_utils/lib.sol"; - -interface IERC721 { - function safeTransferFrom(address from, address to, uint tokenId) external; - - function transferFrom(address, address, uint) external; -} - -contract EnglishAuction { - event Start(); - event Bid(address indexed sender, uint amount); - event Withdraw(address indexed bidder, uint amount); - event End(address winner, uint amount); - - IERC721 public nft; - uint public nftId; - - address payable public seller; - uint public endAt; - bool public started; - bool public ended; - - address public highestBidder; - uint public highestBid; - mapping(address => uint) public bids; - - constructor(address _nft, uint _nftId, uint _startingBid) { - nft = IERC721(_nft); - nftId = _nftId; - - seller = payable(msg.sender); - highestBid = _startingBid; - } - - function start() external { - require(!started, "started"); - require(msg.sender == seller, "not seller"); - - nft.transferFrom(msg.sender, address(this), nftId); - started = true; - endAt = block.timestamp + 7 days; - - emit Start(); - } - - function bid() external payable { - require(started, "not started"); - require(block.timestamp < endAt, "ended"); - require(msg.value > highestBid, "value < highest"); - - if (highestBidder != address(0)) { - bids[highestBidder] += highestBid; - } - - highestBidder = msg.sender; - highestBid = msg.value; - - emit Bid(msg.sender, msg.value); - } - - function withdraw() external { - uint bal = bids[msg.sender]; - bids[msg.sender] = 0; - payable(msg.sender).transfer(bal); - - emit Withdraw(msg.sender, bal); - bug(); - } - - function end() external { - require(started, "not started"); - require(block.timestamp >= endAt, "not ended"); - require(!ended, "ended"); - - ended = true; - if (highestBidder != address(0)) { - nft.safeTransferFrom(address(this), highestBidder, nftId); - seller.transfer(highestBid); - } else { - nft.safeTransferFrom(address(this), seller, nftId); - } - - emit End(highestBidder, highestBid); - } -} diff --git a/tests/evm/sbe@stable-swap-amm-addLiquidity1/StableSwap.sol b/tests/evm/sbe@stable-swap-amm-addLiquidity1/StableSwap.sol deleted file mode 100644 index 9e206ba2d..000000000 --- a/tests/evm/sbe@stable-swap-amm-addLiquidity1/StableSwap.sol +++ /dev/null @@ -1,444 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8; -import "../../../solidity_utils/lib.sol"; - -/* -Invariant - price of trade and amount of liquidity are determined by this equation - -An^n sum(x_i) + D = ADn^n + D^(n + 1) / (n^n prod(x_i)) - -Topics -0. Newton's method x_(n + 1) = x_n - f(x_n) / f'(x_n) -1. Invariant -2. Swap - - Calculate Y - - Calculate D -3. Get virtual price -4. Add liquidity - - Imbalance fee -5. Remove liquidity -6. Remove liquidity one token - - Calculate withdraw one token - - getYD -TODO: test? -*/ - -library Math { - function abs(uint x, uint y) internal pure returns (uint) { - return x >= y ? x - y : y - x; - } -} - -contract StableSwap { - // Number of tokens - uint private constant N = 3; - // Amplification coefficient multiplied by N^(N - 1) - // Higher value makes the curve more flat - // Lower value makes the curve more like constant product AMM - uint private constant A = 1000 * (N ** (N - 1)); - // 0.03% - uint private constant SWAP_FEE = 300; - // Liquidity fee is derived from 2 constraints - // 1. Fee is 0 for adding / removing liquidity that results in a balanced pool - // 2. Swapping in a balanced pool is like adding and then removing liquidity - // from a balanced pool - // swap fee = add liquidity fee + remove liquidity fee - uint private constant LIQUIDITY_FEE = (SWAP_FEE * N) / (4 * (N - 1)); - uint private constant FEE_DENOMINATOR = 1e6; - - address[N] public tokens; - // Normalize each token to 18 decimals - // Example - DAI (18 decimals), USDC (6 decimals), USDT (6 decimals) - uint[N] private multipliers = [1, 1e12, 1e12]; - uint[N] public balances; - - // 1 share = 1e18, 18 decimals - uint private constant DECIMALS = 18; - uint public totalSupply; - mapping(address => uint) public balanceOf; - - function _mint(address _to, uint _amount) private { - balanceOf[_to] += _amount; - totalSupply += _amount; - } - - function _burn(address _from, uint _amount) private { - balanceOf[_from] -= _amount; - totalSupply -= _amount; - } - - // Return precision-adjusted balances, adjusted to 18 decimals - function _xp() private view returns (uint[N] memory xp) { - for (uint i; i < N; ++i) { - xp[i] = balances[i] * multipliers[i]; - } - } - - /** - * @notice Calculate D, sum of balances in a perfectly balanced pool - * If balances of x_0, x_1, ... x_(n-1) then sum(x_i) = D - * @param xp Precision-adjusted balances - * @return D - */ - function _getD(uint[N] memory xp) private pure returns (uint) { - /* - Newton's method to compute D - ----------------------------- - f(D) = ADn^n + D^(n + 1) / (n^n prod(x_i)) - An^n sum(x_i) - D - f'(D) = An^n + (n + 1) D^n / (n^n prod(x_i)) - 1 - - (as + np)D_n - D_(n+1) = ----------------------- - (a - 1)D_n + (n + 1)p - - a = An^n - s = sum(x_i) - p = (D_n)^(n + 1) / (n^n prod(x_i)) - */ - uint a = A * N; // An^n - - uint s; // x_0 + x_1 + ... + x_(n-1) - for (uint i; i < N; ++i) { - s += xp[i]; - } - - // Newton's method - // Initial guess, d <= s - uint d = s; - uint d_prev; - for (uint i; i < 255; ++i) { - // p = D^(n + 1) / (n^n * x_0 * ... * x_(n-1)) - uint p = d; - for (uint j; j < N; ++j) { - p = (p * d) / (N * xp[j]); - } - d_prev = d; - d = ((a * s + N * p) * d) / ((a - 1) * d + (N + 1) * p); - - if (Math.abs(d, d_prev) <= 1) { - return d; - } - } - revert("D didn't converge"); - } - - /** - * @notice Calculate the new balance of token j given the new balance of token i - * @param i Index of token in - * @param j Index of token out - * @param x New balance of token i - * @param xp Current precision-adjusted balances - */ - function _getY( - uint i, - uint j, - uint x, - uint[N] memory xp - ) private pure returns (uint) { - /* - Newton's method to compute y - ----------------------------- - y = x_j - - f(y) = y^2 + y(b - D) - c - - y_n^2 + c - y_(n+1) = -------------- - 2y_n + b - D - - where - s = sum(x_k), k != j - p = prod(x_k), k != j - b = s + D / (An^n) - c = D^(n + 1) / (n^n * p * An^n) - */ - uint a = A * N; - uint d = _getD(xp); - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k == i) { - _x = x; - } else if (k == j) { - continue; - } else { - _x = xp[k]; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - /** - * @notice Calculate the new balance of token i given precision-adjusted - * balances xp and liquidity d - * @dev Equation is calculate y is same as _getY - * @param i Index of token to calculate the new balance - * @param xp Precision-adjusted balances - * @param d Liquidity d - * @return New balance of token i - */ - function _getYD(uint i, uint[N] memory xp, uint d) private pure returns (uint) { - uint a = A * N; - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k != i) { - _x = xp[k]; - } else { - continue; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - // Estimate value of 1 share - // How many tokens is one share worth? - function getVirtualPrice() external view returns (uint) { - uint d = _getD(_xp()); - uint _totalSupply = totalSupply; - if (_totalSupply > 0) { - return (d * 10 ** DECIMALS) / _totalSupply; - } - return 0; - } - - /** - * @notice Swap dx amount of token i for token j - * @param i Index of token in - * @param j Index of token out - * @param dx Token in amount - * @param minDy Minimum token out - */ - function swap(uint i, uint j, uint dx, uint minDy) external returns (uint dy) { - require(i != j, "i = j"); - - IERC20(tokens[i]).transferFrom(msg.sender, address(this), dx); - - // Calculate dy - uint[N] memory xp = _xp(); - uint x = xp[i] + dx * multipliers[i]; - - uint y0 = xp[j]; - uint y1 = _getY(i, j, x, xp); - // y0 must be >= y1, since x has increased - // -1 to round down - dy = (y0 - y1 - 1) / multipliers[j]; - - // Subtract fee from dy - uint fee = (dy * SWAP_FEE) / FEE_DENOMINATOR; - dy -= fee; - require(dy >= minDy, "dy < min"); - - balances[i] += dx; - balances[j] -= dy; - - IERC20(tokens[j]).transfer(msg.sender, dy); - } - - function addLiquidity( - uint[N] calldata amounts, - uint minShares - ) external returns (uint shares) { - // calculate current liquidity d0 - uint _totalSupply = totalSupply; - uint d0; - uint[N] memory old_xs = _xp(); - if (_totalSupply > 0) { - d0 = _getD(old_xs); - bug(); - } - - // Transfer tokens in - uint[N] memory new_xs; - for (uint i; i < N; ++i) { - uint amount = amounts[i]; - if (amount > 0) { - IERC20(tokens[i]).transferFrom(msg.sender, address(this), amount); - new_xs[i] = old_xs[i] + amount * multipliers[i]; - } else { - new_xs[i] = old_xs[i]; - } - } - - // Calculate new liquidity d1 - uint d1 = _getD(new_xs); - require(d1 > d0, "liquidity didn't increase"); - - // Reccalcuate D accounting for fee on imbalance - uint d2; - if (_totalSupply > 0) { - for (uint i; i < N; ++i) { - // TODO: why old_xs[i] * d1 / d0? why not d1 / N? - uint idealBalance = (old_xs[i] * d1) / d0; - uint diff = Math.abs(new_xs[i], idealBalance); - new_xs[i] -= (LIQUIDITY_FEE * diff) / FEE_DENOMINATOR; - } - - d2 = _getD(new_xs); - } else { - d2 = d1; - } - - // Update balances - for (uint i; i < N; ++i) { - balances[i] += amounts[i]; - } - - // Shares to mint = (d2 - d0) / d0 * total supply - // d1 >= d2 >= d0 - if (_totalSupply > 0) { - shares = ((d2 - d0) * _totalSupply) / d0; - } else { - shares = d2; - } - require(shares >= minShares, "shares < min"); - _mint(msg.sender, shares); - } - - function removeLiquidity( - uint shares, - uint[N] calldata minAmountsOut - ) external returns (uint[N] memory amountsOut) { - uint _totalSupply = totalSupply; - - for (uint i; i < N; ++i) { - uint amountOut = (balances[i] * shares) / _totalSupply; - require(amountOut >= minAmountsOut[i], "out < min"); - - balances[i] -= amountOut; - amountsOut[i] = amountOut; - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } - - _burn(msg.sender, shares); - } - - /** - * @notice Calculate amount of token i to receive for shares - * @param shares Shares to burn - * @param i Index of token to withdraw - * @return dy Amount of token i to receive - * fee Fee for withdraw. Fee already included in dy - */ - function _calcWithdrawOneToken( - uint shares, - uint i - ) private view returns (uint dy, uint fee) { - uint _totalSupply = totalSupply; - uint[N] memory xp = _xp(); - - // Calculate d0 and d1 - uint d0 = _getD(xp); - uint d1 = d0 - (d0 * shares) / _totalSupply; - - // Calculate reduction in y if D = d1 - uint y0 = _getYD(i, xp, d1); - // d1 <= d0 so y must be <= xp[i] - uint dy0 = (xp[i] - y0) / multipliers[i]; - - // Calculate imbalance fee, update xp with fees - uint dx; - for (uint j; j < N; ++j) { - if (j == i) { - dx = (xp[j] * d1) / d0 - y0; - } else { - // d1 / d0 <= 1 - dx = xp[j] - (xp[j] * d1) / d0; - } - xp[j] -= (LIQUIDITY_FEE * dx) / FEE_DENOMINATOR; - } - - // Recalculate y with xp including imbalance fees - uint y1 = _getYD(i, xp, d1); - // - 1 to round down - dy = (xp[i] - y1 - 1) / multipliers[i]; - fee = dy0 - dy; - } - - function calcWithdrawOneToken( - uint shares, - uint i - ) external view returns (uint dy, uint fee) { - return _calcWithdrawOneToken(shares, i); - } - - /** - * @notice Withdraw liquidity in token i - * @param shares Shares to burn - * @param i Token to withdraw - * @param minAmountOut Minimum amount of token i that must be withdrawn - */ - function removeLiquidityOneToken( - uint shares, - uint i, - uint minAmountOut - ) external returns (uint amountOut) { - (amountOut, ) = _calcWithdrawOneToken(shares, i); - require(amountOut >= minAmountOut, "out < min"); - - balances[i] -= amountOut; - _burn(msg.sender, shares); - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint amount); - event Approval(address indexed owner, address indexed spender, uint amount); -} diff --git a/tests/evm/sbe@stable-swap-amm-addLiquidity2/StableSwap.sol b/tests/evm/sbe@stable-swap-amm-addLiquidity2/StableSwap.sol deleted file mode 100644 index 2d3cd2f3f..000000000 --- a/tests/evm/sbe@stable-swap-amm-addLiquidity2/StableSwap.sol +++ /dev/null @@ -1,444 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8; -import "../../../solidity_utils/lib.sol"; - -/* -Invariant - price of trade and amount of liquidity are determined by this equation - -An^n sum(x_i) + D = ADn^n + D^(n + 1) / (n^n prod(x_i)) - -Topics -0. Newton's method x_(n + 1) = x_n - f(x_n) / f'(x_n) -1. Invariant -2. Swap - - Calculate Y - - Calculate D -3. Get virtual price -4. Add liquidity - - Imbalance fee -5. Remove liquidity -6. Remove liquidity one token - - Calculate withdraw one token - - getYD -TODO: test? -*/ - -library Math { - function abs(uint x, uint y) internal pure returns (uint) { - return x >= y ? x - y : y - x; - } -} - -contract StableSwap { - // Number of tokens - uint private constant N = 3; - // Amplification coefficient multiplied by N^(N - 1) - // Higher value makes the curve more flat - // Lower value makes the curve more like constant product AMM - uint private constant A = 1000 * (N ** (N - 1)); - // 0.03% - uint private constant SWAP_FEE = 300; - // Liquidity fee is derived from 2 constraints - // 1. Fee is 0 for adding / removing liquidity that results in a balanced pool - // 2. Swapping in a balanced pool is like adding and then removing liquidity - // from a balanced pool - // swap fee = add liquidity fee + remove liquidity fee - uint private constant LIQUIDITY_FEE = (SWAP_FEE * N) / (4 * (N - 1)); - uint private constant FEE_DENOMINATOR = 1e6; - - address[N] public tokens; - // Normalize each token to 18 decimals - // Example - DAI (18 decimals), USDC (6 decimals), USDT (6 decimals) - uint[N] private multipliers = [1, 1e12, 1e12]; - uint[N] public balances; - - // 1 share = 1e18, 18 decimals - uint private constant DECIMALS = 18; - uint public totalSupply; - mapping(address => uint) public balanceOf; - - function _mint(address _to, uint _amount) private { - balanceOf[_to] += _amount; - totalSupply += _amount; - } - - function _burn(address _from, uint _amount) private { - balanceOf[_from] -= _amount; - totalSupply -= _amount; - } - - // Return precision-adjusted balances, adjusted to 18 decimals - function _xp() private view returns (uint[N] memory xp) { - for (uint i; i < N; ++i) { - xp[i] = balances[i] * multipliers[i]; - } - } - - /** - * @notice Calculate D, sum of balances in a perfectly balanced pool - * If balances of x_0, x_1, ... x_(n-1) then sum(x_i) = D - * @param xp Precision-adjusted balances - * @return D - */ - function _getD(uint[N] memory xp) private pure returns (uint) { - /* - Newton's method to compute D - ----------------------------- - f(D) = ADn^n + D^(n + 1) / (n^n prod(x_i)) - An^n sum(x_i) - D - f'(D) = An^n + (n + 1) D^n / (n^n prod(x_i)) - 1 - - (as + np)D_n - D_(n+1) = ----------------------- - (a - 1)D_n + (n + 1)p - - a = An^n - s = sum(x_i) - p = (D_n)^(n + 1) / (n^n prod(x_i)) - */ - uint a = A * N; // An^n - - uint s; // x_0 + x_1 + ... + x_(n-1) - for (uint i; i < N; ++i) { - s += xp[i]; - } - - // Newton's method - // Initial guess, d <= s - uint d = s; - uint d_prev; - for (uint i; i < 255; ++i) { - // p = D^(n + 1) / (n^n * x_0 * ... * x_(n-1)) - uint p = d; - for (uint j; j < N; ++j) { - p = (p * d) / (N * xp[j]); - } - d_prev = d; - d = ((a * s + N * p) * d) / ((a - 1) * d + (N + 1) * p); - - if (Math.abs(d, d_prev) <= 1) { - return d; - } - } - revert("D didn't converge"); - } - - /** - * @notice Calculate the new balance of token j given the new balance of token i - * @param i Index of token in - * @param j Index of token out - * @param x New balance of token i - * @param xp Current precision-adjusted balances - */ - function _getY( - uint i, - uint j, - uint x, - uint[N] memory xp - ) private pure returns (uint) { - /* - Newton's method to compute y - ----------------------------- - y = x_j - - f(y) = y^2 + y(b - D) - c - - y_n^2 + c - y_(n+1) = -------------- - 2y_n + b - D - - where - s = sum(x_k), k != j - p = prod(x_k), k != j - b = s + D / (An^n) - c = D^(n + 1) / (n^n * p * An^n) - */ - uint a = A * N; - uint d = _getD(xp); - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k == i) { - _x = x; - } else if (k == j) { - continue; - } else { - _x = xp[k]; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - /** - * @notice Calculate the new balance of token i given precision-adjusted - * balances xp and liquidity d - * @dev Equation is calculate y is same as _getY - * @param i Index of token to calculate the new balance - * @param xp Precision-adjusted balances - * @param d Liquidity d - * @return New balance of token i - */ - function _getYD(uint i, uint[N] memory xp, uint d) private pure returns (uint) { - uint a = A * N; - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k != i) { - _x = xp[k]; - } else { - continue; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - // Estimate value of 1 share - // How many tokens is one share worth? - function getVirtualPrice() external view returns (uint) { - uint d = _getD(_xp()); - uint _totalSupply = totalSupply; - if (_totalSupply > 0) { - return (d * 10 ** DECIMALS) / _totalSupply; - } - return 0; - } - - /** - * @notice Swap dx amount of token i for token j - * @param i Index of token in - * @param j Index of token out - * @param dx Token in amount - * @param minDy Minimum token out - */ - function swap(uint i, uint j, uint dx, uint minDy) external returns (uint dy) { - require(i != j, "i = j"); - - IERC20(tokens[i]).transferFrom(msg.sender, address(this), dx); - - // Calculate dy - uint[N] memory xp = _xp(); - uint x = xp[i] + dx * multipliers[i]; - - uint y0 = xp[j]; - uint y1 = _getY(i, j, x, xp); - // y0 must be >= y1, since x has increased - // -1 to round down - dy = (y0 - y1 - 1) / multipliers[j]; - - // Subtract fee from dy - uint fee = (dy * SWAP_FEE) / FEE_DENOMINATOR; - dy -= fee; - require(dy >= minDy, "dy < min"); - - balances[i] += dx; - balances[j] -= dy; - - IERC20(tokens[j]).transfer(msg.sender, dy); - } - - function addLiquidity( - uint[N] calldata amounts, - uint minShares - ) external returns (uint shares) { - // calculate current liquidity d0 - uint _totalSupply = totalSupply; - uint d0; - uint[N] memory old_xs = _xp(); - if (_totalSupply > 0) { - d0 = _getD(old_xs); - } - - // Transfer tokens in - uint[N] memory new_xs; - for (uint i; i < N; ++i) { - uint amount = amounts[i]; - if (amount > 0) { - IERC20(tokens[i]).transferFrom(msg.sender, address(this), amount); - new_xs[i] = old_xs[i] + amount * multipliers[i]; - bug(); - } else { - new_xs[i] = old_xs[i]; - } - } - - // Calculate new liquidity d1 - uint d1 = _getD(new_xs); - require(d1 > d0, "liquidity didn't increase"); - - // Reccalcuate D accounting for fee on imbalance - uint d2; - if (_totalSupply > 0) { - for (uint i; i < N; ++i) { - // TODO: why old_xs[i] * d1 / d0? why not d1 / N? - uint idealBalance = (old_xs[i] * d1) / d0; - uint diff = Math.abs(new_xs[i], idealBalance); - new_xs[i] -= (LIQUIDITY_FEE * diff) / FEE_DENOMINATOR; - } - - d2 = _getD(new_xs); - } else { - d2 = d1; - } - - // Update balances - for (uint i; i < N; ++i) { - balances[i] += amounts[i]; - } - - // Shares to mint = (d2 - d0) / d0 * total supply - // d1 >= d2 >= d0 - if (_totalSupply > 0) { - shares = ((d2 - d0) * _totalSupply) / d0; - } else { - shares = d2; - } - require(shares >= minShares, "shares < min"); - _mint(msg.sender, shares); - } - - function removeLiquidity( - uint shares, - uint[N] calldata minAmountsOut - ) external returns (uint[N] memory amountsOut) { - uint _totalSupply = totalSupply; - - for (uint i; i < N; ++i) { - uint amountOut = (balances[i] * shares) / _totalSupply; - require(amountOut >= minAmountsOut[i], "out < min"); - - balances[i] -= amountOut; - amountsOut[i] = amountOut; - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } - - _burn(msg.sender, shares); - } - - /** - * @notice Calculate amount of token i to receive for shares - * @param shares Shares to burn - * @param i Index of token to withdraw - * @return dy Amount of token i to receive - * fee Fee for withdraw. Fee already included in dy - */ - function _calcWithdrawOneToken( - uint shares, - uint i - ) private view returns (uint dy, uint fee) { - uint _totalSupply = totalSupply; - uint[N] memory xp = _xp(); - - // Calculate d0 and d1 - uint d0 = _getD(xp); - uint d1 = d0 - (d0 * shares) / _totalSupply; - - // Calculate reduction in y if D = d1 - uint y0 = _getYD(i, xp, d1); - // d1 <= d0 so y must be <= xp[i] - uint dy0 = (xp[i] - y0) / multipliers[i]; - - // Calculate imbalance fee, update xp with fees - uint dx; - for (uint j; j < N; ++j) { - if (j == i) { - dx = (xp[j] * d1) / d0 - y0; - } else { - // d1 / d0 <= 1 - dx = xp[j] - (xp[j] * d1) / d0; - } - xp[j] -= (LIQUIDITY_FEE * dx) / FEE_DENOMINATOR; - } - - // Recalculate y with xp including imbalance fees - uint y1 = _getYD(i, xp, d1); - // - 1 to round down - dy = (xp[i] - y1 - 1) / multipliers[i]; - fee = dy0 - dy; - } - - function calcWithdrawOneToken( - uint shares, - uint i - ) external view returns (uint dy, uint fee) { - return _calcWithdrawOneToken(shares, i); - } - - /** - * @notice Withdraw liquidity in token i - * @param shares Shares to burn - * @param i Token to withdraw - * @param minAmountOut Minimum amount of token i that must be withdrawn - */ - function removeLiquidityOneToken( - uint shares, - uint i, - uint minAmountOut - ) external returns (uint amountOut) { - (amountOut, ) = _calcWithdrawOneToken(shares, i); - require(amountOut >= minAmountOut, "out < min"); - - balances[i] -= amountOut; - _burn(msg.sender, shares); - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint amount); - event Approval(address indexed owner, address indexed spender, uint amount); -} diff --git a/tests/evm/sbe@stable-swap-amm-addLiquidity3/StableSwap.sol b/tests/evm/sbe@stable-swap-amm-addLiquidity3/StableSwap.sol deleted file mode 100644 index c16f50b0c..000000000 --- a/tests/evm/sbe@stable-swap-amm-addLiquidity3/StableSwap.sol +++ /dev/null @@ -1,444 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8; -import "../../../solidity_utils/lib.sol"; - -/* -Invariant - price of trade and amount of liquidity are determined by this equation - -An^n sum(x_i) + D = ADn^n + D^(n + 1) / (n^n prod(x_i)) - -Topics -0. Newton's method x_(n + 1) = x_n - f(x_n) / f'(x_n) -1. Invariant -2. Swap - - Calculate Y - - Calculate D -3. Get virtual price -4. Add liquidity - - Imbalance fee -5. Remove liquidity -6. Remove liquidity one token - - Calculate withdraw one token - - getYD -TODO: test? -*/ - -library Math { - function abs(uint x, uint y) internal pure returns (uint) { - return x >= y ? x - y : y - x; - } -} - -contract StableSwap { - // Number of tokens - uint private constant N = 3; - // Amplification coefficient multiplied by N^(N - 1) - // Higher value makes the curve more flat - // Lower value makes the curve more like constant product AMM - uint private constant A = 1000 * (N ** (N - 1)); - // 0.03% - uint private constant SWAP_FEE = 300; - // Liquidity fee is derived from 2 constraints - // 1. Fee is 0 for adding / removing liquidity that results in a balanced pool - // 2. Swapping in a balanced pool is like adding and then removing liquidity - // from a balanced pool - // swap fee = add liquidity fee + remove liquidity fee - uint private constant LIQUIDITY_FEE = (SWAP_FEE * N) / (4 * (N - 1)); - uint private constant FEE_DENOMINATOR = 1e6; - - address[N] public tokens; - // Normalize each token to 18 decimals - // Example - DAI (18 decimals), USDC (6 decimals), USDT (6 decimals) - uint[N] private multipliers = [1, 1e12, 1e12]; - uint[N] public balances; - - // 1 share = 1e18, 18 decimals - uint private constant DECIMALS = 18; - uint public totalSupply; - mapping(address => uint) public balanceOf; - - function _mint(address _to, uint _amount) private { - balanceOf[_to] += _amount; - totalSupply += _amount; - } - - function _burn(address _from, uint _amount) private { - balanceOf[_from] -= _amount; - totalSupply -= _amount; - } - - // Return precision-adjusted balances, adjusted to 18 decimals - function _xp() private view returns (uint[N] memory xp) { - for (uint i; i < N; ++i) { - xp[i] = balances[i] * multipliers[i]; - } - } - - /** - * @notice Calculate D, sum of balances in a perfectly balanced pool - * If balances of x_0, x_1, ... x_(n-1) then sum(x_i) = D - * @param xp Precision-adjusted balances - * @return D - */ - function _getD(uint[N] memory xp) private pure returns (uint) { - /* - Newton's method to compute D - ----------------------------- - f(D) = ADn^n + D^(n + 1) / (n^n prod(x_i)) - An^n sum(x_i) - D - f'(D) = An^n + (n + 1) D^n / (n^n prod(x_i)) - 1 - - (as + np)D_n - D_(n+1) = ----------------------- - (a - 1)D_n + (n + 1)p - - a = An^n - s = sum(x_i) - p = (D_n)^(n + 1) / (n^n prod(x_i)) - */ - uint a = A * N; // An^n - - uint s; // x_0 + x_1 + ... + x_(n-1) - for (uint i; i < N; ++i) { - s += xp[i]; - } - - // Newton's method - // Initial guess, d <= s - uint d = s; - uint d_prev; - for (uint i; i < 255; ++i) { - // p = D^(n + 1) / (n^n * x_0 * ... * x_(n-1)) - uint p = d; - for (uint j; j < N; ++j) { - p = (p * d) / (N * xp[j]); - } - d_prev = d; - d = ((a * s + N * p) * d) / ((a - 1) * d + (N + 1) * p); - - if (Math.abs(d, d_prev) <= 1) { - return d; - } - } - revert("D didn't converge"); - } - - /** - * @notice Calculate the new balance of token j given the new balance of token i - * @param i Index of token in - * @param j Index of token out - * @param x New balance of token i - * @param xp Current precision-adjusted balances - */ - function _getY( - uint i, - uint j, - uint x, - uint[N] memory xp - ) private pure returns (uint) { - /* - Newton's method to compute y - ----------------------------- - y = x_j - - f(y) = y^2 + y(b - D) - c - - y_n^2 + c - y_(n+1) = -------------- - 2y_n + b - D - - where - s = sum(x_k), k != j - p = prod(x_k), k != j - b = s + D / (An^n) - c = D^(n + 1) / (n^n * p * An^n) - */ - uint a = A * N; - uint d = _getD(xp); - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k == i) { - _x = x; - } else if (k == j) { - continue; - } else { - _x = xp[k]; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - /** - * @notice Calculate the new balance of token i given precision-adjusted - * balances xp and liquidity d - * @dev Equation is calculate y is same as _getY - * @param i Index of token to calculate the new balance - * @param xp Precision-adjusted balances - * @param d Liquidity d - * @return New balance of token i - */ - function _getYD(uint i, uint[N] memory xp, uint d) private pure returns (uint) { - uint a = A * N; - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k != i) { - _x = xp[k]; - } else { - continue; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - // Estimate value of 1 share - // How many tokens is one share worth? - function getVirtualPrice() external view returns (uint) { - uint d = _getD(_xp()); - uint _totalSupply = totalSupply; - if (_totalSupply > 0) { - return (d * 10 ** DECIMALS) / _totalSupply; - } - return 0; - } - - /** - * @notice Swap dx amount of token i for token j - * @param i Index of token in - * @param j Index of token out - * @param dx Token in amount - * @param minDy Minimum token out - */ - function swap(uint i, uint j, uint dx, uint minDy) external returns (uint dy) { - require(i != j, "i = j"); - - IERC20(tokens[i]).transferFrom(msg.sender, address(this), dx); - - // Calculate dy - uint[N] memory xp = _xp(); - uint x = xp[i] + dx * multipliers[i]; - - uint y0 = xp[j]; - uint y1 = _getY(i, j, x, xp); - // y0 must be >= y1, since x has increased - // -1 to round down - dy = (y0 - y1 - 1) / multipliers[j]; - - // Subtract fee from dy - uint fee = (dy * SWAP_FEE) / FEE_DENOMINATOR; - dy -= fee; - require(dy >= minDy, "dy < min"); - - balances[i] += dx; - balances[j] -= dy; - - IERC20(tokens[j]).transfer(msg.sender, dy); - } - - function addLiquidity( - uint[N] calldata amounts, - uint minShares - ) external returns (uint shares) { - // calculate current liquidity d0 - uint _totalSupply = totalSupply; - uint d0; - uint[N] memory old_xs = _xp(); - if (_totalSupply > 0) { - d0 = _getD(old_xs); - } - - // Transfer tokens in - uint[N] memory new_xs; - for (uint i; i < N; ++i) { - uint amount = amounts[i]; - if (amount > 0) { - IERC20(tokens[i]).transferFrom(msg.sender, address(this), amount); - new_xs[i] = old_xs[i] + amount * multipliers[i]; - } else { - new_xs[i] = old_xs[i]; - } - } - - // Calculate new liquidity d1 - uint d1 = _getD(new_xs); - require(d1 > d0, "liquidity didn't increase"); - - // Reccalcuate D accounting for fee on imbalance - uint d2; - if (_totalSupply > 0) { - for (uint i; i < N; ++i) { - // TODO: why old_xs[i] * d1 / d0? why not d1 / N? - uint idealBalance = (old_xs[i] * d1) / d0; - uint diff = Math.abs(new_xs[i], idealBalance); - new_xs[i] -= (LIQUIDITY_FEE * diff) / FEE_DENOMINATOR; - bug(); - } - - d2 = _getD(new_xs); - } else { - d2 = d1; - } - - // Update balances - for (uint i; i < N; ++i) { - balances[i] += amounts[i]; - } - - // Shares to mint = (d2 - d0) / d0 * total supply - // d1 >= d2 >= d0 - if (_totalSupply > 0) { - shares = ((d2 - d0) * _totalSupply) / d0; - } else { - shares = d2; - } - require(shares >= minShares, "shares < min"); - _mint(msg.sender, shares); - } - - function removeLiquidity( - uint shares, - uint[N] calldata minAmountsOut - ) external returns (uint[N] memory amountsOut) { - uint _totalSupply = totalSupply; - - for (uint i; i < N; ++i) { - uint amountOut = (balances[i] * shares) / _totalSupply; - require(amountOut >= minAmountsOut[i], "out < min"); - - balances[i] -= amountOut; - amountsOut[i] = amountOut; - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } - - _burn(msg.sender, shares); - } - - /** - * @notice Calculate amount of token i to receive for shares - * @param shares Shares to burn - * @param i Index of token to withdraw - * @return dy Amount of token i to receive - * fee Fee for withdraw. Fee already included in dy - */ - function _calcWithdrawOneToken( - uint shares, - uint i - ) private view returns (uint dy, uint fee) { - uint _totalSupply = totalSupply; - uint[N] memory xp = _xp(); - - // Calculate d0 and d1 - uint d0 = _getD(xp); - uint d1 = d0 - (d0 * shares) / _totalSupply; - - // Calculate reduction in y if D = d1 - uint y0 = _getYD(i, xp, d1); - // d1 <= d0 so y must be <= xp[i] - uint dy0 = (xp[i] - y0) / multipliers[i]; - - // Calculate imbalance fee, update xp with fees - uint dx; - for (uint j; j < N; ++j) { - if (j == i) { - dx = (xp[j] * d1) / d0 - y0; - } else { - // d1 / d0 <= 1 - dx = xp[j] - (xp[j] * d1) / d0; - } - xp[j] -= (LIQUIDITY_FEE * dx) / FEE_DENOMINATOR; - } - - // Recalculate y with xp including imbalance fees - uint y1 = _getYD(i, xp, d1); - // - 1 to round down - dy = (xp[i] - y1 - 1) / multipliers[i]; - fee = dy0 - dy; - } - - function calcWithdrawOneToken( - uint shares, - uint i - ) external view returns (uint dy, uint fee) { - return _calcWithdrawOneToken(shares, i); - } - - /** - * @notice Withdraw liquidity in token i - * @param shares Shares to burn - * @param i Token to withdraw - * @param minAmountOut Minimum amount of token i that must be withdrawn - */ - function removeLiquidityOneToken( - uint shares, - uint i, - uint minAmountOut - ) external returns (uint amountOut) { - (amountOut, ) = _calcWithdrawOneToken(shares, i); - require(amountOut >= minAmountOut, "out < min"); - - balances[i] -= amountOut; - _burn(msg.sender, shares); - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint amount); - event Approval(address indexed owner, address indexed spender, uint amount); -} diff --git a/tests/evm/sbe@stable-swap-amm-addLiquidity4/StableSwap.sol b/tests/evm/sbe@stable-swap-amm-addLiquidity4/StableSwap.sol deleted file mode 100644 index 5be5ffd15..000000000 --- a/tests/evm/sbe@stable-swap-amm-addLiquidity4/StableSwap.sol +++ /dev/null @@ -1,444 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8; -import "../../../solidity_utils/lib.sol"; - -/* -Invariant - price of trade and amount of liquidity are determined by this equation - -An^n sum(x_i) + D = ADn^n + D^(n + 1) / (n^n prod(x_i)) - -Topics -0. Newton's method x_(n + 1) = x_n - f(x_n) / f'(x_n) -1. Invariant -2. Swap - - Calculate Y - - Calculate D -3. Get virtual price -4. Add liquidity - - Imbalance fee -5. Remove liquidity -6. Remove liquidity one token - - Calculate withdraw one token - - getYD -TODO: test? -*/ - -library Math { - function abs(uint x, uint y) internal pure returns (uint) { - return x >= y ? x - y : y - x; - } -} - -contract StableSwap { - // Number of tokens - uint private constant N = 3; - // Amplification coefficient multiplied by N^(N - 1) - // Higher value makes the curve more flat - // Lower value makes the curve more like constant product AMM - uint private constant A = 1000 * (N ** (N - 1)); - // 0.03% - uint private constant SWAP_FEE = 300; - // Liquidity fee is derived from 2 constraints - // 1. Fee is 0 for adding / removing liquidity that results in a balanced pool - // 2. Swapping in a balanced pool is like adding and then removing liquidity - // from a balanced pool - // swap fee = add liquidity fee + remove liquidity fee - uint private constant LIQUIDITY_FEE = (SWAP_FEE * N) / (4 * (N - 1)); - uint private constant FEE_DENOMINATOR = 1e6; - - address[N] public tokens; - // Normalize each token to 18 decimals - // Example - DAI (18 decimals), USDC (6 decimals), USDT (6 decimals) - uint[N] private multipliers = [1, 1e12, 1e12]; - uint[N] public balances; - - // 1 share = 1e18, 18 decimals - uint private constant DECIMALS = 18; - uint public totalSupply; - mapping(address => uint) public balanceOf; - - function _mint(address _to, uint _amount) private { - balanceOf[_to] += _amount; - totalSupply += _amount; - } - - function _burn(address _from, uint _amount) private { - balanceOf[_from] -= _amount; - totalSupply -= _amount; - } - - // Return precision-adjusted balances, adjusted to 18 decimals - function _xp() private view returns (uint[N] memory xp) { - for (uint i; i < N; ++i) { - xp[i] = balances[i] * multipliers[i]; - } - } - - /** - * @notice Calculate D, sum of balances in a perfectly balanced pool - * If balances of x_0, x_1, ... x_(n-1) then sum(x_i) = D - * @param xp Precision-adjusted balances - * @return D - */ - function _getD(uint[N] memory xp) private pure returns (uint) { - /* - Newton's method to compute D - ----------------------------- - f(D) = ADn^n + D^(n + 1) / (n^n prod(x_i)) - An^n sum(x_i) - D - f'(D) = An^n + (n + 1) D^n / (n^n prod(x_i)) - 1 - - (as + np)D_n - D_(n+1) = ----------------------- - (a - 1)D_n + (n + 1)p - - a = An^n - s = sum(x_i) - p = (D_n)^(n + 1) / (n^n prod(x_i)) - */ - uint a = A * N; // An^n - - uint s; // x_0 + x_1 + ... + x_(n-1) - for (uint i; i < N; ++i) { - s += xp[i]; - } - - // Newton's method - // Initial guess, d <= s - uint d = s; - uint d_prev; - for (uint i; i < 255; ++i) { - // p = D^(n + 1) / (n^n * x_0 * ... * x_(n-1)) - uint p = d; - for (uint j; j < N; ++j) { - p = (p * d) / (N * xp[j]); - } - d_prev = d; - d = ((a * s + N * p) * d) / ((a - 1) * d + (N + 1) * p); - - if (Math.abs(d, d_prev) <= 1) { - return d; - } - } - revert("D didn't converge"); - } - - /** - * @notice Calculate the new balance of token j given the new balance of token i - * @param i Index of token in - * @param j Index of token out - * @param x New balance of token i - * @param xp Current precision-adjusted balances - */ - function _getY( - uint i, - uint j, - uint x, - uint[N] memory xp - ) private pure returns (uint) { - /* - Newton's method to compute y - ----------------------------- - y = x_j - - f(y) = y^2 + y(b - D) - c - - y_n^2 + c - y_(n+1) = -------------- - 2y_n + b - D - - where - s = sum(x_k), k != j - p = prod(x_k), k != j - b = s + D / (An^n) - c = D^(n + 1) / (n^n * p * An^n) - */ - uint a = A * N; - uint d = _getD(xp); - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k == i) { - _x = x; - } else if (k == j) { - continue; - } else { - _x = xp[k]; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - /** - * @notice Calculate the new balance of token i given precision-adjusted - * balances xp and liquidity d - * @dev Equation is calculate y is same as _getY - * @param i Index of token to calculate the new balance - * @param xp Precision-adjusted balances - * @param d Liquidity d - * @return New balance of token i - */ - function _getYD(uint i, uint[N] memory xp, uint d) private pure returns (uint) { - uint a = A * N; - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k != i) { - _x = xp[k]; - } else { - continue; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - // Estimate value of 1 share - // How many tokens is one share worth? - function getVirtualPrice() external view returns (uint) { - uint d = _getD(_xp()); - uint _totalSupply = totalSupply; - if (_totalSupply > 0) { - return (d * 10 ** DECIMALS) / _totalSupply; - } - return 0; - } - - /** - * @notice Swap dx amount of token i for token j - * @param i Index of token in - * @param j Index of token out - * @param dx Token in amount - * @param minDy Minimum token out - */ - function swap(uint i, uint j, uint dx, uint minDy) external returns (uint dy) { - require(i != j, "i = j"); - - IERC20(tokens[i]).transferFrom(msg.sender, address(this), dx); - - // Calculate dy - uint[N] memory xp = _xp(); - uint x = xp[i] + dx * multipliers[i]; - - uint y0 = xp[j]; - uint y1 = _getY(i, j, x, xp); - // y0 must be >= y1, since x has increased - // -1 to round down - dy = (y0 - y1 - 1) / multipliers[j]; - - // Subtract fee from dy - uint fee = (dy * SWAP_FEE) / FEE_DENOMINATOR; - dy -= fee; - require(dy >= minDy, "dy < min"); - - balances[i] += dx; - balances[j] -= dy; - - IERC20(tokens[j]).transfer(msg.sender, dy); - } - - function addLiquidity( - uint[N] calldata amounts, - uint minShares - ) external returns (uint shares) { - // calculate current liquidity d0 - uint _totalSupply = totalSupply; - uint d0; - uint[N] memory old_xs = _xp(); - if (_totalSupply > 0) { - d0 = _getD(old_xs); - } - - // Transfer tokens in - uint[N] memory new_xs; - for (uint i; i < N; ++i) { - uint amount = amounts[i]; - if (amount > 0) { - IERC20(tokens[i]).transferFrom(msg.sender, address(this), amount); - new_xs[i] = old_xs[i] + amount * multipliers[i]; - } else { - new_xs[i] = old_xs[i]; - } - } - - // Calculate new liquidity d1 - uint d1 = _getD(new_xs); - require(d1 > d0, "liquidity didn't increase"); - - // Reccalcuate D accounting for fee on imbalance - uint d2; - if (_totalSupply > 0) { - for (uint i; i < N; ++i) { - // TODO: why old_xs[i] * d1 / d0? why not d1 / N? - uint idealBalance = (old_xs[i] * d1) / d0; - uint diff = Math.abs(new_xs[i], idealBalance); - new_xs[i] -= (LIQUIDITY_FEE * diff) / FEE_DENOMINATOR; - } - - d2 = _getD(new_xs); - bug(); - } else { - d2 = d1; - } - - // Update balances - for (uint i; i < N; ++i) { - balances[i] += amounts[i]; - } - - // Shares to mint = (d2 - d0) / d0 * total supply - // d1 >= d2 >= d0 - if (_totalSupply > 0) { - shares = ((d2 - d0) * _totalSupply) / d0; - } else { - shares = d2; - } - require(shares >= minShares, "shares < min"); - _mint(msg.sender, shares); - } - - function removeLiquidity( - uint shares, - uint[N] calldata minAmountsOut - ) external returns (uint[N] memory amountsOut) { - uint _totalSupply = totalSupply; - - for (uint i; i < N; ++i) { - uint amountOut = (balances[i] * shares) / _totalSupply; - require(amountOut >= minAmountsOut[i], "out < min"); - - balances[i] -= amountOut; - amountsOut[i] = amountOut; - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } - - _burn(msg.sender, shares); - } - - /** - * @notice Calculate amount of token i to receive for shares - * @param shares Shares to burn - * @param i Index of token to withdraw - * @return dy Amount of token i to receive - * fee Fee for withdraw. Fee already included in dy - */ - function _calcWithdrawOneToken( - uint shares, - uint i - ) private view returns (uint dy, uint fee) { - uint _totalSupply = totalSupply; - uint[N] memory xp = _xp(); - - // Calculate d0 and d1 - uint d0 = _getD(xp); - uint d1 = d0 - (d0 * shares) / _totalSupply; - - // Calculate reduction in y if D = d1 - uint y0 = _getYD(i, xp, d1); - // d1 <= d0 so y must be <= xp[i] - uint dy0 = (xp[i] - y0) / multipliers[i]; - - // Calculate imbalance fee, update xp with fees - uint dx; - for (uint j; j < N; ++j) { - if (j == i) { - dx = (xp[j] * d1) / d0 - y0; - } else { - // d1 / d0 <= 1 - dx = xp[j] - (xp[j] * d1) / d0; - } - xp[j] -= (LIQUIDITY_FEE * dx) / FEE_DENOMINATOR; - } - - // Recalculate y with xp including imbalance fees - uint y1 = _getYD(i, xp, d1); - // - 1 to round down - dy = (xp[i] - y1 - 1) / multipliers[i]; - fee = dy0 - dy; - } - - function calcWithdrawOneToken( - uint shares, - uint i - ) external view returns (uint dy, uint fee) { - return _calcWithdrawOneToken(shares, i); - } - - /** - * @notice Withdraw liquidity in token i - * @param shares Shares to burn - * @param i Token to withdraw - * @param minAmountOut Minimum amount of token i that must be withdrawn - */ - function removeLiquidityOneToken( - uint shares, - uint i, - uint minAmountOut - ) external returns (uint amountOut) { - (amountOut, ) = _calcWithdrawOneToken(shares, i); - require(amountOut >= minAmountOut, "out < min"); - - balances[i] -= amountOut; - _burn(msg.sender, shares); - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint amount); - event Approval(address indexed owner, address indexed spender, uint amount); -} diff --git a/tests/evm/sbe@stable-swap-amm-addLiquidity5/StableSwap.sol b/tests/evm/sbe@stable-swap-amm-addLiquidity5/StableSwap.sol deleted file mode 100644 index 06c1ff22f..000000000 --- a/tests/evm/sbe@stable-swap-amm-addLiquidity5/StableSwap.sol +++ /dev/null @@ -1,444 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8; -import "../../../solidity_utils/lib.sol"; - -/* -Invariant - price of trade and amount of liquidity are determined by this equation - -An^n sum(x_i) + D = ADn^n + D^(n + 1) / (n^n prod(x_i)) - -Topics -0. Newton's method x_(n + 1) = x_n - f(x_n) / f'(x_n) -1. Invariant -2. Swap - - Calculate Y - - Calculate D -3. Get virtual price -4. Add liquidity - - Imbalance fee -5. Remove liquidity -6. Remove liquidity one token - - Calculate withdraw one token - - getYD -TODO: test? -*/ - -library Math { - function abs(uint x, uint y) internal pure returns (uint) { - return x >= y ? x - y : y - x; - } -} - -contract StableSwap { - // Number of tokens - uint private constant N = 3; - // Amplification coefficient multiplied by N^(N - 1) - // Higher value makes the curve more flat - // Lower value makes the curve more like constant product AMM - uint private constant A = 1000 * (N ** (N - 1)); - // 0.03% - uint private constant SWAP_FEE = 300; - // Liquidity fee is derived from 2 constraints - // 1. Fee is 0 for adding / removing liquidity that results in a balanced pool - // 2. Swapping in a balanced pool is like adding and then removing liquidity - // from a balanced pool - // swap fee = add liquidity fee + remove liquidity fee - uint private constant LIQUIDITY_FEE = (SWAP_FEE * N) / (4 * (N - 1)); - uint private constant FEE_DENOMINATOR = 1e6; - - address[N] public tokens; - // Normalize each token to 18 decimals - // Example - DAI (18 decimals), USDC (6 decimals), USDT (6 decimals) - uint[N] private multipliers = [1, 1e12, 1e12]; - uint[N] public balances; - - // 1 share = 1e18, 18 decimals - uint private constant DECIMALS = 18; - uint public totalSupply; - mapping(address => uint) public balanceOf; - - function _mint(address _to, uint _amount) private { - balanceOf[_to] += _amount; - totalSupply += _amount; - } - - function _burn(address _from, uint _amount) private { - balanceOf[_from] -= _amount; - totalSupply -= _amount; - } - - // Return precision-adjusted balances, adjusted to 18 decimals - function _xp() private view returns (uint[N] memory xp) { - for (uint i; i < N; ++i) { - xp[i] = balances[i] * multipliers[i]; - } - } - - /** - * @notice Calculate D, sum of balances in a perfectly balanced pool - * If balances of x_0, x_1, ... x_(n-1) then sum(x_i) = D - * @param xp Precision-adjusted balances - * @return D - */ - function _getD(uint[N] memory xp) private pure returns (uint) { - /* - Newton's method to compute D - ----------------------------- - f(D) = ADn^n + D^(n + 1) / (n^n prod(x_i)) - An^n sum(x_i) - D - f'(D) = An^n + (n + 1) D^n / (n^n prod(x_i)) - 1 - - (as + np)D_n - D_(n+1) = ----------------------- - (a - 1)D_n + (n + 1)p - - a = An^n - s = sum(x_i) - p = (D_n)^(n + 1) / (n^n prod(x_i)) - */ - uint a = A * N; // An^n - - uint s; // x_0 + x_1 + ... + x_(n-1) - for (uint i; i < N; ++i) { - s += xp[i]; - } - - // Newton's method - // Initial guess, d <= s - uint d = s; - uint d_prev; - for (uint i; i < 255; ++i) { - // p = D^(n + 1) / (n^n * x_0 * ... * x_(n-1)) - uint p = d; - for (uint j; j < N; ++j) { - p = (p * d) / (N * xp[j]); - } - d_prev = d; - d = ((a * s + N * p) * d) / ((a - 1) * d + (N + 1) * p); - - if (Math.abs(d, d_prev) <= 1) { - return d; - } - } - revert("D didn't converge"); - } - - /** - * @notice Calculate the new balance of token j given the new balance of token i - * @param i Index of token in - * @param j Index of token out - * @param x New balance of token i - * @param xp Current precision-adjusted balances - */ - function _getY( - uint i, - uint j, - uint x, - uint[N] memory xp - ) private pure returns (uint) { - /* - Newton's method to compute y - ----------------------------- - y = x_j - - f(y) = y^2 + y(b - D) - c - - y_n^2 + c - y_(n+1) = -------------- - 2y_n + b - D - - where - s = sum(x_k), k != j - p = prod(x_k), k != j - b = s + D / (An^n) - c = D^(n + 1) / (n^n * p * An^n) - */ - uint a = A * N; - uint d = _getD(xp); - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k == i) { - _x = x; - } else if (k == j) { - continue; - } else { - _x = xp[k]; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - /** - * @notice Calculate the new balance of token i given precision-adjusted - * balances xp and liquidity d - * @dev Equation is calculate y is same as _getY - * @param i Index of token to calculate the new balance - * @param xp Precision-adjusted balances - * @param d Liquidity d - * @return New balance of token i - */ - function _getYD(uint i, uint[N] memory xp, uint d) private pure returns (uint) { - uint a = A * N; - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k != i) { - _x = xp[k]; - } else { - continue; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - // Estimate value of 1 share - // How many tokens is one share worth? - function getVirtualPrice() external view returns (uint) { - uint d = _getD(_xp()); - uint _totalSupply = totalSupply; - if (_totalSupply > 0) { - return (d * 10 ** DECIMALS) / _totalSupply; - } - return 0; - } - - /** - * @notice Swap dx amount of token i for token j - * @param i Index of token in - * @param j Index of token out - * @param dx Token in amount - * @param minDy Minimum token out - */ - function swap(uint i, uint j, uint dx, uint minDy) external returns (uint dy) { - require(i != j, "i = j"); - - IERC20(tokens[i]).transferFrom(msg.sender, address(this), dx); - - // Calculate dy - uint[N] memory xp = _xp(); - uint x = xp[i] + dx * multipliers[i]; - - uint y0 = xp[j]; - uint y1 = _getY(i, j, x, xp); - // y0 must be >= y1, since x has increased - // -1 to round down - dy = (y0 - y1 - 1) / multipliers[j]; - - // Subtract fee from dy - uint fee = (dy * SWAP_FEE) / FEE_DENOMINATOR; - dy -= fee; - require(dy >= minDy, "dy < min"); - - balances[i] += dx; - balances[j] -= dy; - - IERC20(tokens[j]).transfer(msg.sender, dy); - } - - function addLiquidity( - uint[N] calldata amounts, - uint minShares - ) external returns (uint shares) { - // calculate current liquidity d0 - uint _totalSupply = totalSupply; - uint d0; - uint[N] memory old_xs = _xp(); - if (_totalSupply > 0) { - d0 = _getD(old_xs); - } - - // Transfer tokens in - uint[N] memory new_xs; - for (uint i; i < N; ++i) { - uint amount = amounts[i]; - if (amount > 0) { - IERC20(tokens[i]).transferFrom(msg.sender, address(this), amount); - new_xs[i] = old_xs[i] + amount * multipliers[i]; - } else { - new_xs[i] = old_xs[i]; - } - } - - // Calculate new liquidity d1 - uint d1 = _getD(new_xs); - require(d1 > d0, "liquidity didn't increase"); - - // Reccalcuate D accounting for fee on imbalance - uint d2; - if (_totalSupply > 0) { - for (uint i; i < N; ++i) { - // TODO: why old_xs[i] * d1 / d0? why not d1 / N? - uint idealBalance = (old_xs[i] * d1) / d0; - uint diff = Math.abs(new_xs[i], idealBalance); - new_xs[i] -= (LIQUIDITY_FEE * diff) / FEE_DENOMINATOR; - } - - d2 = _getD(new_xs); - } else { - d2 = d1; - bug(); - } - - // Update balances - for (uint i; i < N; ++i) { - balances[i] += amounts[i]; - } - - // Shares to mint = (d2 - d0) / d0 * total supply - // d1 >= d2 >= d0 - if (_totalSupply > 0) { - shares = ((d2 - d0) * _totalSupply) / d0; - } else { - shares = d2; - } - require(shares >= minShares, "shares < min"); - _mint(msg.sender, shares); - } - - function removeLiquidity( - uint shares, - uint[N] calldata minAmountsOut - ) external returns (uint[N] memory amountsOut) { - uint _totalSupply = totalSupply; - - for (uint i; i < N; ++i) { - uint amountOut = (balances[i] * shares) / _totalSupply; - require(amountOut >= minAmountsOut[i], "out < min"); - - balances[i] -= amountOut; - amountsOut[i] = amountOut; - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } - - _burn(msg.sender, shares); - } - - /** - * @notice Calculate amount of token i to receive for shares - * @param shares Shares to burn - * @param i Index of token to withdraw - * @return dy Amount of token i to receive - * fee Fee for withdraw. Fee already included in dy - */ - function _calcWithdrawOneToken( - uint shares, - uint i - ) private view returns (uint dy, uint fee) { - uint _totalSupply = totalSupply; - uint[N] memory xp = _xp(); - - // Calculate d0 and d1 - uint d0 = _getD(xp); - uint d1 = d0 - (d0 * shares) / _totalSupply; - - // Calculate reduction in y if D = d1 - uint y0 = _getYD(i, xp, d1); - // d1 <= d0 so y must be <= xp[i] - uint dy0 = (xp[i] - y0) / multipliers[i]; - - // Calculate imbalance fee, update xp with fees - uint dx; - for (uint j; j < N; ++j) { - if (j == i) { - dx = (xp[j] * d1) / d0 - y0; - } else { - // d1 / d0 <= 1 - dx = xp[j] - (xp[j] * d1) / d0; - } - xp[j] -= (LIQUIDITY_FEE * dx) / FEE_DENOMINATOR; - } - - // Recalculate y with xp including imbalance fees - uint y1 = _getYD(i, xp, d1); - // - 1 to round down - dy = (xp[i] - y1 - 1) / multipliers[i]; - fee = dy0 - dy; - } - - function calcWithdrawOneToken( - uint shares, - uint i - ) external view returns (uint dy, uint fee) { - return _calcWithdrawOneToken(shares, i); - } - - /** - * @notice Withdraw liquidity in token i - * @param shares Shares to burn - * @param i Token to withdraw - * @param minAmountOut Minimum amount of token i that must be withdrawn - */ - function removeLiquidityOneToken( - uint shares, - uint i, - uint minAmountOut - ) external returns (uint amountOut) { - (amountOut, ) = _calcWithdrawOneToken(shares, i); - require(amountOut >= minAmountOut, "out < min"); - - balances[i] -= amountOut; - _burn(msg.sender, shares); - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint amount); - event Approval(address indexed owner, address indexed spender, uint amount); -} diff --git a/tests/evm/sbe@stable-swap-amm-addLiquidity6/StableSwap.sol b/tests/evm/sbe@stable-swap-amm-addLiquidity6/StableSwap.sol deleted file mode 100644 index 0cb3f73f1..000000000 --- a/tests/evm/sbe@stable-swap-amm-addLiquidity6/StableSwap.sol +++ /dev/null @@ -1,444 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8; -import "../../../solidity_utils/lib.sol"; - -/* -Invariant - price of trade and amount of liquidity are determined by this equation - -An^n sum(x_i) + D = ADn^n + D^(n + 1) / (n^n prod(x_i)) - -Topics -0. Newton's method x_(n + 1) = x_n - f(x_n) / f'(x_n) -1. Invariant -2. Swap - - Calculate Y - - Calculate D -3. Get virtual price -4. Add liquidity - - Imbalance fee -5. Remove liquidity -6. Remove liquidity one token - - Calculate withdraw one token - - getYD -TODO: test? -*/ - -library Math { - function abs(uint x, uint y) internal pure returns (uint) { - return x >= y ? x - y : y - x; - } -} - -contract StableSwap { - // Number of tokens - uint private constant N = 3; - // Amplification coefficient multiplied by N^(N - 1) - // Higher value makes the curve more flat - // Lower value makes the curve more like constant product AMM - uint private constant A = 1000 * (N ** (N - 1)); - // 0.03% - uint private constant SWAP_FEE = 300; - // Liquidity fee is derived from 2 constraints - // 1. Fee is 0 for adding / removing liquidity that results in a balanced pool - // 2. Swapping in a balanced pool is like adding and then removing liquidity - // from a balanced pool - // swap fee = add liquidity fee + remove liquidity fee - uint private constant LIQUIDITY_FEE = (SWAP_FEE * N) / (4 * (N - 1)); - uint private constant FEE_DENOMINATOR = 1e6; - - address[N] public tokens; - // Normalize each token to 18 decimals - // Example - DAI (18 decimals), USDC (6 decimals), USDT (6 decimals) - uint[N] private multipliers = [1, 1e12, 1e12]; - uint[N] public balances; - - // 1 share = 1e18, 18 decimals - uint private constant DECIMALS = 18; - uint public totalSupply; - mapping(address => uint) public balanceOf; - - function _mint(address _to, uint _amount) private { - balanceOf[_to] += _amount; - totalSupply += _amount; - } - - function _burn(address _from, uint _amount) private { - balanceOf[_from] -= _amount; - totalSupply -= _amount; - } - - // Return precision-adjusted balances, adjusted to 18 decimals - function _xp() private view returns (uint[N] memory xp) { - for (uint i; i < N; ++i) { - xp[i] = balances[i] * multipliers[i]; - } - } - - /** - * @notice Calculate D, sum of balances in a perfectly balanced pool - * If balances of x_0, x_1, ... x_(n-1) then sum(x_i) = D - * @param xp Precision-adjusted balances - * @return D - */ - function _getD(uint[N] memory xp) private pure returns (uint) { - /* - Newton's method to compute D - ----------------------------- - f(D) = ADn^n + D^(n + 1) / (n^n prod(x_i)) - An^n sum(x_i) - D - f'(D) = An^n + (n + 1) D^n / (n^n prod(x_i)) - 1 - - (as + np)D_n - D_(n+1) = ----------------------- - (a - 1)D_n + (n + 1)p - - a = An^n - s = sum(x_i) - p = (D_n)^(n + 1) / (n^n prod(x_i)) - */ - uint a = A * N; // An^n - - uint s; // x_0 + x_1 + ... + x_(n-1) - for (uint i; i < N; ++i) { - s += xp[i]; - } - - // Newton's method - // Initial guess, d <= s - uint d = s; - uint d_prev; - for (uint i; i < 255; ++i) { - // p = D^(n + 1) / (n^n * x_0 * ... * x_(n-1)) - uint p = d; - for (uint j; j < N; ++j) { - p = (p * d) / (N * xp[j]); - } - d_prev = d; - d = ((a * s + N * p) * d) / ((a - 1) * d + (N + 1) * p); - - if (Math.abs(d, d_prev) <= 1) { - return d; - } - } - revert("D didn't converge"); - } - - /** - * @notice Calculate the new balance of token j given the new balance of token i - * @param i Index of token in - * @param j Index of token out - * @param x New balance of token i - * @param xp Current precision-adjusted balances - */ - function _getY( - uint i, - uint j, - uint x, - uint[N] memory xp - ) private pure returns (uint) { - /* - Newton's method to compute y - ----------------------------- - y = x_j - - f(y) = y^2 + y(b - D) - c - - y_n^2 + c - y_(n+1) = -------------- - 2y_n + b - D - - where - s = sum(x_k), k != j - p = prod(x_k), k != j - b = s + D / (An^n) - c = D^(n + 1) / (n^n * p * An^n) - */ - uint a = A * N; - uint d = _getD(xp); - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k == i) { - _x = x; - } else if (k == j) { - continue; - } else { - _x = xp[k]; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - /** - * @notice Calculate the new balance of token i given precision-adjusted - * balances xp and liquidity d - * @dev Equation is calculate y is same as _getY - * @param i Index of token to calculate the new balance - * @param xp Precision-adjusted balances - * @param d Liquidity d - * @return New balance of token i - */ - function _getYD(uint i, uint[N] memory xp, uint d) private pure returns (uint) { - uint a = A * N; - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k != i) { - _x = xp[k]; - } else { - continue; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - // Estimate value of 1 share - // How many tokens is one share worth? - function getVirtualPrice() external view returns (uint) { - uint d = _getD(_xp()); - uint _totalSupply = totalSupply; - if (_totalSupply > 0) { - return (d * 10 ** DECIMALS) / _totalSupply; - } - return 0; - } - - /** - * @notice Swap dx amount of token i for token j - * @param i Index of token in - * @param j Index of token out - * @param dx Token in amount - * @param minDy Minimum token out - */ - function swap(uint i, uint j, uint dx, uint minDy) external returns (uint dy) { - require(i != j, "i = j"); - - IERC20(tokens[i]).transferFrom(msg.sender, address(this), dx); - - // Calculate dy - uint[N] memory xp = _xp(); - uint x = xp[i] + dx * multipliers[i]; - - uint y0 = xp[j]; - uint y1 = _getY(i, j, x, xp); - // y0 must be >= y1, since x has increased - // -1 to round down - dy = (y0 - y1 - 1) / multipliers[j]; - - // Subtract fee from dy - uint fee = (dy * SWAP_FEE) / FEE_DENOMINATOR; - dy -= fee; - require(dy >= minDy, "dy < min"); - - balances[i] += dx; - balances[j] -= dy; - - IERC20(tokens[j]).transfer(msg.sender, dy); - } - - function addLiquidity( - uint[N] calldata amounts, - uint minShares - ) external returns (uint shares) { - // calculate current liquidity d0 - uint _totalSupply = totalSupply; - uint d0; - uint[N] memory old_xs = _xp(); - if (_totalSupply > 0) { - d0 = _getD(old_xs); - } - - // Transfer tokens in - uint[N] memory new_xs; - for (uint i; i < N; ++i) { - uint amount = amounts[i]; - if (amount > 0) { - IERC20(tokens[i]).transferFrom(msg.sender, address(this), amount); - new_xs[i] = old_xs[i] + amount * multipliers[i]; - } else { - new_xs[i] = old_xs[i]; - } - } - - // Calculate new liquidity d1 - uint d1 = _getD(new_xs); - require(d1 > d0, "liquidity didn't increase"); - - // Reccalcuate D accounting for fee on imbalance - uint d2; - if (_totalSupply > 0) { - for (uint i; i < N; ++i) { - // TODO: why old_xs[i] * d1 / d0? why not d1 / N? - uint idealBalance = (old_xs[i] * d1) / d0; - uint diff = Math.abs(new_xs[i], idealBalance); - new_xs[i] -= (LIQUIDITY_FEE * diff) / FEE_DENOMINATOR; - } - - d2 = _getD(new_xs); - } else { - d2 = d1; - } - - // Update balances - for (uint i; i < N; ++i) { - balances[i] += amounts[i]; - bug(); - } - - // Shares to mint = (d2 - d0) / d0 * total supply - // d1 >= d2 >= d0 - if (_totalSupply > 0) { - shares = ((d2 - d0) * _totalSupply) / d0; - } else { - shares = d2; - } - require(shares >= minShares, "shares < min"); - _mint(msg.sender, shares); - } - - function removeLiquidity( - uint shares, - uint[N] calldata minAmountsOut - ) external returns (uint[N] memory amountsOut) { - uint _totalSupply = totalSupply; - - for (uint i; i < N; ++i) { - uint amountOut = (balances[i] * shares) / _totalSupply; - require(amountOut >= minAmountsOut[i], "out < min"); - - balances[i] -= amountOut; - amountsOut[i] = amountOut; - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } - - _burn(msg.sender, shares); - } - - /** - * @notice Calculate amount of token i to receive for shares - * @param shares Shares to burn - * @param i Index of token to withdraw - * @return dy Amount of token i to receive - * fee Fee for withdraw. Fee already included in dy - */ - function _calcWithdrawOneToken( - uint shares, - uint i - ) private view returns (uint dy, uint fee) { - uint _totalSupply = totalSupply; - uint[N] memory xp = _xp(); - - // Calculate d0 and d1 - uint d0 = _getD(xp); - uint d1 = d0 - (d0 * shares) / _totalSupply; - - // Calculate reduction in y if D = d1 - uint y0 = _getYD(i, xp, d1); - // d1 <= d0 so y must be <= xp[i] - uint dy0 = (xp[i] - y0) / multipliers[i]; - - // Calculate imbalance fee, update xp with fees - uint dx; - for (uint j; j < N; ++j) { - if (j == i) { - dx = (xp[j] * d1) / d0 - y0; - } else { - // d1 / d0 <= 1 - dx = xp[j] - (xp[j] * d1) / d0; - } - xp[j] -= (LIQUIDITY_FEE * dx) / FEE_DENOMINATOR; - } - - // Recalculate y with xp including imbalance fees - uint y1 = _getYD(i, xp, d1); - // - 1 to round down - dy = (xp[i] - y1 - 1) / multipliers[i]; - fee = dy0 - dy; - } - - function calcWithdrawOneToken( - uint shares, - uint i - ) external view returns (uint dy, uint fee) { - return _calcWithdrawOneToken(shares, i); - } - - /** - * @notice Withdraw liquidity in token i - * @param shares Shares to burn - * @param i Token to withdraw - * @param minAmountOut Minimum amount of token i that must be withdrawn - */ - function removeLiquidityOneToken( - uint shares, - uint i, - uint minAmountOut - ) external returns (uint amountOut) { - (amountOut, ) = _calcWithdrawOneToken(shares, i); - require(amountOut >= minAmountOut, "out < min"); - - balances[i] -= amountOut; - _burn(msg.sender, shares); - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint amount); - event Approval(address indexed owner, address indexed spender, uint amount); -} diff --git a/tests/evm/sbe@stable-swap-amm-addLiquidity7/StableSwap.sol b/tests/evm/sbe@stable-swap-amm-addLiquidity7/StableSwap.sol deleted file mode 100644 index 59d635352..000000000 --- a/tests/evm/sbe@stable-swap-amm-addLiquidity7/StableSwap.sol +++ /dev/null @@ -1,444 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8; -import "../../../solidity_utils/lib.sol"; - -/* -Invariant - price of trade and amount of liquidity are determined by this equation - -An^n sum(x_i) + D = ADn^n + D^(n + 1) / (n^n prod(x_i)) - -Topics -0. Newton's method x_(n + 1) = x_n - f(x_n) / f'(x_n) -1. Invariant -2. Swap - - Calculate Y - - Calculate D -3. Get virtual price -4. Add liquidity - - Imbalance fee -5. Remove liquidity -6. Remove liquidity one token - - Calculate withdraw one token - - getYD -TODO: test? -*/ - -library Math { - function abs(uint x, uint y) internal pure returns (uint) { - return x >= y ? x - y : y - x; - } -} - -contract StableSwap { - // Number of tokens - uint private constant N = 3; - // Amplification coefficient multiplied by N^(N - 1) - // Higher value makes the curve more flat - // Lower value makes the curve more like constant product AMM - uint private constant A = 1000 * (N ** (N - 1)); - // 0.03% - uint private constant SWAP_FEE = 300; - // Liquidity fee is derived from 2 constraints - // 1. Fee is 0 for adding / removing liquidity that results in a balanced pool - // 2. Swapping in a balanced pool is like adding and then removing liquidity - // from a balanced pool - // swap fee = add liquidity fee + remove liquidity fee - uint private constant LIQUIDITY_FEE = (SWAP_FEE * N) / (4 * (N - 1)); - uint private constant FEE_DENOMINATOR = 1e6; - - address[N] public tokens; - // Normalize each token to 18 decimals - // Example - DAI (18 decimals), USDC (6 decimals), USDT (6 decimals) - uint[N] private multipliers = [1, 1e12, 1e12]; - uint[N] public balances; - - // 1 share = 1e18, 18 decimals - uint private constant DECIMALS = 18; - uint public totalSupply; - mapping(address => uint) public balanceOf; - - function _mint(address _to, uint _amount) private { - balanceOf[_to] += _amount; - totalSupply += _amount; - } - - function _burn(address _from, uint _amount) private { - balanceOf[_from] -= _amount; - totalSupply -= _amount; - } - - // Return precision-adjusted balances, adjusted to 18 decimals - function _xp() private view returns (uint[N] memory xp) { - for (uint i; i < N; ++i) { - xp[i] = balances[i] * multipliers[i]; - } - } - - /** - * @notice Calculate D, sum of balances in a perfectly balanced pool - * If balances of x_0, x_1, ... x_(n-1) then sum(x_i) = D - * @param xp Precision-adjusted balances - * @return D - */ - function _getD(uint[N] memory xp) private pure returns (uint) { - /* - Newton's method to compute D - ----------------------------- - f(D) = ADn^n + D^(n + 1) / (n^n prod(x_i)) - An^n sum(x_i) - D - f'(D) = An^n + (n + 1) D^n / (n^n prod(x_i)) - 1 - - (as + np)D_n - D_(n+1) = ----------------------- - (a - 1)D_n + (n + 1)p - - a = An^n - s = sum(x_i) - p = (D_n)^(n + 1) / (n^n prod(x_i)) - */ - uint a = A * N; // An^n - - uint s; // x_0 + x_1 + ... + x_(n-1) - for (uint i; i < N; ++i) { - s += xp[i]; - } - - // Newton's method - // Initial guess, d <= s - uint d = s; - uint d_prev; - for (uint i; i < 255; ++i) { - // p = D^(n + 1) / (n^n * x_0 * ... * x_(n-1)) - uint p = d; - for (uint j; j < N; ++j) { - p = (p * d) / (N * xp[j]); - } - d_prev = d; - d = ((a * s + N * p) * d) / ((a - 1) * d + (N + 1) * p); - - if (Math.abs(d, d_prev) <= 1) { - return d; - } - } - revert("D didn't converge"); - } - - /** - * @notice Calculate the new balance of token j given the new balance of token i - * @param i Index of token in - * @param j Index of token out - * @param x New balance of token i - * @param xp Current precision-adjusted balances - */ - function _getY( - uint i, - uint j, - uint x, - uint[N] memory xp - ) private pure returns (uint) { - /* - Newton's method to compute y - ----------------------------- - y = x_j - - f(y) = y^2 + y(b - D) - c - - y_n^2 + c - y_(n+1) = -------------- - 2y_n + b - D - - where - s = sum(x_k), k != j - p = prod(x_k), k != j - b = s + D / (An^n) - c = D^(n + 1) / (n^n * p * An^n) - */ - uint a = A * N; - uint d = _getD(xp); - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k == i) { - _x = x; - } else if (k == j) { - continue; - } else { - _x = xp[k]; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - /** - * @notice Calculate the new balance of token i given precision-adjusted - * balances xp and liquidity d - * @dev Equation is calculate y is same as _getY - * @param i Index of token to calculate the new balance - * @param xp Precision-adjusted balances - * @param d Liquidity d - * @return New balance of token i - */ - function _getYD(uint i, uint[N] memory xp, uint d) private pure returns (uint) { - uint a = A * N; - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k != i) { - _x = xp[k]; - } else { - continue; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - // Estimate value of 1 share - // How many tokens is one share worth? - function getVirtualPrice() external view returns (uint) { - uint d = _getD(_xp()); - uint _totalSupply = totalSupply; - if (_totalSupply > 0) { - return (d * 10 ** DECIMALS) / _totalSupply; - } - return 0; - } - - /** - * @notice Swap dx amount of token i for token j - * @param i Index of token in - * @param j Index of token out - * @param dx Token in amount - * @param minDy Minimum token out - */ - function swap(uint i, uint j, uint dx, uint minDy) external returns (uint dy) { - require(i != j, "i = j"); - - IERC20(tokens[i]).transferFrom(msg.sender, address(this), dx); - - // Calculate dy - uint[N] memory xp = _xp(); - uint x = xp[i] + dx * multipliers[i]; - - uint y0 = xp[j]; - uint y1 = _getY(i, j, x, xp); - // y0 must be >= y1, since x has increased - // -1 to round down - dy = (y0 - y1 - 1) / multipliers[j]; - - // Subtract fee from dy - uint fee = (dy * SWAP_FEE) / FEE_DENOMINATOR; - dy -= fee; - require(dy >= minDy, "dy < min"); - - balances[i] += dx; - balances[j] -= dy; - - IERC20(tokens[j]).transfer(msg.sender, dy); - } - - function addLiquidity( - uint[N] calldata amounts, - uint minShares - ) external returns (uint shares) { - // calculate current liquidity d0 - uint _totalSupply = totalSupply; - uint d0; - uint[N] memory old_xs = _xp(); - if (_totalSupply > 0) { - d0 = _getD(old_xs); - } - - // Transfer tokens in - uint[N] memory new_xs; - for (uint i; i < N; ++i) { - uint amount = amounts[i]; - if (amount > 0) { - IERC20(tokens[i]).transferFrom(msg.sender, address(this), amount); - new_xs[i] = old_xs[i] + amount * multipliers[i]; - } else { - new_xs[i] = old_xs[i]; - } - } - - // Calculate new liquidity d1 - uint d1 = _getD(new_xs); - require(d1 > d0, "liquidity didn't increase"); - - // Reccalcuate D accounting for fee on imbalance - uint d2; - if (_totalSupply > 0) { - for (uint i; i < N; ++i) { - // TODO: why old_xs[i] * d1 / d0? why not d1 / N? - uint idealBalance = (old_xs[i] * d1) / d0; - uint diff = Math.abs(new_xs[i], idealBalance); - new_xs[i] -= (LIQUIDITY_FEE * diff) / FEE_DENOMINATOR; - } - - d2 = _getD(new_xs); - } else { - d2 = d1; - } - - // Update balances - for (uint i; i < N; ++i) { - balances[i] += amounts[i]; - } - - // Shares to mint = (d2 - d0) / d0 * total supply - // d1 >= d2 >= d0 - if (_totalSupply > 0) { - shares = ((d2 - d0) * _totalSupply) / d0; - bug(); - } else { - shares = d2; - } - require(shares >= minShares, "shares < min"); - _mint(msg.sender, shares); - } - - function removeLiquidity( - uint shares, - uint[N] calldata minAmountsOut - ) external returns (uint[N] memory amountsOut) { - uint _totalSupply = totalSupply; - - for (uint i; i < N; ++i) { - uint amountOut = (balances[i] * shares) / _totalSupply; - require(amountOut >= minAmountsOut[i], "out < min"); - - balances[i] -= amountOut; - amountsOut[i] = amountOut; - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } - - _burn(msg.sender, shares); - } - - /** - * @notice Calculate amount of token i to receive for shares - * @param shares Shares to burn - * @param i Index of token to withdraw - * @return dy Amount of token i to receive - * fee Fee for withdraw. Fee already included in dy - */ - function _calcWithdrawOneToken( - uint shares, - uint i - ) private view returns (uint dy, uint fee) { - uint _totalSupply = totalSupply; - uint[N] memory xp = _xp(); - - // Calculate d0 and d1 - uint d0 = _getD(xp); - uint d1 = d0 - (d0 * shares) / _totalSupply; - - // Calculate reduction in y if D = d1 - uint y0 = _getYD(i, xp, d1); - // d1 <= d0 so y must be <= xp[i] - uint dy0 = (xp[i] - y0) / multipliers[i]; - - // Calculate imbalance fee, update xp with fees - uint dx; - for (uint j; j < N; ++j) { - if (j == i) { - dx = (xp[j] * d1) / d0 - y0; - } else { - // d1 / d0 <= 1 - dx = xp[j] - (xp[j] * d1) / d0; - } - xp[j] -= (LIQUIDITY_FEE * dx) / FEE_DENOMINATOR; - } - - // Recalculate y with xp including imbalance fees - uint y1 = _getYD(i, xp, d1); - // - 1 to round down - dy = (xp[i] - y1 - 1) / multipliers[i]; - fee = dy0 - dy; - } - - function calcWithdrawOneToken( - uint shares, - uint i - ) external view returns (uint dy, uint fee) { - return _calcWithdrawOneToken(shares, i); - } - - /** - * @notice Withdraw liquidity in token i - * @param shares Shares to burn - * @param i Token to withdraw - * @param minAmountOut Minimum amount of token i that must be withdrawn - */ - function removeLiquidityOneToken( - uint shares, - uint i, - uint minAmountOut - ) external returns (uint amountOut) { - (amountOut, ) = _calcWithdrawOneToken(shares, i); - require(amountOut >= minAmountOut, "out < min"); - - balances[i] -= amountOut; - _burn(msg.sender, shares); - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint amount); - event Approval(address indexed owner, address indexed spender, uint amount); -} diff --git a/tests/evm/sbe@stable-swap-amm-addLiquidity8/StableSwap.sol b/tests/evm/sbe@stable-swap-amm-addLiquidity8/StableSwap.sol deleted file mode 100644 index 33b276f43..000000000 --- a/tests/evm/sbe@stable-swap-amm-addLiquidity8/StableSwap.sol +++ /dev/null @@ -1,444 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8; -import "../../../solidity_utils/lib.sol"; - -/* -Invariant - price of trade and amount of liquidity are determined by this equation - -An^n sum(x_i) + D = ADn^n + D^(n + 1) / (n^n prod(x_i)) - -Topics -0. Newton's method x_(n + 1) = x_n - f(x_n) / f'(x_n) -1. Invariant -2. Swap - - Calculate Y - - Calculate D -3. Get virtual price -4. Add liquidity - - Imbalance fee -5. Remove liquidity -6. Remove liquidity one token - - Calculate withdraw one token - - getYD -TODO: test? -*/ - -library Math { - function abs(uint x, uint y) internal pure returns (uint) { - return x >= y ? x - y : y - x; - } -} - -contract StableSwap { - // Number of tokens - uint private constant N = 3; - // Amplification coefficient multiplied by N^(N - 1) - // Higher value makes the curve more flat - // Lower value makes the curve more like constant product AMM - uint private constant A = 1000 * (N ** (N - 1)); - // 0.03% - uint private constant SWAP_FEE = 300; - // Liquidity fee is derived from 2 constraints - // 1. Fee is 0 for adding / removing liquidity that results in a balanced pool - // 2. Swapping in a balanced pool is like adding and then removing liquidity - // from a balanced pool - // swap fee = add liquidity fee + remove liquidity fee - uint private constant LIQUIDITY_FEE = (SWAP_FEE * N) / (4 * (N - 1)); - uint private constant FEE_DENOMINATOR = 1e6; - - address[N] public tokens; - // Normalize each token to 18 decimals - // Example - DAI (18 decimals), USDC (6 decimals), USDT (6 decimals) - uint[N] private multipliers = [1, 1e12, 1e12]; - uint[N] public balances; - - // 1 share = 1e18, 18 decimals - uint private constant DECIMALS = 18; - uint public totalSupply; - mapping(address => uint) public balanceOf; - - function _mint(address _to, uint _amount) private { - balanceOf[_to] += _amount; - totalSupply += _amount; - } - - function _burn(address _from, uint _amount) private { - balanceOf[_from] -= _amount; - totalSupply -= _amount; - } - - // Return precision-adjusted balances, adjusted to 18 decimals - function _xp() private view returns (uint[N] memory xp) { - for (uint i; i < N; ++i) { - xp[i] = balances[i] * multipliers[i]; - } - } - - /** - * @notice Calculate D, sum of balances in a perfectly balanced pool - * If balances of x_0, x_1, ... x_(n-1) then sum(x_i) = D - * @param xp Precision-adjusted balances - * @return D - */ - function _getD(uint[N] memory xp) private pure returns (uint) { - /* - Newton's method to compute D - ----------------------------- - f(D) = ADn^n + D^(n + 1) / (n^n prod(x_i)) - An^n sum(x_i) - D - f'(D) = An^n + (n + 1) D^n / (n^n prod(x_i)) - 1 - - (as + np)D_n - D_(n+1) = ----------------------- - (a - 1)D_n + (n + 1)p - - a = An^n - s = sum(x_i) - p = (D_n)^(n + 1) / (n^n prod(x_i)) - */ - uint a = A * N; // An^n - - uint s; // x_0 + x_1 + ... + x_(n-1) - for (uint i; i < N; ++i) { - s += xp[i]; - } - - // Newton's method - // Initial guess, d <= s - uint d = s; - uint d_prev; - for (uint i; i < 255; ++i) { - // p = D^(n + 1) / (n^n * x_0 * ... * x_(n-1)) - uint p = d; - for (uint j; j < N; ++j) { - p = (p * d) / (N * xp[j]); - } - d_prev = d; - d = ((a * s + N * p) * d) / ((a - 1) * d + (N + 1) * p); - - if (Math.abs(d, d_prev) <= 1) { - return d; - } - } - revert("D didn't converge"); - } - - /** - * @notice Calculate the new balance of token j given the new balance of token i - * @param i Index of token in - * @param j Index of token out - * @param x New balance of token i - * @param xp Current precision-adjusted balances - */ - function _getY( - uint i, - uint j, - uint x, - uint[N] memory xp - ) private pure returns (uint) { - /* - Newton's method to compute y - ----------------------------- - y = x_j - - f(y) = y^2 + y(b - D) - c - - y_n^2 + c - y_(n+1) = -------------- - 2y_n + b - D - - where - s = sum(x_k), k != j - p = prod(x_k), k != j - b = s + D / (An^n) - c = D^(n + 1) / (n^n * p * An^n) - */ - uint a = A * N; - uint d = _getD(xp); - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k == i) { - _x = x; - } else if (k == j) { - continue; - } else { - _x = xp[k]; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - /** - * @notice Calculate the new balance of token i given precision-adjusted - * balances xp and liquidity d - * @dev Equation is calculate y is same as _getY - * @param i Index of token to calculate the new balance - * @param xp Precision-adjusted balances - * @param d Liquidity d - * @return New balance of token i - */ - function _getYD(uint i, uint[N] memory xp, uint d) private pure returns (uint) { - uint a = A * N; - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k != i) { - _x = xp[k]; - } else { - continue; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - // Estimate value of 1 share - // How many tokens is one share worth? - function getVirtualPrice() external view returns (uint) { - uint d = _getD(_xp()); - uint _totalSupply = totalSupply; - if (_totalSupply > 0) { - return (d * 10 ** DECIMALS) / _totalSupply; - } - return 0; - } - - /** - * @notice Swap dx amount of token i for token j - * @param i Index of token in - * @param j Index of token out - * @param dx Token in amount - * @param minDy Minimum token out - */ - function swap(uint i, uint j, uint dx, uint minDy) external returns (uint dy) { - require(i != j, "i = j"); - - IERC20(tokens[i]).transferFrom(msg.sender, address(this), dx); - - // Calculate dy - uint[N] memory xp = _xp(); - uint x = xp[i] + dx * multipliers[i]; - - uint y0 = xp[j]; - uint y1 = _getY(i, j, x, xp); - // y0 must be >= y1, since x has increased - // -1 to round down - dy = (y0 - y1 - 1) / multipliers[j]; - - // Subtract fee from dy - uint fee = (dy * SWAP_FEE) / FEE_DENOMINATOR; - dy -= fee; - require(dy >= minDy, "dy < min"); - - balances[i] += dx; - balances[j] -= dy; - - IERC20(tokens[j]).transfer(msg.sender, dy); - } - - function addLiquidity( - uint[N] calldata amounts, - uint minShares - ) external returns (uint shares) { - // calculate current liquidity d0 - uint _totalSupply = totalSupply; - uint d0; - uint[N] memory old_xs = _xp(); - if (_totalSupply > 0) { - d0 = _getD(old_xs); - } - - // Transfer tokens in - uint[N] memory new_xs; - for (uint i; i < N; ++i) { - uint amount = amounts[i]; - if (amount > 0) { - IERC20(tokens[i]).transferFrom(msg.sender, address(this), amount); - new_xs[i] = old_xs[i] + amount * multipliers[i]; - } else { - new_xs[i] = old_xs[i]; - } - } - - // Calculate new liquidity d1 - uint d1 = _getD(new_xs); - require(d1 > d0, "liquidity didn't increase"); - - // Reccalcuate D accounting for fee on imbalance - uint d2; - if (_totalSupply > 0) { - for (uint i; i < N; ++i) { - // TODO: why old_xs[i] * d1 / d0? why not d1 / N? - uint idealBalance = (old_xs[i] * d1) / d0; - uint diff = Math.abs(new_xs[i], idealBalance); - new_xs[i] -= (LIQUIDITY_FEE * diff) / FEE_DENOMINATOR; - } - - d2 = _getD(new_xs); - } else { - d2 = d1; - } - - // Update balances - for (uint i; i < N; ++i) { - balances[i] += amounts[i]; - } - - // Shares to mint = (d2 - d0) / d0 * total supply - // d1 >= d2 >= d0 - if (_totalSupply > 0) { - shares = ((d2 - d0) * _totalSupply) / d0; - } else { - shares = d2; - bug(); - } - require(shares >= minShares, "shares < min"); - _mint(msg.sender, shares); - } - - function removeLiquidity( - uint shares, - uint[N] calldata minAmountsOut - ) external returns (uint[N] memory amountsOut) { - uint _totalSupply = totalSupply; - - for (uint i; i < N; ++i) { - uint amountOut = (balances[i] * shares) / _totalSupply; - require(amountOut >= minAmountsOut[i], "out < min"); - - balances[i] -= amountOut; - amountsOut[i] = amountOut; - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } - - _burn(msg.sender, shares); - } - - /** - * @notice Calculate amount of token i to receive for shares - * @param shares Shares to burn - * @param i Index of token to withdraw - * @return dy Amount of token i to receive - * fee Fee for withdraw. Fee already included in dy - */ - function _calcWithdrawOneToken( - uint shares, - uint i - ) private view returns (uint dy, uint fee) { - uint _totalSupply = totalSupply; - uint[N] memory xp = _xp(); - - // Calculate d0 and d1 - uint d0 = _getD(xp); - uint d1 = d0 - (d0 * shares) / _totalSupply; - - // Calculate reduction in y if D = d1 - uint y0 = _getYD(i, xp, d1); - // d1 <= d0 so y must be <= xp[i] - uint dy0 = (xp[i] - y0) / multipliers[i]; - - // Calculate imbalance fee, update xp with fees - uint dx; - for (uint j; j < N; ++j) { - if (j == i) { - dx = (xp[j] * d1) / d0 - y0; - } else { - // d1 / d0 <= 1 - dx = xp[j] - (xp[j] * d1) / d0; - } - xp[j] -= (LIQUIDITY_FEE * dx) / FEE_DENOMINATOR; - } - - // Recalculate y with xp including imbalance fees - uint y1 = _getYD(i, xp, d1); - // - 1 to round down - dy = (xp[i] - y1 - 1) / multipliers[i]; - fee = dy0 - dy; - } - - function calcWithdrawOneToken( - uint shares, - uint i - ) external view returns (uint dy, uint fee) { - return _calcWithdrawOneToken(shares, i); - } - - /** - * @notice Withdraw liquidity in token i - * @param shares Shares to burn - * @param i Token to withdraw - * @param minAmountOut Minimum amount of token i that must be withdrawn - */ - function removeLiquidityOneToken( - uint shares, - uint i, - uint minAmountOut - ) external returns (uint amountOut) { - (amountOut, ) = _calcWithdrawOneToken(shares, i); - require(amountOut >= minAmountOut, "out < min"); - - balances[i] -= amountOut; - _burn(msg.sender, shares); - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint amount); - event Approval(address indexed owner, address indexed spender, uint amount); -} diff --git a/tests/evm/sbe@stable-swap-amm-addLiquidity9/StableSwap.sol b/tests/evm/sbe@stable-swap-amm-addLiquidity9/StableSwap.sol deleted file mode 100644 index 5847197cd..000000000 --- a/tests/evm/sbe@stable-swap-amm-addLiquidity9/StableSwap.sol +++ /dev/null @@ -1,444 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8; -import "../../../solidity_utils/lib.sol"; - -/* -Invariant - price of trade and amount of liquidity are determined by this equation - -An^n sum(x_i) + D = ADn^n + D^(n + 1) / (n^n prod(x_i)) - -Topics -0. Newton's method x_(n + 1) = x_n - f(x_n) / f'(x_n) -1. Invariant -2. Swap - - Calculate Y - - Calculate D -3. Get virtual price -4. Add liquidity - - Imbalance fee -5. Remove liquidity -6. Remove liquidity one token - - Calculate withdraw one token - - getYD -TODO: test? -*/ - -library Math { - function abs(uint x, uint y) internal pure returns (uint) { - return x >= y ? x - y : y - x; - } -} - -contract StableSwap { - // Number of tokens - uint private constant N = 3; - // Amplification coefficient multiplied by N^(N - 1) - // Higher value makes the curve more flat - // Lower value makes the curve more like constant product AMM - uint private constant A = 1000 * (N ** (N - 1)); - // 0.03% - uint private constant SWAP_FEE = 300; - // Liquidity fee is derived from 2 constraints - // 1. Fee is 0 for adding / removing liquidity that results in a balanced pool - // 2. Swapping in a balanced pool is like adding and then removing liquidity - // from a balanced pool - // swap fee = add liquidity fee + remove liquidity fee - uint private constant LIQUIDITY_FEE = (SWAP_FEE * N) / (4 * (N - 1)); - uint private constant FEE_DENOMINATOR = 1e6; - - address[N] public tokens; - // Normalize each token to 18 decimals - // Example - DAI (18 decimals), USDC (6 decimals), USDT (6 decimals) - uint[N] private multipliers = [1, 1e12, 1e12]; - uint[N] public balances; - - // 1 share = 1e18, 18 decimals - uint private constant DECIMALS = 18; - uint public totalSupply; - mapping(address => uint) public balanceOf; - - function _mint(address _to, uint _amount) private { - balanceOf[_to] += _amount; - totalSupply += _amount; - } - - function _burn(address _from, uint _amount) private { - balanceOf[_from] -= _amount; - totalSupply -= _amount; - } - - // Return precision-adjusted balances, adjusted to 18 decimals - function _xp() private view returns (uint[N] memory xp) { - for (uint i; i < N; ++i) { - xp[i] = balances[i] * multipliers[i]; - } - } - - /** - * @notice Calculate D, sum of balances in a perfectly balanced pool - * If balances of x_0, x_1, ... x_(n-1) then sum(x_i) = D - * @param xp Precision-adjusted balances - * @return D - */ - function _getD(uint[N] memory xp) private pure returns (uint) { - /* - Newton's method to compute D - ----------------------------- - f(D) = ADn^n + D^(n + 1) / (n^n prod(x_i)) - An^n sum(x_i) - D - f'(D) = An^n + (n + 1) D^n / (n^n prod(x_i)) - 1 - - (as + np)D_n - D_(n+1) = ----------------------- - (a - 1)D_n + (n + 1)p - - a = An^n - s = sum(x_i) - p = (D_n)^(n + 1) / (n^n prod(x_i)) - */ - uint a = A * N; // An^n - - uint s; // x_0 + x_1 + ... + x_(n-1) - for (uint i; i < N; ++i) { - s += xp[i]; - } - - // Newton's method - // Initial guess, d <= s - uint d = s; - uint d_prev; - for (uint i; i < 255; ++i) { - // p = D^(n + 1) / (n^n * x_0 * ... * x_(n-1)) - uint p = d; - for (uint j; j < N; ++j) { - p = (p * d) / (N * xp[j]); - } - d_prev = d; - d = ((a * s + N * p) * d) / ((a - 1) * d + (N + 1) * p); - - if (Math.abs(d, d_prev) <= 1) { - return d; - } - } - revert("D didn't converge"); - } - - /** - * @notice Calculate the new balance of token j given the new balance of token i - * @param i Index of token in - * @param j Index of token out - * @param x New balance of token i - * @param xp Current precision-adjusted balances - */ - function _getY( - uint i, - uint j, - uint x, - uint[N] memory xp - ) private pure returns (uint) { - /* - Newton's method to compute y - ----------------------------- - y = x_j - - f(y) = y^2 + y(b - D) - c - - y_n^2 + c - y_(n+1) = -------------- - 2y_n + b - D - - where - s = sum(x_k), k != j - p = prod(x_k), k != j - b = s + D / (An^n) - c = D^(n + 1) / (n^n * p * An^n) - */ - uint a = A * N; - uint d = _getD(xp); - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k == i) { - _x = x; - } else if (k == j) { - continue; - } else { - _x = xp[k]; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - /** - * @notice Calculate the new balance of token i given precision-adjusted - * balances xp and liquidity d - * @dev Equation is calculate y is same as _getY - * @param i Index of token to calculate the new balance - * @param xp Precision-adjusted balances - * @param d Liquidity d - * @return New balance of token i - */ - function _getYD(uint i, uint[N] memory xp, uint d) private pure returns (uint) { - uint a = A * N; - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k != i) { - _x = xp[k]; - } else { - continue; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - // Estimate value of 1 share - // How many tokens is one share worth? - function getVirtualPrice() external view returns (uint) { - uint d = _getD(_xp()); - uint _totalSupply = totalSupply; - if (_totalSupply > 0) { - return (d * 10 ** DECIMALS) / _totalSupply; - } - return 0; - } - - /** - * @notice Swap dx amount of token i for token j - * @param i Index of token in - * @param j Index of token out - * @param dx Token in amount - * @param minDy Minimum token out - */ - function swap(uint i, uint j, uint dx, uint minDy) external returns (uint dy) { - require(i != j, "i = j"); - - IERC20(tokens[i]).transferFrom(msg.sender, address(this), dx); - - // Calculate dy - uint[N] memory xp = _xp(); - uint x = xp[i] + dx * multipliers[i]; - - uint y0 = xp[j]; - uint y1 = _getY(i, j, x, xp); - // y0 must be >= y1, since x has increased - // -1 to round down - dy = (y0 - y1 - 1) / multipliers[j]; - - // Subtract fee from dy - uint fee = (dy * SWAP_FEE) / FEE_DENOMINATOR; - dy -= fee; - require(dy >= minDy, "dy < min"); - - balances[i] += dx; - balances[j] -= dy; - - IERC20(tokens[j]).transfer(msg.sender, dy); - } - - function addLiquidity( - uint[N] calldata amounts, - uint minShares - ) external returns (uint shares) { - // calculate current liquidity d0 - uint _totalSupply = totalSupply; - uint d0; - uint[N] memory old_xs = _xp(); - if (_totalSupply > 0) { - d0 = _getD(old_xs); - } - - // Transfer tokens in - uint[N] memory new_xs; - for (uint i; i < N; ++i) { - uint amount = amounts[i]; - if (amount > 0) { - IERC20(tokens[i]).transferFrom(msg.sender, address(this), amount); - new_xs[i] = old_xs[i] + amount * multipliers[i]; - } else { - new_xs[i] = old_xs[i]; - } - } - - // Calculate new liquidity d1 - uint d1 = _getD(new_xs); - require(d1 > d0, "liquidity didn't increase"); - - // Reccalcuate D accounting for fee on imbalance - uint d2; - if (_totalSupply > 0) { - for (uint i; i < N; ++i) { - // TODO: why old_xs[i] * d1 / d0? why not d1 / N? - uint idealBalance = (old_xs[i] * d1) / d0; - uint diff = Math.abs(new_xs[i], idealBalance); - new_xs[i] -= (LIQUIDITY_FEE * diff) / FEE_DENOMINATOR; - } - - d2 = _getD(new_xs); - } else { - d2 = d1; - } - - // Update balances - for (uint i; i < N; ++i) { - balances[i] += amounts[i]; - } - - // Shares to mint = (d2 - d0) / d0 * total supply - // d1 >= d2 >= d0 - if (_totalSupply > 0) { - shares = ((d2 - d0) * _totalSupply) / d0; - } else { - shares = d2; - } - require(shares >= minShares, "shares < min"); - _mint(msg.sender, shares); - bug(); - } - - function removeLiquidity( - uint shares, - uint[N] calldata minAmountsOut - ) external returns (uint[N] memory amountsOut) { - uint _totalSupply = totalSupply; - - for (uint i; i < N; ++i) { - uint amountOut = (balances[i] * shares) / _totalSupply; - require(amountOut >= minAmountsOut[i], "out < min"); - - balances[i] -= amountOut; - amountsOut[i] = amountOut; - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } - - _burn(msg.sender, shares); - } - - /** - * @notice Calculate amount of token i to receive for shares - * @param shares Shares to burn - * @param i Index of token to withdraw - * @return dy Amount of token i to receive - * fee Fee for withdraw. Fee already included in dy - */ - function _calcWithdrawOneToken( - uint shares, - uint i - ) private view returns (uint dy, uint fee) { - uint _totalSupply = totalSupply; - uint[N] memory xp = _xp(); - - // Calculate d0 and d1 - uint d0 = _getD(xp); - uint d1 = d0 - (d0 * shares) / _totalSupply; - - // Calculate reduction in y if D = d1 - uint y0 = _getYD(i, xp, d1); - // d1 <= d0 so y must be <= xp[i] - uint dy0 = (xp[i] - y0) / multipliers[i]; - - // Calculate imbalance fee, update xp with fees - uint dx; - for (uint j; j < N; ++j) { - if (j == i) { - dx = (xp[j] * d1) / d0 - y0; - } else { - // d1 / d0 <= 1 - dx = xp[j] - (xp[j] * d1) / d0; - } - xp[j] -= (LIQUIDITY_FEE * dx) / FEE_DENOMINATOR; - } - - // Recalculate y with xp including imbalance fees - uint y1 = _getYD(i, xp, d1); - // - 1 to round down - dy = (xp[i] - y1 - 1) / multipliers[i]; - fee = dy0 - dy; - } - - function calcWithdrawOneToken( - uint shares, - uint i - ) external view returns (uint dy, uint fee) { - return _calcWithdrawOneToken(shares, i); - } - - /** - * @notice Withdraw liquidity in token i - * @param shares Shares to burn - * @param i Token to withdraw - * @param minAmountOut Minimum amount of token i that must be withdrawn - */ - function removeLiquidityOneToken( - uint shares, - uint i, - uint minAmountOut - ) external returns (uint amountOut) { - (amountOut, ) = _calcWithdrawOneToken(shares, i); - require(amountOut >= minAmountOut, "out < min"); - - balances[i] -= amountOut; - _burn(msg.sender, shares); - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint amount); - event Approval(address indexed owner, address indexed spender, uint amount); -} diff --git a/tests/evm/sbe@stable-swap-amm-burn/StableSwap.sol b/tests/evm/sbe@stable-swap-amm-burn/StableSwap.sol deleted file mode 100644 index 90ada787a..000000000 --- a/tests/evm/sbe@stable-swap-amm-burn/StableSwap.sol +++ /dev/null @@ -1,444 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8; -import "../../../solidity_utils/lib.sol"; - -/* -Invariant - price of trade and amount of liquidity are determined by this equation - -An^n sum(x_i) + D = ADn^n + D^(n + 1) / (n^n prod(x_i)) - -Topics -0. Newton's method x_(n + 1) = x_n - f(x_n) / f'(x_n) -1. Invariant -2. Swap - - Calculate Y - - Calculate D -3. Get virtual price -4. Add liquidity - - Imbalance fee -5. Remove liquidity -6. Remove liquidity one token - - Calculate withdraw one token - - getYD -TODO: test? -*/ - -library Math { - function abs(uint x, uint y) internal pure returns (uint) { - return x >= y ? x - y : y - x; - } -} - -contract StableSwap { - // Number of tokens - uint private constant N = 3; - // Amplification coefficient multiplied by N^(N - 1) - // Higher value makes the curve more flat - // Lower value makes the curve more like constant product AMM - uint private constant A = 1000 * (N ** (N - 1)); - // 0.03% - uint private constant SWAP_FEE = 300; - // Liquidity fee is derived from 2 constraints - // 1. Fee is 0 for adding / removing liquidity that results in a balanced pool - // 2. Swapping in a balanced pool is like adding and then removing liquidity - // from a balanced pool - // swap fee = add liquidity fee + remove liquidity fee - uint private constant LIQUIDITY_FEE = (SWAP_FEE * N) / (4 * (N - 1)); - uint private constant FEE_DENOMINATOR = 1e6; - - address[N] public tokens; - // Normalize each token to 18 decimals - // Example - DAI (18 decimals), USDC (6 decimals), USDT (6 decimals) - uint[N] private multipliers = [1, 1e12, 1e12]; - uint[N] public balances; - - // 1 share = 1e18, 18 decimals - uint private constant DECIMALS = 18; - uint public totalSupply; - mapping(address => uint) public balanceOf; - - function _mint(address _to, uint _amount) private { - balanceOf[_to] += _amount; - totalSupply += _amount; - } - - function _burn(address _from, uint _amount) private { - balanceOf[_from] -= _amount; - totalSupply -= _amount; - bug(); - } - - // Return precision-adjusted balances, adjusted to 18 decimals - function _xp() private view returns (uint[N] memory xp) { - for (uint i; i < N; ++i) { - xp[i] = balances[i] * multipliers[i]; - } - } - - /** - * @notice Calculate D, sum of balances in a perfectly balanced pool - * If balances of x_0, x_1, ... x_(n-1) then sum(x_i) = D - * @param xp Precision-adjusted balances - * @return D - */ - function _getD(uint[N] memory xp) private pure returns (uint) { - /* - Newton's method to compute D - ----------------------------- - f(D) = ADn^n + D^(n + 1) / (n^n prod(x_i)) - An^n sum(x_i) - D - f'(D) = An^n + (n + 1) D^n / (n^n prod(x_i)) - 1 - - (as + np)D_n - D_(n+1) = ----------------------- - (a - 1)D_n + (n + 1)p - - a = An^n - s = sum(x_i) - p = (D_n)^(n + 1) / (n^n prod(x_i)) - */ - uint a = A * N; // An^n - - uint s; // x_0 + x_1 + ... + x_(n-1) - for (uint i; i < N; ++i) { - s += xp[i]; - } - - // Newton's method - // Initial guess, d <= s - uint d = s; - uint d_prev; - for (uint i; i < 255; ++i) { - // p = D^(n + 1) / (n^n * x_0 * ... * x_(n-1)) - uint p = d; - for (uint j; j < N; ++j) { - p = (p * d) / (N * xp[j]); - } - d_prev = d; - d = ((a * s + N * p) * d) / ((a - 1) * d + (N + 1) * p); - - if (Math.abs(d, d_prev) <= 1) { - return d; - } - } - revert("D didn't converge"); - } - - /** - * @notice Calculate the new balance of token j given the new balance of token i - * @param i Index of token in - * @param j Index of token out - * @param x New balance of token i - * @param xp Current precision-adjusted balances - */ - function _getY( - uint i, - uint j, - uint x, - uint[N] memory xp - ) private pure returns (uint) { - /* - Newton's method to compute y - ----------------------------- - y = x_j - - f(y) = y^2 + y(b - D) - c - - y_n^2 + c - y_(n+1) = -------------- - 2y_n + b - D - - where - s = sum(x_k), k != j - p = prod(x_k), k != j - b = s + D / (An^n) - c = D^(n + 1) / (n^n * p * An^n) - */ - uint a = A * N; - uint d = _getD(xp); - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k == i) { - _x = x; - } else if (k == j) { - continue; - } else { - _x = xp[k]; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - /** - * @notice Calculate the new balance of token i given precision-adjusted - * balances xp and liquidity d - * @dev Equation is calculate y is same as _getY - * @param i Index of token to calculate the new balance - * @param xp Precision-adjusted balances - * @param d Liquidity d - * @return New balance of token i - */ - function _getYD(uint i, uint[N] memory xp, uint d) private pure returns (uint) { - uint a = A * N; - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k != i) { - _x = xp[k]; - } else { - continue; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - // Estimate value of 1 share - // How many tokens is one share worth? - function getVirtualPrice() external view returns (uint) { - uint d = _getD(_xp()); - uint _totalSupply = totalSupply; - if (_totalSupply > 0) { - return (d * 10 ** DECIMALS) / _totalSupply; - } - return 0; - } - - /** - * @notice Swap dx amount of token i for token j - * @param i Index of token in - * @param j Index of token out - * @param dx Token in amount - * @param minDy Minimum token out - */ - function swap(uint i, uint j, uint dx, uint minDy) external returns (uint dy) { - require(i != j, "i = j"); - - IERC20(tokens[i]).transferFrom(msg.sender, address(this), dx); - - // Calculate dy - uint[N] memory xp = _xp(); - uint x = xp[i] + dx * multipliers[i]; - - uint y0 = xp[j]; - uint y1 = _getY(i, j, x, xp); - // y0 must be >= y1, since x has increased - // -1 to round down - dy = (y0 - y1 - 1) / multipliers[j]; - - // Subtract fee from dy - uint fee = (dy * SWAP_FEE) / FEE_DENOMINATOR; - dy -= fee; - require(dy >= minDy, "dy < min"); - - balances[i] += dx; - balances[j] -= dy; - - IERC20(tokens[j]).transfer(msg.sender, dy); - } - - function addLiquidity( - uint[N] calldata amounts, - uint minShares - ) external returns (uint shares) { - // calculate current liquidity d0 - uint _totalSupply = totalSupply; - uint d0; - uint[N] memory old_xs = _xp(); - if (_totalSupply > 0) { - d0 = _getD(old_xs); - } - - // Transfer tokens in - uint[N] memory new_xs; - for (uint i; i < N; ++i) { - uint amount = amounts[i]; - if (amount > 0) { - IERC20(tokens[i]).transferFrom(msg.sender, address(this), amount); - new_xs[i] = old_xs[i] + amount * multipliers[i]; - } else { - new_xs[i] = old_xs[i]; - } - } - - // Calculate new liquidity d1 - uint d1 = _getD(new_xs); - require(d1 > d0, "liquidity didn't increase"); - - // Reccalcuate D accounting for fee on imbalance - uint d2; - if (_totalSupply > 0) { - for (uint i; i < N; ++i) { - // TODO: why old_xs[i] * d1 / d0? why not d1 / N? - uint idealBalance = (old_xs[i] * d1) / d0; - uint diff = Math.abs(new_xs[i], idealBalance); - new_xs[i] -= (LIQUIDITY_FEE * diff) / FEE_DENOMINATOR; - } - - d2 = _getD(new_xs); - } else { - d2 = d1; - } - - // Update balances - for (uint i; i < N; ++i) { - balances[i] += amounts[i]; - } - - // Shares to mint = (d2 - d0) / d0 * total supply - // d1 >= d2 >= d0 - if (_totalSupply > 0) { - shares = ((d2 - d0) * _totalSupply) / d0; - } else { - shares = d2; - } - require(shares >= minShares, "shares < min"); - _mint(msg.sender, shares); - } - - function removeLiquidity( - uint shares, - uint[N] calldata minAmountsOut - ) external returns (uint[N] memory amountsOut) { - uint _totalSupply = totalSupply; - - for (uint i; i < N; ++i) { - uint amountOut = (balances[i] * shares) / _totalSupply; - require(amountOut >= minAmountsOut[i], "out < min"); - - balances[i] -= amountOut; - amountsOut[i] = amountOut; - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } - - _burn(msg.sender, shares); - } - - /** - * @notice Calculate amount of token i to receive for shares - * @param shares Shares to burn - * @param i Index of token to withdraw - * @return dy Amount of token i to receive - * fee Fee for withdraw. Fee already included in dy - */ - function _calcWithdrawOneToken( - uint shares, - uint i - ) private view returns (uint dy, uint fee) { - uint _totalSupply = totalSupply; - uint[N] memory xp = _xp(); - - // Calculate d0 and d1 - uint d0 = _getD(xp); - uint d1 = d0 - (d0 * shares) / _totalSupply; - - // Calculate reduction in y if D = d1 - uint y0 = _getYD(i, xp, d1); - // d1 <= d0 so y must be <= xp[i] - uint dy0 = (xp[i] - y0) / multipliers[i]; - - // Calculate imbalance fee, update xp with fees - uint dx; - for (uint j; j < N; ++j) { - if (j == i) { - dx = (xp[j] * d1) / d0 - y0; - } else { - // d1 / d0 <= 1 - dx = xp[j] - (xp[j] * d1) / d0; - } - xp[j] -= (LIQUIDITY_FEE * dx) / FEE_DENOMINATOR; - } - - // Recalculate y with xp including imbalance fees - uint y1 = _getYD(i, xp, d1); - // - 1 to round down - dy = (xp[i] - y1 - 1) / multipliers[i]; - fee = dy0 - dy; - } - - function calcWithdrawOneToken( - uint shares, - uint i - ) external view returns (uint dy, uint fee) { - return _calcWithdrawOneToken(shares, i); - } - - /** - * @notice Withdraw liquidity in token i - * @param shares Shares to burn - * @param i Token to withdraw - * @param minAmountOut Minimum amount of token i that must be withdrawn - */ - function removeLiquidityOneToken( - uint shares, - uint i, - uint minAmountOut - ) external returns (uint amountOut) { - (amountOut, ) = _calcWithdrawOneToken(shares, i); - require(amountOut >= minAmountOut, "out < min"); - - balances[i] -= amountOut; - _burn(msg.sender, shares); - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint amount); - event Approval(address indexed owner, address indexed spender, uint amount); -} diff --git a/tests/evm/sbe@stable-swap-amm-mint/StableSwap.sol b/tests/evm/sbe@stable-swap-amm-mint/StableSwap.sol deleted file mode 100644 index 1f25abcd3..000000000 --- a/tests/evm/sbe@stable-swap-amm-mint/StableSwap.sol +++ /dev/null @@ -1,444 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8; -import "../../../solidity_utils/lib.sol"; - -/* -Invariant - price of trade and amount of liquidity are determined by this equation - -An^n sum(x_i) + D = ADn^n + D^(n + 1) / (n^n prod(x_i)) - -Topics -0. Newton's method x_(n + 1) = x_n - f(x_n) / f'(x_n) -1. Invariant -2. Swap - - Calculate Y - - Calculate D -3. Get virtual price -4. Add liquidity - - Imbalance fee -5. Remove liquidity -6. Remove liquidity one token - - Calculate withdraw one token - - getYD -TODO: test? -*/ - -library Math { - function abs(uint x, uint y) internal pure returns (uint) { - return x >= y ? x - y : y - x; - } -} - -contract StableSwap { - // Number of tokens - uint private constant N = 3; - // Amplification coefficient multiplied by N^(N - 1) - // Higher value makes the curve more flat - // Lower value makes the curve more like constant product AMM - uint private constant A = 1000 * (N ** (N - 1)); - // 0.03% - uint private constant SWAP_FEE = 300; - // Liquidity fee is derived from 2 constraints - // 1. Fee is 0 for adding / removing liquidity that results in a balanced pool - // 2. Swapping in a balanced pool is like adding and then removing liquidity - // from a balanced pool - // swap fee = add liquidity fee + remove liquidity fee - uint private constant LIQUIDITY_FEE = (SWAP_FEE * N) / (4 * (N - 1)); - uint private constant FEE_DENOMINATOR = 1e6; - - address[N] public tokens; - // Normalize each token to 18 decimals - // Example - DAI (18 decimals), USDC (6 decimals), USDT (6 decimals) - uint[N] private multipliers = [1, 1e12, 1e12]; - uint[N] public balances; - - // 1 share = 1e18, 18 decimals - uint private constant DECIMALS = 18; - uint public totalSupply; - mapping(address => uint) public balanceOf; - - function _mint(address _to, uint _amount) private { - balanceOf[_to] += _amount; - totalSupply += _amount; - bug(); - } - - function _burn(address _from, uint _amount) private { - balanceOf[_from] -= _amount; - totalSupply -= _amount; - } - - // Return precision-adjusted balances, adjusted to 18 decimals - function _xp() private view returns (uint[N] memory xp) { - for (uint i; i < N; ++i) { - xp[i] = balances[i] * multipliers[i]; - } - } - - /** - * @notice Calculate D, sum of balances in a perfectly balanced pool - * If balances of x_0, x_1, ... x_(n-1) then sum(x_i) = D - * @param xp Precision-adjusted balances - * @return D - */ - function _getD(uint[N] memory xp) private pure returns (uint) { - /* - Newton's method to compute D - ----------------------------- - f(D) = ADn^n + D^(n + 1) / (n^n prod(x_i)) - An^n sum(x_i) - D - f'(D) = An^n + (n + 1) D^n / (n^n prod(x_i)) - 1 - - (as + np)D_n - D_(n+1) = ----------------------- - (a - 1)D_n + (n + 1)p - - a = An^n - s = sum(x_i) - p = (D_n)^(n + 1) / (n^n prod(x_i)) - */ - uint a = A * N; // An^n - - uint s; // x_0 + x_1 + ... + x_(n-1) - for (uint i; i < N; ++i) { - s += xp[i]; - } - - // Newton's method - // Initial guess, d <= s - uint d = s; - uint d_prev; - for (uint i; i < 255; ++i) { - // p = D^(n + 1) / (n^n * x_0 * ... * x_(n-1)) - uint p = d; - for (uint j; j < N; ++j) { - p = (p * d) / (N * xp[j]); - } - d_prev = d; - d = ((a * s + N * p) * d) / ((a - 1) * d + (N + 1) * p); - - if (Math.abs(d, d_prev) <= 1) { - return d; - } - } - revert("D didn't converge"); - } - - /** - * @notice Calculate the new balance of token j given the new balance of token i - * @param i Index of token in - * @param j Index of token out - * @param x New balance of token i - * @param xp Current precision-adjusted balances - */ - function _getY( - uint i, - uint j, - uint x, - uint[N] memory xp - ) private pure returns (uint) { - /* - Newton's method to compute y - ----------------------------- - y = x_j - - f(y) = y^2 + y(b - D) - c - - y_n^2 + c - y_(n+1) = -------------- - 2y_n + b - D - - where - s = sum(x_k), k != j - p = prod(x_k), k != j - b = s + D / (An^n) - c = D^(n + 1) / (n^n * p * An^n) - */ - uint a = A * N; - uint d = _getD(xp); - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k == i) { - _x = x; - } else if (k == j) { - continue; - } else { - _x = xp[k]; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - /** - * @notice Calculate the new balance of token i given precision-adjusted - * balances xp and liquidity d - * @dev Equation is calculate y is same as _getY - * @param i Index of token to calculate the new balance - * @param xp Precision-adjusted balances - * @param d Liquidity d - * @return New balance of token i - */ - function _getYD(uint i, uint[N] memory xp, uint d) private pure returns (uint) { - uint a = A * N; - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k != i) { - _x = xp[k]; - } else { - continue; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - // Estimate value of 1 share - // How many tokens is one share worth? - function getVirtualPrice() external view returns (uint) { - uint d = _getD(_xp()); - uint _totalSupply = totalSupply; - if (_totalSupply > 0) { - return (d * 10 ** DECIMALS) / _totalSupply; - } - return 0; - } - - /** - * @notice Swap dx amount of token i for token j - * @param i Index of token in - * @param j Index of token out - * @param dx Token in amount - * @param minDy Minimum token out - */ - function swap(uint i, uint j, uint dx, uint minDy) external returns (uint dy) { - require(i != j, "i = j"); - - IERC20(tokens[i]).transferFrom(msg.sender, address(this), dx); - - // Calculate dy - uint[N] memory xp = _xp(); - uint x = xp[i] + dx * multipliers[i]; - - uint y0 = xp[j]; - uint y1 = _getY(i, j, x, xp); - // y0 must be >= y1, since x has increased - // -1 to round down - dy = (y0 - y1 - 1) / multipliers[j]; - - // Subtract fee from dy - uint fee = (dy * SWAP_FEE) / FEE_DENOMINATOR; - dy -= fee; - require(dy >= minDy, "dy < min"); - - balances[i] += dx; - balances[j] -= dy; - - IERC20(tokens[j]).transfer(msg.sender, dy); - } - - function addLiquidity( - uint[N] calldata amounts, - uint minShares - ) external returns (uint shares) { - // calculate current liquidity d0 - uint _totalSupply = totalSupply; - uint d0; - uint[N] memory old_xs = _xp(); - if (_totalSupply > 0) { - d0 = _getD(old_xs); - } - - // Transfer tokens in - uint[N] memory new_xs; - for (uint i; i < N; ++i) { - uint amount = amounts[i]; - if (amount > 0) { - IERC20(tokens[i]).transferFrom(msg.sender, address(this), amount); - new_xs[i] = old_xs[i] + amount * multipliers[i]; - } else { - new_xs[i] = old_xs[i]; - } - } - - // Calculate new liquidity d1 - uint d1 = _getD(new_xs); - require(d1 > d0, "liquidity didn't increase"); - - // Reccalcuate D accounting for fee on imbalance - uint d2; - if (_totalSupply > 0) { - for (uint i; i < N; ++i) { - // TODO: why old_xs[i] * d1 / d0? why not d1 / N? - uint idealBalance = (old_xs[i] * d1) / d0; - uint diff = Math.abs(new_xs[i], idealBalance); - new_xs[i] -= (LIQUIDITY_FEE * diff) / FEE_DENOMINATOR; - } - - d2 = _getD(new_xs); - } else { - d2 = d1; - } - - // Update balances - for (uint i; i < N; ++i) { - balances[i] += amounts[i]; - } - - // Shares to mint = (d2 - d0) / d0 * total supply - // d1 >= d2 >= d0 - if (_totalSupply > 0) { - shares = ((d2 - d0) * _totalSupply) / d0; - } else { - shares = d2; - } - require(shares >= minShares, "shares < min"); - _mint(msg.sender, shares); - } - - function removeLiquidity( - uint shares, - uint[N] calldata minAmountsOut - ) external returns (uint[N] memory amountsOut) { - uint _totalSupply = totalSupply; - - for (uint i; i < N; ++i) { - uint amountOut = (balances[i] * shares) / _totalSupply; - require(amountOut >= minAmountsOut[i], "out < min"); - - balances[i] -= amountOut; - amountsOut[i] = amountOut; - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } - - _burn(msg.sender, shares); - } - - /** - * @notice Calculate amount of token i to receive for shares - * @param shares Shares to burn - * @param i Index of token to withdraw - * @return dy Amount of token i to receive - * fee Fee for withdraw. Fee already included in dy - */ - function _calcWithdrawOneToken( - uint shares, - uint i - ) private view returns (uint dy, uint fee) { - uint _totalSupply = totalSupply; - uint[N] memory xp = _xp(); - - // Calculate d0 and d1 - uint d0 = _getD(xp); - uint d1 = d0 - (d0 * shares) / _totalSupply; - - // Calculate reduction in y if D = d1 - uint y0 = _getYD(i, xp, d1); - // d1 <= d0 so y must be <= xp[i] - uint dy0 = (xp[i] - y0) / multipliers[i]; - - // Calculate imbalance fee, update xp with fees - uint dx; - for (uint j; j < N; ++j) { - if (j == i) { - dx = (xp[j] * d1) / d0 - y0; - } else { - // d1 / d0 <= 1 - dx = xp[j] - (xp[j] * d1) / d0; - } - xp[j] -= (LIQUIDITY_FEE * dx) / FEE_DENOMINATOR; - } - - // Recalculate y with xp including imbalance fees - uint y1 = _getYD(i, xp, d1); - // - 1 to round down - dy = (xp[i] - y1 - 1) / multipliers[i]; - fee = dy0 - dy; - } - - function calcWithdrawOneToken( - uint shares, - uint i - ) external view returns (uint dy, uint fee) { - return _calcWithdrawOneToken(shares, i); - } - - /** - * @notice Withdraw liquidity in token i - * @param shares Shares to burn - * @param i Token to withdraw - * @param minAmountOut Minimum amount of token i that must be withdrawn - */ - function removeLiquidityOneToken( - uint shares, - uint i, - uint minAmountOut - ) external returns (uint amountOut) { - (amountOut, ) = _calcWithdrawOneToken(shares, i); - require(amountOut >= minAmountOut, "out < min"); - - balances[i] -= amountOut; - _burn(msg.sender, shares); - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint amount); - event Approval(address indexed owner, address indexed spender, uint amount); -} diff --git a/tests/evm/sbe@stable-swap-amm-removeLiquidity1/StableSwap.sol b/tests/evm/sbe@stable-swap-amm-removeLiquidity1/StableSwap.sol deleted file mode 100644 index 1df374da0..000000000 --- a/tests/evm/sbe@stable-swap-amm-removeLiquidity1/StableSwap.sol +++ /dev/null @@ -1,444 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8; -import "../../../solidity_utils/lib.sol"; - -/* -Invariant - price of trade and amount of liquidity are determined by this equation - -An^n sum(x_i) + D = ADn^n + D^(n + 1) / (n^n prod(x_i)) - -Topics -0. Newton's method x_(n + 1) = x_n - f(x_n) / f'(x_n) -1. Invariant -2. Swap - - Calculate Y - - Calculate D -3. Get virtual price -4. Add liquidity - - Imbalance fee -5. Remove liquidity -6. Remove liquidity one token - - Calculate withdraw one token - - getYD -TODO: test? -*/ - -library Math { - function abs(uint x, uint y) internal pure returns (uint) { - return x >= y ? x - y : y - x; - } -} - -contract StableSwap { - // Number of tokens - uint private constant N = 3; - // Amplification coefficient multiplied by N^(N - 1) - // Higher value makes the curve more flat - // Lower value makes the curve more like constant product AMM - uint private constant A = 1000 * (N ** (N - 1)); - // 0.03% - uint private constant SWAP_FEE = 300; - // Liquidity fee is derived from 2 constraints - // 1. Fee is 0 for adding / removing liquidity that results in a balanced pool - // 2. Swapping in a balanced pool is like adding and then removing liquidity - // from a balanced pool - // swap fee = add liquidity fee + remove liquidity fee - uint private constant LIQUIDITY_FEE = (SWAP_FEE * N) / (4 * (N - 1)); - uint private constant FEE_DENOMINATOR = 1e6; - - address[N] public tokens; - // Normalize each token to 18 decimals - // Example - DAI (18 decimals), USDC (6 decimals), USDT (6 decimals) - uint[N] private multipliers = [1, 1e12, 1e12]; - uint[N] public balances; - - // 1 share = 1e18, 18 decimals - uint private constant DECIMALS = 18; - uint public totalSupply; - mapping(address => uint) public balanceOf; - - function _mint(address _to, uint _amount) private { - balanceOf[_to] += _amount; - totalSupply += _amount; - } - - function _burn(address _from, uint _amount) private { - balanceOf[_from] -= _amount; - totalSupply -= _amount; - } - - // Return precision-adjusted balances, adjusted to 18 decimals - function _xp() private view returns (uint[N] memory xp) { - for (uint i; i < N; ++i) { - xp[i] = balances[i] * multipliers[i]; - } - } - - /** - * @notice Calculate D, sum of balances in a perfectly balanced pool - * If balances of x_0, x_1, ... x_(n-1) then sum(x_i) = D - * @param xp Precision-adjusted balances - * @return D - */ - function _getD(uint[N] memory xp) private pure returns (uint) { - /* - Newton's method to compute D - ----------------------------- - f(D) = ADn^n + D^(n + 1) / (n^n prod(x_i)) - An^n sum(x_i) - D - f'(D) = An^n + (n + 1) D^n / (n^n prod(x_i)) - 1 - - (as + np)D_n - D_(n+1) = ----------------------- - (a - 1)D_n + (n + 1)p - - a = An^n - s = sum(x_i) - p = (D_n)^(n + 1) / (n^n prod(x_i)) - */ - uint a = A * N; // An^n - - uint s; // x_0 + x_1 + ... + x_(n-1) - for (uint i; i < N; ++i) { - s += xp[i]; - } - - // Newton's method - // Initial guess, d <= s - uint d = s; - uint d_prev; - for (uint i; i < 255; ++i) { - // p = D^(n + 1) / (n^n * x_0 * ... * x_(n-1)) - uint p = d; - for (uint j; j < N; ++j) { - p = (p * d) / (N * xp[j]); - } - d_prev = d; - d = ((a * s + N * p) * d) / ((a - 1) * d + (N + 1) * p); - - if (Math.abs(d, d_prev) <= 1) { - return d; - } - } - revert("D didn't converge"); - } - - /** - * @notice Calculate the new balance of token j given the new balance of token i - * @param i Index of token in - * @param j Index of token out - * @param x New balance of token i - * @param xp Current precision-adjusted balances - */ - function _getY( - uint i, - uint j, - uint x, - uint[N] memory xp - ) private pure returns (uint) { - /* - Newton's method to compute y - ----------------------------- - y = x_j - - f(y) = y^2 + y(b - D) - c - - y_n^2 + c - y_(n+1) = -------------- - 2y_n + b - D - - where - s = sum(x_k), k != j - p = prod(x_k), k != j - b = s + D / (An^n) - c = D^(n + 1) / (n^n * p * An^n) - */ - uint a = A * N; - uint d = _getD(xp); - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k == i) { - _x = x; - } else if (k == j) { - continue; - } else { - _x = xp[k]; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - /** - * @notice Calculate the new balance of token i given precision-adjusted - * balances xp and liquidity d - * @dev Equation is calculate y is same as _getY - * @param i Index of token to calculate the new balance - * @param xp Precision-adjusted balances - * @param d Liquidity d - * @return New balance of token i - */ - function _getYD(uint i, uint[N] memory xp, uint d) private pure returns (uint) { - uint a = A * N; - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k != i) { - _x = xp[k]; - } else { - continue; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - // Estimate value of 1 share - // How many tokens is one share worth? - function getVirtualPrice() external view returns (uint) { - uint d = _getD(_xp()); - uint _totalSupply = totalSupply; - if (_totalSupply > 0) { - return (d * 10 ** DECIMALS) / _totalSupply; - } - return 0; - } - - /** - * @notice Swap dx amount of token i for token j - * @param i Index of token in - * @param j Index of token out - * @param dx Token in amount - * @param minDy Minimum token out - */ - function swap(uint i, uint j, uint dx, uint minDy) external returns (uint dy) { - require(i != j, "i = j"); - - IERC20(tokens[i]).transferFrom(msg.sender, address(this), dx); - - // Calculate dy - uint[N] memory xp = _xp(); - uint x = xp[i] + dx * multipliers[i]; - - uint y0 = xp[j]; - uint y1 = _getY(i, j, x, xp); - // y0 must be >= y1, since x has increased - // -1 to round down - dy = (y0 - y1 - 1) / multipliers[j]; - - // Subtract fee from dy - uint fee = (dy * SWAP_FEE) / FEE_DENOMINATOR; - dy -= fee; - require(dy >= minDy, "dy < min"); - - balances[i] += dx; - balances[j] -= dy; - - IERC20(tokens[j]).transfer(msg.sender, dy); - } - - function addLiquidity( - uint[N] calldata amounts, - uint minShares - ) external returns (uint shares) { - // calculate current liquidity d0 - uint _totalSupply = totalSupply; - uint d0; - uint[N] memory old_xs = _xp(); - if (_totalSupply > 0) { - d0 = _getD(old_xs); - } - - // Transfer tokens in - uint[N] memory new_xs; - for (uint i; i < N; ++i) { - uint amount = amounts[i]; - if (amount > 0) { - IERC20(tokens[i]).transferFrom(msg.sender, address(this), amount); - new_xs[i] = old_xs[i] + amount * multipliers[i]; - } else { - new_xs[i] = old_xs[i]; - } - } - - // Calculate new liquidity d1 - uint d1 = _getD(new_xs); - require(d1 > d0, "liquidity didn't increase"); - - // Reccalcuate D accounting for fee on imbalance - uint d2; - if (_totalSupply > 0) { - for (uint i; i < N; ++i) { - // TODO: why old_xs[i] * d1 / d0? why not d1 / N? - uint idealBalance = (old_xs[i] * d1) / d0; - uint diff = Math.abs(new_xs[i], idealBalance); - new_xs[i] -= (LIQUIDITY_FEE * diff) / FEE_DENOMINATOR; - } - - d2 = _getD(new_xs); - } else { - d2 = d1; - } - - // Update balances - for (uint i; i < N; ++i) { - balances[i] += amounts[i]; - } - - // Shares to mint = (d2 - d0) / d0 * total supply - // d1 >= d2 >= d0 - if (_totalSupply > 0) { - shares = ((d2 - d0) * _totalSupply) / d0; - } else { - shares = d2; - } - require(shares >= minShares, "shares < min"); - _mint(msg.sender, shares); - } - - function removeLiquidity( - uint shares, - uint[N] calldata minAmountsOut - ) external returns (uint[N] memory amountsOut) { - uint _totalSupply = totalSupply; - - for (uint i; i < N; ++i) { - uint amountOut = (balances[i] * shares) / _totalSupply; - require(amountOut >= minAmountsOut[i], "out < min"); - - balances[i] -= amountOut; - amountsOut[i] = amountOut; - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - bug(); - } - - _burn(msg.sender, shares); - } - - /** - * @notice Calculate amount of token i to receive for shares - * @param shares Shares to burn - * @param i Index of token to withdraw - * @return dy Amount of token i to receive - * fee Fee for withdraw. Fee already included in dy - */ - function _calcWithdrawOneToken( - uint shares, - uint i - ) private view returns (uint dy, uint fee) { - uint _totalSupply = totalSupply; - uint[N] memory xp = _xp(); - - // Calculate d0 and d1 - uint d0 = _getD(xp); - uint d1 = d0 - (d0 * shares) / _totalSupply; - - // Calculate reduction in y if D = d1 - uint y0 = _getYD(i, xp, d1); - // d1 <= d0 so y must be <= xp[i] - uint dy0 = (xp[i] - y0) / multipliers[i]; - - // Calculate imbalance fee, update xp with fees - uint dx; - for (uint j; j < N; ++j) { - if (j == i) { - dx = (xp[j] * d1) / d0 - y0; - } else { - // d1 / d0 <= 1 - dx = xp[j] - (xp[j] * d1) / d0; - } - xp[j] -= (LIQUIDITY_FEE * dx) / FEE_DENOMINATOR; - } - - // Recalculate y with xp including imbalance fees - uint y1 = _getYD(i, xp, d1); - // - 1 to round down - dy = (xp[i] - y1 - 1) / multipliers[i]; - fee = dy0 - dy; - } - - function calcWithdrawOneToken( - uint shares, - uint i - ) external view returns (uint dy, uint fee) { - return _calcWithdrawOneToken(shares, i); - } - - /** - * @notice Withdraw liquidity in token i - * @param shares Shares to burn - * @param i Token to withdraw - * @param minAmountOut Minimum amount of token i that must be withdrawn - */ - function removeLiquidityOneToken( - uint shares, - uint i, - uint minAmountOut - ) external returns (uint amountOut) { - (amountOut, ) = _calcWithdrawOneToken(shares, i); - require(amountOut >= minAmountOut, "out < min"); - - balances[i] -= amountOut; - _burn(msg.sender, shares); - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint amount); - event Approval(address indexed owner, address indexed spender, uint amount); -} diff --git a/tests/evm/sbe@stable-swap-amm-removeLiquidity2/StableSwap.sol b/tests/evm/sbe@stable-swap-amm-removeLiquidity2/StableSwap.sol deleted file mode 100644 index 9d16008dc..000000000 --- a/tests/evm/sbe@stable-swap-amm-removeLiquidity2/StableSwap.sol +++ /dev/null @@ -1,444 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8; -import "../../../solidity_utils/lib.sol"; - -/* -Invariant - price of trade and amount of liquidity are determined by this equation - -An^n sum(x_i) + D = ADn^n + D^(n + 1) / (n^n prod(x_i)) - -Topics -0. Newton's method x_(n + 1) = x_n - f(x_n) / f'(x_n) -1. Invariant -2. Swap - - Calculate Y - - Calculate D -3. Get virtual price -4. Add liquidity - - Imbalance fee -5. Remove liquidity -6. Remove liquidity one token - - Calculate withdraw one token - - getYD -TODO: test? -*/ - -library Math { - function abs(uint x, uint y) internal pure returns (uint) { - return x >= y ? x - y : y - x; - } -} - -contract StableSwap { - // Number of tokens - uint private constant N = 3; - // Amplification coefficient multiplied by N^(N - 1) - // Higher value makes the curve more flat - // Lower value makes the curve more like constant product AMM - uint private constant A = 1000 * (N ** (N - 1)); - // 0.03% - uint private constant SWAP_FEE = 300; - // Liquidity fee is derived from 2 constraints - // 1. Fee is 0 for adding / removing liquidity that results in a balanced pool - // 2. Swapping in a balanced pool is like adding and then removing liquidity - // from a balanced pool - // swap fee = add liquidity fee + remove liquidity fee - uint private constant LIQUIDITY_FEE = (SWAP_FEE * N) / (4 * (N - 1)); - uint private constant FEE_DENOMINATOR = 1e6; - - address[N] public tokens; - // Normalize each token to 18 decimals - // Example - DAI (18 decimals), USDC (6 decimals), USDT (6 decimals) - uint[N] private multipliers = [1, 1e12, 1e12]; - uint[N] public balances; - - // 1 share = 1e18, 18 decimals - uint private constant DECIMALS = 18; - uint public totalSupply; - mapping(address => uint) public balanceOf; - - function _mint(address _to, uint _amount) private { - balanceOf[_to] += _amount; - totalSupply += _amount; - } - - function _burn(address _from, uint _amount) private { - balanceOf[_from] -= _amount; - totalSupply -= _amount; - } - - // Return precision-adjusted balances, adjusted to 18 decimals - function _xp() private view returns (uint[N] memory xp) { - for (uint i; i < N; ++i) { - xp[i] = balances[i] * multipliers[i]; - } - } - - /** - * @notice Calculate D, sum of balances in a perfectly balanced pool - * If balances of x_0, x_1, ... x_(n-1) then sum(x_i) = D - * @param xp Precision-adjusted balances - * @return D - */ - function _getD(uint[N] memory xp) private pure returns (uint) { - /* - Newton's method to compute D - ----------------------------- - f(D) = ADn^n + D^(n + 1) / (n^n prod(x_i)) - An^n sum(x_i) - D - f'(D) = An^n + (n + 1) D^n / (n^n prod(x_i)) - 1 - - (as + np)D_n - D_(n+1) = ----------------------- - (a - 1)D_n + (n + 1)p - - a = An^n - s = sum(x_i) - p = (D_n)^(n + 1) / (n^n prod(x_i)) - */ - uint a = A * N; // An^n - - uint s; // x_0 + x_1 + ... + x_(n-1) - for (uint i; i < N; ++i) { - s += xp[i]; - } - - // Newton's method - // Initial guess, d <= s - uint d = s; - uint d_prev; - for (uint i; i < 255; ++i) { - // p = D^(n + 1) / (n^n * x_0 * ... * x_(n-1)) - uint p = d; - for (uint j; j < N; ++j) { - p = (p * d) / (N * xp[j]); - } - d_prev = d; - d = ((a * s + N * p) * d) / ((a - 1) * d + (N + 1) * p); - - if (Math.abs(d, d_prev) <= 1) { - return d; - } - } - revert("D didn't converge"); - } - - /** - * @notice Calculate the new balance of token j given the new balance of token i - * @param i Index of token in - * @param j Index of token out - * @param x New balance of token i - * @param xp Current precision-adjusted balances - */ - function _getY( - uint i, - uint j, - uint x, - uint[N] memory xp - ) private pure returns (uint) { - /* - Newton's method to compute y - ----------------------------- - y = x_j - - f(y) = y^2 + y(b - D) - c - - y_n^2 + c - y_(n+1) = -------------- - 2y_n + b - D - - where - s = sum(x_k), k != j - p = prod(x_k), k != j - b = s + D / (An^n) - c = D^(n + 1) / (n^n * p * An^n) - */ - uint a = A * N; - uint d = _getD(xp); - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k == i) { - _x = x; - } else if (k == j) { - continue; - } else { - _x = xp[k]; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - /** - * @notice Calculate the new balance of token i given precision-adjusted - * balances xp and liquidity d - * @dev Equation is calculate y is same as _getY - * @param i Index of token to calculate the new balance - * @param xp Precision-adjusted balances - * @param d Liquidity d - * @return New balance of token i - */ - function _getYD(uint i, uint[N] memory xp, uint d) private pure returns (uint) { - uint a = A * N; - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k != i) { - _x = xp[k]; - } else { - continue; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - // Estimate value of 1 share - // How many tokens is one share worth? - function getVirtualPrice() external view returns (uint) { - uint d = _getD(_xp()); - uint _totalSupply = totalSupply; - if (_totalSupply > 0) { - return (d * 10 ** DECIMALS) / _totalSupply; - } - return 0; - } - - /** - * @notice Swap dx amount of token i for token j - * @param i Index of token in - * @param j Index of token out - * @param dx Token in amount - * @param minDy Minimum token out - */ - function swap(uint i, uint j, uint dx, uint minDy) external returns (uint dy) { - require(i != j, "i = j"); - - IERC20(tokens[i]).transferFrom(msg.sender, address(this), dx); - - // Calculate dy - uint[N] memory xp = _xp(); - uint x = xp[i] + dx * multipliers[i]; - - uint y0 = xp[j]; - uint y1 = _getY(i, j, x, xp); - // y0 must be >= y1, since x has increased - // -1 to round down - dy = (y0 - y1 - 1) / multipliers[j]; - - // Subtract fee from dy - uint fee = (dy * SWAP_FEE) / FEE_DENOMINATOR; - dy -= fee; - require(dy >= minDy, "dy < min"); - - balances[i] += dx; - balances[j] -= dy; - - IERC20(tokens[j]).transfer(msg.sender, dy); - } - - function addLiquidity( - uint[N] calldata amounts, - uint minShares - ) external returns (uint shares) { - // calculate current liquidity d0 - uint _totalSupply = totalSupply; - uint d0; - uint[N] memory old_xs = _xp(); - if (_totalSupply > 0) { - d0 = _getD(old_xs); - } - - // Transfer tokens in - uint[N] memory new_xs; - for (uint i; i < N; ++i) { - uint amount = amounts[i]; - if (amount > 0) { - IERC20(tokens[i]).transferFrom(msg.sender, address(this), amount); - new_xs[i] = old_xs[i] + amount * multipliers[i]; - } else { - new_xs[i] = old_xs[i]; - } - } - - // Calculate new liquidity d1 - uint d1 = _getD(new_xs); - require(d1 > d0, "liquidity didn't increase"); - - // Reccalcuate D accounting for fee on imbalance - uint d2; - if (_totalSupply > 0) { - for (uint i; i < N; ++i) { - // TODO: why old_xs[i] * d1 / d0? why not d1 / N? - uint idealBalance = (old_xs[i] * d1) / d0; - uint diff = Math.abs(new_xs[i], idealBalance); - new_xs[i] -= (LIQUIDITY_FEE * diff) / FEE_DENOMINATOR; - } - - d2 = _getD(new_xs); - } else { - d2 = d1; - } - - // Update balances - for (uint i; i < N; ++i) { - balances[i] += amounts[i]; - } - - // Shares to mint = (d2 - d0) / d0 * total supply - // d1 >= d2 >= d0 - if (_totalSupply > 0) { - shares = ((d2 - d0) * _totalSupply) / d0; - } else { - shares = d2; - } - require(shares >= minShares, "shares < min"); - _mint(msg.sender, shares); - } - - function removeLiquidity( - uint shares, - uint[N] calldata minAmountsOut - ) external returns (uint[N] memory amountsOut) { - uint _totalSupply = totalSupply; - - for (uint i; i < N; ++i) { - uint amountOut = (balances[i] * shares) / _totalSupply; - require(amountOut >= minAmountsOut[i], "out < min"); - - balances[i] -= amountOut; - amountsOut[i] = amountOut; - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } - - _burn(msg.sender, shares); - bug(); - } - - /** - * @notice Calculate amount of token i to receive for shares - * @param shares Shares to burn - * @param i Index of token to withdraw - * @return dy Amount of token i to receive - * fee Fee for withdraw. Fee already included in dy - */ - function _calcWithdrawOneToken( - uint shares, - uint i - ) private view returns (uint dy, uint fee) { - uint _totalSupply = totalSupply; - uint[N] memory xp = _xp(); - - // Calculate d0 and d1 - uint d0 = _getD(xp); - uint d1 = d0 - (d0 * shares) / _totalSupply; - - // Calculate reduction in y if D = d1 - uint y0 = _getYD(i, xp, d1); - // d1 <= d0 so y must be <= xp[i] - uint dy0 = (xp[i] - y0) / multipliers[i]; - - // Calculate imbalance fee, update xp with fees - uint dx; - for (uint j; j < N; ++j) { - if (j == i) { - dx = (xp[j] * d1) / d0 - y0; - } else { - // d1 / d0 <= 1 - dx = xp[j] - (xp[j] * d1) / d0; - } - xp[j] -= (LIQUIDITY_FEE * dx) / FEE_DENOMINATOR; - } - - // Recalculate y with xp including imbalance fees - uint y1 = _getYD(i, xp, d1); - // - 1 to round down - dy = (xp[i] - y1 - 1) / multipliers[i]; - fee = dy0 - dy; - } - - function calcWithdrawOneToken( - uint shares, - uint i - ) external view returns (uint dy, uint fee) { - return _calcWithdrawOneToken(shares, i); - } - - /** - * @notice Withdraw liquidity in token i - * @param shares Shares to burn - * @param i Token to withdraw - * @param minAmountOut Minimum amount of token i that must be withdrawn - */ - function removeLiquidityOneToken( - uint shares, - uint i, - uint minAmountOut - ) external returns (uint amountOut) { - (amountOut, ) = _calcWithdrawOneToken(shares, i); - require(amountOut >= minAmountOut, "out < min"); - - balances[i] -= amountOut; - _burn(msg.sender, shares); - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint amount); - event Approval(address indexed owner, address indexed spender, uint amount); -} diff --git a/tests/evm/sbe@stable-swap-amm-removeLiquidityOneToken/StableSwap.sol b/tests/evm/sbe@stable-swap-amm-removeLiquidityOneToken/StableSwap.sol deleted file mode 100644 index 906fa06a8..000000000 --- a/tests/evm/sbe@stable-swap-amm-removeLiquidityOneToken/StableSwap.sol +++ /dev/null @@ -1,444 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8; -import "../../../solidity_utils/lib.sol"; - -/* -Invariant - price of trade and amount of liquidity are determined by this equation - -An^n sum(x_i) + D = ADn^n + D^(n + 1) / (n^n prod(x_i)) - -Topics -0. Newton's method x_(n + 1) = x_n - f(x_n) / f'(x_n) -1. Invariant -2. Swap - - Calculate Y - - Calculate D -3. Get virtual price -4. Add liquidity - - Imbalance fee -5. Remove liquidity -6. Remove liquidity one token - - Calculate withdraw one token - - getYD -TODO: test? -*/ - -library Math { - function abs(uint x, uint y) internal pure returns (uint) { - return x >= y ? x - y : y - x; - } -} - -contract StableSwap { - // Number of tokens - uint private constant N = 3; - // Amplification coefficient multiplied by N^(N - 1) - // Higher value makes the curve more flat - // Lower value makes the curve more like constant product AMM - uint private constant A = 1000 * (N ** (N - 1)); - // 0.03% - uint private constant SWAP_FEE = 300; - // Liquidity fee is derived from 2 constraints - // 1. Fee is 0 for adding / removing liquidity that results in a balanced pool - // 2. Swapping in a balanced pool is like adding and then removing liquidity - // from a balanced pool - // swap fee = add liquidity fee + remove liquidity fee - uint private constant LIQUIDITY_FEE = (SWAP_FEE * N) / (4 * (N - 1)); - uint private constant FEE_DENOMINATOR = 1e6; - - address[N] public tokens; - // Normalize each token to 18 decimals - // Example - DAI (18 decimals), USDC (6 decimals), USDT (6 decimals) - uint[N] private multipliers = [1, 1e12, 1e12]; - uint[N] public balances; - - // 1 share = 1e18, 18 decimals - uint private constant DECIMALS = 18; - uint public totalSupply; - mapping(address => uint) public balanceOf; - - function _mint(address _to, uint _amount) private { - balanceOf[_to] += _amount; - totalSupply += _amount; - } - - function _burn(address _from, uint _amount) private { - balanceOf[_from] -= _amount; - totalSupply -= _amount; - } - - // Return precision-adjusted balances, adjusted to 18 decimals - function _xp() private view returns (uint[N] memory xp) { - for (uint i; i < N; ++i) { - xp[i] = balances[i] * multipliers[i]; - } - } - - /** - * @notice Calculate D, sum of balances in a perfectly balanced pool - * If balances of x_0, x_1, ... x_(n-1) then sum(x_i) = D - * @param xp Precision-adjusted balances - * @return D - */ - function _getD(uint[N] memory xp) private pure returns (uint) { - /* - Newton's method to compute D - ----------------------------- - f(D) = ADn^n + D^(n + 1) / (n^n prod(x_i)) - An^n sum(x_i) - D - f'(D) = An^n + (n + 1) D^n / (n^n prod(x_i)) - 1 - - (as + np)D_n - D_(n+1) = ----------------------- - (a - 1)D_n + (n + 1)p - - a = An^n - s = sum(x_i) - p = (D_n)^(n + 1) / (n^n prod(x_i)) - */ - uint a = A * N; // An^n - - uint s; // x_0 + x_1 + ... + x_(n-1) - for (uint i; i < N; ++i) { - s += xp[i]; - } - - // Newton's method - // Initial guess, d <= s - uint d = s; - uint d_prev; - for (uint i; i < 255; ++i) { - // p = D^(n + 1) / (n^n * x_0 * ... * x_(n-1)) - uint p = d; - for (uint j; j < N; ++j) { - p = (p * d) / (N * xp[j]); - } - d_prev = d; - d = ((a * s + N * p) * d) / ((a - 1) * d + (N + 1) * p); - - if (Math.abs(d, d_prev) <= 1) { - return d; - } - } - revert("D didn't converge"); - } - - /** - * @notice Calculate the new balance of token j given the new balance of token i - * @param i Index of token in - * @param j Index of token out - * @param x New balance of token i - * @param xp Current precision-adjusted balances - */ - function _getY( - uint i, - uint j, - uint x, - uint[N] memory xp - ) private pure returns (uint) { - /* - Newton's method to compute y - ----------------------------- - y = x_j - - f(y) = y^2 + y(b - D) - c - - y_n^2 + c - y_(n+1) = -------------- - 2y_n + b - D - - where - s = sum(x_k), k != j - p = prod(x_k), k != j - b = s + D / (An^n) - c = D^(n + 1) / (n^n * p * An^n) - */ - uint a = A * N; - uint d = _getD(xp); - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k == i) { - _x = x; - } else if (k == j) { - continue; - } else { - _x = xp[k]; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - /** - * @notice Calculate the new balance of token i given precision-adjusted - * balances xp and liquidity d - * @dev Equation is calculate y is same as _getY - * @param i Index of token to calculate the new balance - * @param xp Precision-adjusted balances - * @param d Liquidity d - * @return New balance of token i - */ - function _getYD(uint i, uint[N] memory xp, uint d) private pure returns (uint) { - uint a = A * N; - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k != i) { - _x = xp[k]; - } else { - continue; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - // Estimate value of 1 share - // How many tokens is one share worth? - function getVirtualPrice() external view returns (uint) { - uint d = _getD(_xp()); - uint _totalSupply = totalSupply; - if (_totalSupply > 0) { - return (d * 10 ** DECIMALS) / _totalSupply; - } - return 0; - } - - /** - * @notice Swap dx amount of token i for token j - * @param i Index of token in - * @param j Index of token out - * @param dx Token in amount - * @param minDy Minimum token out - */ - function swap(uint i, uint j, uint dx, uint minDy) external returns (uint dy) { - require(i != j, "i = j"); - - IERC20(tokens[i]).transferFrom(msg.sender, address(this), dx); - - // Calculate dy - uint[N] memory xp = _xp(); - uint x = xp[i] + dx * multipliers[i]; - - uint y0 = xp[j]; - uint y1 = _getY(i, j, x, xp); - // y0 must be >= y1, since x has increased - // -1 to round down - dy = (y0 - y1 - 1) / multipliers[j]; - - // Subtract fee from dy - uint fee = (dy * SWAP_FEE) / FEE_DENOMINATOR; - dy -= fee; - require(dy >= minDy, "dy < min"); - - balances[i] += dx; - balances[j] -= dy; - - IERC20(tokens[j]).transfer(msg.sender, dy); - } - - function addLiquidity( - uint[N] calldata amounts, - uint minShares - ) external returns (uint shares) { - // calculate current liquidity d0 - uint _totalSupply = totalSupply; - uint d0; - uint[N] memory old_xs = _xp(); - if (_totalSupply > 0) { - d0 = _getD(old_xs); - } - - // Transfer tokens in - uint[N] memory new_xs; - for (uint i; i < N; ++i) { - uint amount = amounts[i]; - if (amount > 0) { - IERC20(tokens[i]).transferFrom(msg.sender, address(this), amount); - new_xs[i] = old_xs[i] + amount * multipliers[i]; - } else { - new_xs[i] = old_xs[i]; - } - } - - // Calculate new liquidity d1 - uint d1 = _getD(new_xs); - require(d1 > d0, "liquidity didn't increase"); - - // Reccalcuate D accounting for fee on imbalance - uint d2; - if (_totalSupply > 0) { - for (uint i; i < N; ++i) { - // TODO: why old_xs[i] * d1 / d0? why not d1 / N? - uint idealBalance = (old_xs[i] * d1) / d0; - uint diff = Math.abs(new_xs[i], idealBalance); - new_xs[i] -= (LIQUIDITY_FEE * diff) / FEE_DENOMINATOR; - } - - d2 = _getD(new_xs); - } else { - d2 = d1; - } - - // Update balances - for (uint i; i < N; ++i) { - balances[i] += amounts[i]; - } - - // Shares to mint = (d2 - d0) / d0 * total supply - // d1 >= d2 >= d0 - if (_totalSupply > 0) { - shares = ((d2 - d0) * _totalSupply) / d0; - } else { - shares = d2; - } - require(shares >= minShares, "shares < min"); - _mint(msg.sender, shares); - } - - function removeLiquidity( - uint shares, - uint[N] calldata minAmountsOut - ) external returns (uint[N] memory amountsOut) { - uint _totalSupply = totalSupply; - - for (uint i; i < N; ++i) { - uint amountOut = (balances[i] * shares) / _totalSupply; - require(amountOut >= minAmountsOut[i], "out < min"); - - balances[i] -= amountOut; - amountsOut[i] = amountOut; - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } - - _burn(msg.sender, shares); - } - - /** - * @notice Calculate amount of token i to receive for shares - * @param shares Shares to burn - * @param i Index of token to withdraw - * @return dy Amount of token i to receive - * fee Fee for withdraw. Fee already included in dy - */ - function _calcWithdrawOneToken( - uint shares, - uint i - ) private view returns (uint dy, uint fee) { - uint _totalSupply = totalSupply; - uint[N] memory xp = _xp(); - - // Calculate d0 and d1 - uint d0 = _getD(xp); - uint d1 = d0 - (d0 * shares) / _totalSupply; - - // Calculate reduction in y if D = d1 - uint y0 = _getYD(i, xp, d1); - // d1 <= d0 so y must be <= xp[i] - uint dy0 = (xp[i] - y0) / multipliers[i]; - - // Calculate imbalance fee, update xp with fees - uint dx; - for (uint j; j < N; ++j) { - if (j == i) { - dx = (xp[j] * d1) / d0 - y0; - } else { - // d1 / d0 <= 1 - dx = xp[j] - (xp[j] * d1) / d0; - } - xp[j] -= (LIQUIDITY_FEE * dx) / FEE_DENOMINATOR; - } - - // Recalculate y with xp including imbalance fees - uint y1 = _getYD(i, xp, d1); - // - 1 to round down - dy = (xp[i] - y1 - 1) / multipliers[i]; - fee = dy0 - dy; - } - - function calcWithdrawOneToken( - uint shares, - uint i - ) external view returns (uint dy, uint fee) { - return _calcWithdrawOneToken(shares, i); - } - - /** - * @notice Withdraw liquidity in token i - * @param shares Shares to burn - * @param i Token to withdraw - * @param minAmountOut Minimum amount of token i that must be withdrawn - */ - function removeLiquidityOneToken( - uint shares, - uint i, - uint minAmountOut - ) external returns (uint amountOut) { - (amountOut, ) = _calcWithdrawOneToken(shares, i); - require(amountOut >= minAmountOut, "out < min"); - - balances[i] -= amountOut; - _burn(msg.sender, shares); - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - bug(); - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint amount); - event Approval(address indexed owner, address indexed spender, uint amount); -} diff --git a/tests/evm/sbe@stable-swap-amm-swap/StableSwap.sol b/tests/evm/sbe@stable-swap-amm-swap/StableSwap.sol deleted file mode 100644 index e4d674a7e..000000000 --- a/tests/evm/sbe@stable-swap-amm-swap/StableSwap.sol +++ /dev/null @@ -1,444 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8; -import "../../../solidity_utils/lib.sol"; - -/* -Invariant - price of trade and amount of liquidity are determined by this equation - -An^n sum(x_i) + D = ADn^n + D^(n + 1) / (n^n prod(x_i)) - -Topics -0. Newton's method x_(n + 1) = x_n - f(x_n) / f'(x_n) -1. Invariant -2. Swap - - Calculate Y - - Calculate D -3. Get virtual price -4. Add liquidity - - Imbalance fee -5. Remove liquidity -6. Remove liquidity one token - - Calculate withdraw one token - - getYD -TODO: test? -*/ - -library Math { - function abs(uint x, uint y) internal pure returns (uint) { - return x >= y ? x - y : y - x; - } -} - -contract StableSwap { - // Number of tokens - uint private constant N = 3; - // Amplification coefficient multiplied by N^(N - 1) - // Higher value makes the curve more flat - // Lower value makes the curve more like constant product AMM - uint private constant A = 1000 * (N ** (N - 1)); - // 0.03% - uint private constant SWAP_FEE = 300; - // Liquidity fee is derived from 2 constraints - // 1. Fee is 0 for adding / removing liquidity that results in a balanced pool - // 2. Swapping in a balanced pool is like adding and then removing liquidity - // from a balanced pool - // swap fee = add liquidity fee + remove liquidity fee - uint private constant LIQUIDITY_FEE = (SWAP_FEE * N) / (4 * (N - 1)); - uint private constant FEE_DENOMINATOR = 1e6; - - address[N] public tokens; - // Normalize each token to 18 decimals - // Example - DAI (18 decimals), USDC (6 decimals), USDT (6 decimals) - uint[N] private multipliers = [1, 1e12, 1e12]; - uint[N] public balances; - - // 1 share = 1e18, 18 decimals - uint private constant DECIMALS = 18; - uint public totalSupply; - mapping(address => uint) public balanceOf; - - function _mint(address _to, uint _amount) private { - balanceOf[_to] += _amount; - totalSupply += _amount; - } - - function _burn(address _from, uint _amount) private { - balanceOf[_from] -= _amount; - totalSupply -= _amount; - } - - // Return precision-adjusted balances, adjusted to 18 decimals - function _xp() private view returns (uint[N] memory xp) { - for (uint i; i < N; ++i) { - xp[i] = balances[i] * multipliers[i]; - } - } - - /** - * @notice Calculate D, sum of balances in a perfectly balanced pool - * If balances of x_0, x_1, ... x_(n-1) then sum(x_i) = D - * @param xp Precision-adjusted balances - * @return D - */ - function _getD(uint[N] memory xp) private pure returns (uint) { - /* - Newton's method to compute D - ----------------------------- - f(D) = ADn^n + D^(n + 1) / (n^n prod(x_i)) - An^n sum(x_i) - D - f'(D) = An^n + (n + 1) D^n / (n^n prod(x_i)) - 1 - - (as + np)D_n - D_(n+1) = ----------------------- - (a - 1)D_n + (n + 1)p - - a = An^n - s = sum(x_i) - p = (D_n)^(n + 1) / (n^n prod(x_i)) - */ - uint a = A * N; // An^n - - uint s; // x_0 + x_1 + ... + x_(n-1) - for (uint i; i < N; ++i) { - s += xp[i]; - } - - // Newton's method - // Initial guess, d <= s - uint d = s; - uint d_prev; - for (uint i; i < 255; ++i) { - // p = D^(n + 1) / (n^n * x_0 * ... * x_(n-1)) - uint p = d; - for (uint j; j < N; ++j) { - p = (p * d) / (N * xp[j]); - } - d_prev = d; - d = ((a * s + N * p) * d) / ((a - 1) * d + (N + 1) * p); - - if (Math.abs(d, d_prev) <= 1) { - return d; - } - } - revert("D didn't converge"); - } - - /** - * @notice Calculate the new balance of token j given the new balance of token i - * @param i Index of token in - * @param j Index of token out - * @param x New balance of token i - * @param xp Current precision-adjusted balances - */ - function _getY( - uint i, - uint j, - uint x, - uint[N] memory xp - ) private pure returns (uint) { - /* - Newton's method to compute y - ----------------------------- - y = x_j - - f(y) = y^2 + y(b - D) - c - - y_n^2 + c - y_(n+1) = -------------- - 2y_n + b - D - - where - s = sum(x_k), k != j - p = prod(x_k), k != j - b = s + D / (An^n) - c = D^(n + 1) / (n^n * p * An^n) - */ - uint a = A * N; - uint d = _getD(xp); - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k == i) { - _x = x; - } else if (k == j) { - continue; - } else { - _x = xp[k]; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - /** - * @notice Calculate the new balance of token i given precision-adjusted - * balances xp and liquidity d - * @dev Equation is calculate y is same as _getY - * @param i Index of token to calculate the new balance - * @param xp Precision-adjusted balances - * @param d Liquidity d - * @return New balance of token i - */ - function _getYD(uint i, uint[N] memory xp, uint d) private pure returns (uint) { - uint a = A * N; - uint s; - uint c = d; - - uint _x; - for (uint k; k < N; ++k) { - if (k != i) { - _x = xp[k]; - } else { - continue; - } - - s += _x; - c = (c * d) / (N * _x); - } - c = (c * d) / (N * a); - uint b = s + d / a; - - // Newton's method - uint y_prev; - // Initial guess, y <= d - uint y = d; - for (uint _i; _i < 255; ++_i) { - y_prev = y; - y = (y * y + c) / (2 * y + b - d); - if (Math.abs(y, y_prev) <= 1) { - return y; - } - } - revert("y didn't converge"); - } - - // Estimate value of 1 share - // How many tokens is one share worth? - function getVirtualPrice() external view returns (uint) { - uint d = _getD(_xp()); - uint _totalSupply = totalSupply; - if (_totalSupply > 0) { - return (d * 10 ** DECIMALS) / _totalSupply; - } - return 0; - } - - /** - * @notice Swap dx amount of token i for token j - * @param i Index of token in - * @param j Index of token out - * @param dx Token in amount - * @param minDy Minimum token out - */ - function swap(uint i, uint j, uint dx, uint minDy) external returns (uint dy) { - require(i != j, "i = j"); - - IERC20(tokens[i]).transferFrom(msg.sender, address(this), dx); - - // Calculate dy - uint[N] memory xp = _xp(); - uint x = xp[i] + dx * multipliers[i]; - - uint y0 = xp[j]; - uint y1 = _getY(i, j, x, xp); - // y0 must be >= y1, since x has increased - // -1 to round down - dy = (y0 - y1 - 1) / multipliers[j]; - - // Subtract fee from dy - uint fee = (dy * SWAP_FEE) / FEE_DENOMINATOR; - dy -= fee; - require(dy >= minDy, "dy < min"); - - balances[i] += dx; - balances[j] -= dy; - - IERC20(tokens[j]).transfer(msg.sender, dy); - bug(); - } - - function addLiquidity( - uint[N] calldata amounts, - uint minShares - ) external returns (uint shares) { - // calculate current liquidity d0 - uint _totalSupply = totalSupply; - uint d0; - uint[N] memory old_xs = _xp(); - if (_totalSupply > 0) { - d0 = _getD(old_xs); - } - - // Transfer tokens in - uint[N] memory new_xs; - for (uint i; i < N; ++i) { - uint amount = amounts[i]; - if (amount > 0) { - IERC20(tokens[i]).transferFrom(msg.sender, address(this), amount); - new_xs[i] = old_xs[i] + amount * multipliers[i]; - } else { - new_xs[i] = old_xs[i]; - } - } - - // Calculate new liquidity d1 - uint d1 = _getD(new_xs); - require(d1 > d0, "liquidity didn't increase"); - - // Reccalcuate D accounting for fee on imbalance - uint d2; - if (_totalSupply > 0) { - for (uint i; i < N; ++i) { - // TODO: why old_xs[i] * d1 / d0? why not d1 / N? - uint idealBalance = (old_xs[i] * d1) / d0; - uint diff = Math.abs(new_xs[i], idealBalance); - new_xs[i] -= (LIQUIDITY_FEE * diff) / FEE_DENOMINATOR; - } - - d2 = _getD(new_xs); - } else { - d2 = d1; - } - - // Update balances - for (uint i; i < N; ++i) { - balances[i] += amounts[i]; - } - - // Shares to mint = (d2 - d0) / d0 * total supply - // d1 >= d2 >= d0 - if (_totalSupply > 0) { - shares = ((d2 - d0) * _totalSupply) / d0; - } else { - shares = d2; - } - require(shares >= minShares, "shares < min"); - _mint(msg.sender, shares); - } - - function removeLiquidity( - uint shares, - uint[N] calldata minAmountsOut - ) external returns (uint[N] memory amountsOut) { - uint _totalSupply = totalSupply; - - for (uint i; i < N; ++i) { - uint amountOut = (balances[i] * shares) / _totalSupply; - require(amountOut >= minAmountsOut[i], "out < min"); - - balances[i] -= amountOut; - amountsOut[i] = amountOut; - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } - - _burn(msg.sender, shares); - } - - /** - * @notice Calculate amount of token i to receive for shares - * @param shares Shares to burn - * @param i Index of token to withdraw - * @return dy Amount of token i to receive - * fee Fee for withdraw. Fee already included in dy - */ - function _calcWithdrawOneToken( - uint shares, - uint i - ) private view returns (uint dy, uint fee) { - uint _totalSupply = totalSupply; - uint[N] memory xp = _xp(); - - // Calculate d0 and d1 - uint d0 = _getD(xp); - uint d1 = d0 - (d0 * shares) / _totalSupply; - - // Calculate reduction in y if D = d1 - uint y0 = _getYD(i, xp, d1); - // d1 <= d0 so y must be <= xp[i] - uint dy0 = (xp[i] - y0) / multipliers[i]; - - // Calculate imbalance fee, update xp with fees - uint dx; - for (uint j; j < N; ++j) { - if (j == i) { - dx = (xp[j] * d1) / d0 - y0; - } else { - // d1 / d0 <= 1 - dx = xp[j] - (xp[j] * d1) / d0; - } - xp[j] -= (LIQUIDITY_FEE * dx) / FEE_DENOMINATOR; - } - - // Recalculate y with xp including imbalance fees - uint y1 = _getYD(i, xp, d1); - // - 1 to round down - dy = (xp[i] - y1 - 1) / multipliers[i]; - fee = dy0 - dy; - } - - function calcWithdrawOneToken( - uint shares, - uint i - ) external view returns (uint dy, uint fee) { - return _calcWithdrawOneToken(shares, i); - } - - /** - * @notice Withdraw liquidity in token i - * @param shares Shares to burn - * @param i Token to withdraw - * @param minAmountOut Minimum amount of token i that must be withdrawn - */ - function removeLiquidityOneToken( - uint shares, - uint i, - uint minAmountOut - ) external returns (uint amountOut) { - (amountOut, ) = _calcWithdrawOneToken(shares, i); - require(amountOut >= minAmountOut, "out < min"); - - balances[i] -= amountOut; - _burn(msg.sender, shares); - - IERC20(tokens[i]).transfer(msg.sender, amountOut); - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint amount); - event Approval(address indexed owner, address indexed spender, uint amount); -} diff --git a/tests/evm/sbe@staking-rewards-stake/StakingRewards.sol b/tests/evm/sbe@staking-rewards-stake/StakingRewards.sol deleted file mode 100644 index 6f55beacf..000000000 --- a/tests/evm/sbe@staking-rewards-stake/StakingRewards.sol +++ /dev/null @@ -1,148 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8; -import "../../../solidity_utils/lib.sol"; - -contract StakingRewards { - IERC20 public immutable stakingToken; - IERC20 public immutable rewardsToken; - - address public owner; - - // Duration of rewards to be paid out (in seconds) - uint public duration; - // Timestamp of when the rewards finish - uint public finishAt; - // Minimum of last updated time and reward finish time - uint public updatedAt; - // Reward to be paid out per second - uint public rewardRate; - // Sum of (reward rate * dt * 1e18 / total supply) - uint public rewardPerTokenStored; - // User address => rewardPerTokenStored - mapping(address => uint) public userRewardPerTokenPaid; - // User address => rewards to be claimed - mapping(address => uint) public rewards; - - // Total staked - uint public totalSupply; - // User address => staked amount - mapping(address => uint) public balanceOf; - - constructor(address _stakingToken, address _rewardToken) { - owner = msg.sender; - stakingToken = IERC20(_stakingToken); - rewardsToken = IERC20(_rewardToken); - } - - modifier onlyOwner() { - require(msg.sender == owner, "not authorized"); - _; - } - - modifier updateReward(address _account) { - rewardPerTokenStored = rewardPerToken(); - updatedAt = lastTimeRewardApplicable(); - - if (_account != address(0)) { - rewards[_account] = earned(_account); - userRewardPerTokenPaid[_account] = rewardPerTokenStored; - } - - _; - } - - function lastTimeRewardApplicable() public view returns (uint) { - return _min(finishAt, block.timestamp); - } - - function rewardPerToken() public view returns (uint) { - if (totalSupply == 0) { - return rewardPerTokenStored; - } - - return - rewardPerTokenStored + - (rewardRate * (lastTimeRewardApplicable() - updatedAt) * 1e18) / - totalSupply; - } - - function stake(uint _amount) external updateReward(msg.sender) { - require(_amount > 0, "amount = 0"); - stakingToken.transferFrom(msg.sender, address(this), _amount); - balanceOf[msg.sender] += _amount; - totalSupply += _amount; - bug(); - } - - function withdraw(uint _amount) external updateReward(msg.sender) { - require(_amount > 0, "amount = 0"); - balanceOf[msg.sender] -= _amount; - totalSupply -= _amount; - stakingToken.transfer(msg.sender, _amount); - } - - function earned(address _account) public view returns (uint) { - return - ((balanceOf[_account] * - (rewardPerToken() - userRewardPerTokenPaid[_account])) / 1e18) + - rewards[_account]; - } - - function getReward() external updateReward(msg.sender) { - uint reward = rewards[msg.sender]; - if (reward > 0) { - rewards[msg.sender] = 0; - rewardsToken.transfer(msg.sender, reward); - } - } - - function setRewardsDuration(uint _duration) external onlyOwner { - require(finishAt < block.timestamp, "reward duration not finished"); - duration = _duration; - } - - function notifyRewardAmount( - uint _amount - ) external onlyOwner updateReward(address(0)) { - if (block.timestamp >= finishAt) { - rewardRate = _amount / duration; - } else { - uint remainingRewards = (finishAt - block.timestamp) * rewardRate; - rewardRate = (_amount + remainingRewards) / duration; - } - - require(rewardRate > 0, "reward rate = 0"); - require( - rewardRate * duration <= rewardsToken.balanceOf(address(this)), - "reward amount > balance" - ); - - finishAt = block.timestamp + duration; - updatedAt = block.timestamp; - } - - function _min(uint x, uint y) private pure returns (uint) { - return x <= y ? x : y; - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint value); - event Approval(address indexed owner, address indexed spender, uint value); -} diff --git a/tests/evm/sbe@vault-_burn/Vault.sol b/tests/evm/sbe@vault-_burn/Vault.sol deleted file mode 100644 index dcb26023d..000000000 --- a/tests/evm/sbe@vault-_burn/Vault.sol +++ /dev/null @@ -1,84 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.13; -import "../../../solidity_utils/lib.sol"; - -contract Vault { - IERC20 public immutable token; - - uint public totalSupply; - mapping(address => uint) public balanceOf; - - constructor(address _token) { - token = IERC20(_token); - } - - function _mint(address _to, uint _shares) private { - totalSupply += _shares; - balanceOf[_to] += _shares; - } - - function _burn(address _from, uint _shares) private { - totalSupply -= _shares; - balanceOf[_from] -= _shares; - bug(); - } - - function deposit(uint _amount) external { - /* - a = amount - B = balance of token before deposit - T = total supply - s = shares to mint - - (T + s) / T = (a + B) / B - - s = aT / B - */ - uint shares; - if (totalSupply == 0) { - shares = _amount; - } else { - shares = (_amount * totalSupply) / token.balanceOf(address(this)); - } - - _mint(msg.sender, shares); - token.transferFrom(msg.sender, address(this), _amount); - } - - function withdraw(uint _shares) external { - /* - a = amount - B = balance of token before withdraw - T = total supply - s = shares to burn - - (T - s) / T = (B - a) / B - - a = sB / T - */ - uint amount = (_shares * token.balanceOf(address(this))) / totalSupply; - _burn(msg.sender, _shares); - token.transfer(msg.sender, amount); - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint amount); - event Approval(address indexed owner, address indexed spender, uint amount); -} diff --git a/tests/evm/sbe@vault-deposit/Vault.sol b/tests/evm/sbe@vault-deposit/Vault.sol deleted file mode 100644 index 8c0e4a6f7..000000000 --- a/tests/evm/sbe@vault-deposit/Vault.sol +++ /dev/null @@ -1,84 +0,0 @@ -// SPDX-License-Identifier: MIT -pragma solidity ^0.8.13; -import "../../../solidity_utils/lib.sol"; - -contract Vault { - IERC20 public immutable token; - - uint public totalSupply; - mapping(address => uint) public balanceOf; - - constructor(address _token) { - token = IERC20(_token); - } - - function _mint(address _to, uint _shares) private { - totalSupply += _shares; - balanceOf[_to] += _shares; - } - - function _burn(address _from, uint _shares) private { - totalSupply -= _shares; - balanceOf[_from] -= _shares; - } - - function deposit(uint _amount) external { - /* - a = amount - B = balance of token before deposit - T = total supply - s = shares to mint - - (T + s) / T = (a + B) / B - - s = aT / B - */ - uint shares; - if (totalSupply == 0) { - shares = _amount; - } else { - shares = (_amount * totalSupply) / token.balanceOf(address(this)); - bug(); - } - - _mint(msg.sender, shares); - token.transferFrom(msg.sender, address(this), _amount); - } - - function withdraw(uint _shares) external { - /* - a = amount - B = balance of token before withdraw - T = total supply - s = shares to burn - - (T - s) / T = (B - a) / B - - a = sB / T - */ - uint amount = (_shares * token.balanceOf(address(this))) / totalSupply; - _burn(msg.sender, _shares); - token.transfer(msg.sender, amount); - } -} - -interface IERC20 { - function totalSupply() external view returns (uint); - - function balanceOf(address account) external view returns (uint); - - function transfer(address recipient, uint amount) external returns (bool); - - function allowance(address owner, address spender) external view returns (uint); - - function approve(address spender, uint amount) external returns (bool); - - function transferFrom( - address sender, - address recipient, - uint amount - ) external returns (bool); - - event Transfer(address indexed from, address indexed to, uint amount); - event Approval(address indexed owner, address indexed spender, uint amount); -} diff --git a/tests/evm/out-of-memory/test.sol b/tests/evm_never_pass/out-of-memory/test.sol similarity index 100% rename from tests/evm/out-of-memory/test.sol rename to tests/evm_never_pass/out-of-memory/test.sol