diff --git a/domain/mocks/wasm_client.go b/domain/mocks/wasm_client.go new file mode 100644 index 00000000..e14eaefb --- /dev/null +++ b/domain/mocks/wasm_client.go @@ -0,0 +1,116 @@ +package mocks + +import ( + "context" + + wasmtypes "github.com/CosmWasm/wasmd/x/wasm/types" + + "google.golang.org/grpc" +) + +type WasmClient struct { + ContractInfoFunc func(ctx context.Context, in *wasmtypes.QueryContractInfoRequest, opts ...grpc.CallOption) (*wasmtypes.QueryContractInfoResponse, error) + ContractHistoryFunc func(ctx context.Context, in *wasmtypes.QueryContractHistoryRequest, opts ...grpc.CallOption) (*wasmtypes.QueryContractHistoryResponse, error) + ContractsByCodeFunc func(ctx context.Context, in *wasmtypes.QueryContractsByCodeRequest, opts ...grpc.CallOption) (*wasmtypes.QueryContractsByCodeResponse, error) + AllContractStateFunc func(ctx context.Context, in *wasmtypes.QueryAllContractStateRequest, opts ...grpc.CallOption) (*wasmtypes.QueryAllContractStateResponse, error) + RawContractStateFunc func(ctx context.Context, in *wasmtypes.QueryRawContractStateRequest, opts ...grpc.CallOption) (*wasmtypes.QueryRawContractStateResponse, error) + SmartContractStateFunc func(ctx context.Context, in *wasmtypes.QuerySmartContractStateRequest, opts ...grpc.CallOption) (*wasmtypes.QuerySmartContractStateResponse, error) + CodeFunc func(ctx context.Context, in *wasmtypes.QueryCodeRequest, opts ...grpc.CallOption) (*wasmtypes.QueryCodeResponse, error) + CodesFunc func(ctx context.Context, in *wasmtypes.QueryCodesRequest, opts ...grpc.CallOption) (*wasmtypes.QueryCodesResponse, error) + PinnedCodesFunc func(ctx context.Context, in *wasmtypes.QueryPinnedCodesRequest, opts ...grpc.CallOption) (*wasmtypes.QueryPinnedCodesResponse, error) + ParamsFunc func(ctx context.Context, in *wasmtypes.QueryParamsRequest, opts ...grpc.CallOption) (*wasmtypes.QueryParamsResponse, error) + ContractsByCreatorFunc func(ctx context.Context, in *wasmtypes.QueryContractsByCreatorRequest, opts ...grpc.CallOption) (*wasmtypes.QueryContractsByCreatorResponse, error) + BuildAddressFunc func(ctx context.Context, in *wasmtypes.QueryBuildAddressRequest, opts ...grpc.CallOption) (*wasmtypes.QueryBuildAddressResponse, error) +} + +func (m *WasmClient) ContractInfo(ctx context.Context, in *wasmtypes.QueryContractInfoRequest, opts ...grpc.CallOption) (*wasmtypes.QueryContractInfoResponse, error) { + if m.ContractInfoFunc != nil { + return m.ContractInfoFunc(ctx, in, opts...) + } + panic("MockQueryClient.ContractInfo unimplemented") +} + +func (m *WasmClient) ContractHistory(ctx context.Context, in *wasmtypes.QueryContractHistoryRequest, opts ...grpc.CallOption) (*wasmtypes.QueryContractHistoryResponse, error) { + if m.ContractHistoryFunc != nil { + return m.ContractHistoryFunc(ctx, in, opts...) + } + panic("MockQueryClient.ContractHistory unimplemented") +} + +func (m *WasmClient) ContractsByCode(ctx context.Context, in *wasmtypes.QueryContractsByCodeRequest, opts ...grpc.CallOption) (*wasmtypes.QueryContractsByCodeResponse, error) { + if m.ContractsByCodeFunc != nil { + return m.ContractsByCodeFunc(ctx, in, opts...) + } + panic("MockQueryClient.ContractsByCode unimplemented") +} + +func (m *WasmClient) AllContractState(ctx context.Context, in *wasmtypes.QueryAllContractStateRequest, opts ...grpc.CallOption) (*wasmtypes.QueryAllContractStateResponse, error) { + if m.AllContractStateFunc != nil { + return m.AllContractStateFunc(ctx, in, opts...) + } + panic("MockQueryClient.AllContractState unimplemented") +} + +func (m *WasmClient) RawContractState(ctx context.Context, in *wasmtypes.QueryRawContractStateRequest, opts ...grpc.CallOption) (*wasmtypes.QueryRawContractStateResponse, error) { + if m.RawContractStateFunc != nil { + return m.RawContractStateFunc(ctx, in, opts...) + } + panic("MockQueryClient.RawContractState unimplemented") +} + +func (m *WasmClient) SmartContractState(ctx context.Context, in *wasmtypes.QuerySmartContractStateRequest, opts ...grpc.CallOption) (*wasmtypes.QuerySmartContractStateResponse, error) { + if m.SmartContractStateFunc != nil { + return m.SmartContractStateFunc(ctx, in, opts...) + } + panic("MockQueryClient.SmartContractState unimplemented") +} + +func (m *WasmClient) WithSmartContractState(data wasmtypes.RawContractMessage, err error) { + m.SmartContractStateFunc = func(ctx context.Context, in *wasmtypes.QuerySmartContractStateRequest, opts ...grpc.CallOption) (*wasmtypes.QuerySmartContractStateResponse, error) { + return &wasmtypes.QuerySmartContractStateResponse{ + Data: data, + }, err + } +} + +func (m *WasmClient) Code(ctx context.Context, in *wasmtypes.QueryCodeRequest, opts ...grpc.CallOption) (*wasmtypes.QueryCodeResponse, error) { + if m.CodeFunc != nil { + return m.CodeFunc(ctx, in, opts...) + } + panic("MockQueryClient.Code unimplemented") +} + +func (m *WasmClient) Codes(ctx context.Context, in *wasmtypes.QueryCodesRequest, opts ...grpc.CallOption) (*wasmtypes.QueryCodesResponse, error) { + if m.CodesFunc != nil { + return m.CodesFunc(ctx, in, opts...) + } + panic("MockQueryClient.Codes unimplemented") +} + +func (m *WasmClient) PinnedCodes(ctx context.Context, in *wasmtypes.QueryPinnedCodesRequest, opts ...grpc.CallOption) (*wasmtypes.QueryPinnedCodesResponse, error) { + if m.PinnedCodesFunc != nil { + return m.PinnedCodesFunc(ctx, in, opts...) + } + panic("MockQueryClient.PinnedCodes unimplemented") +} + +func (m *WasmClient) Params(ctx context.Context, in *wasmtypes.QueryParamsRequest, opts ...grpc.CallOption) (*wasmtypes.QueryParamsResponse, error) { + if m.ParamsFunc != nil { + return m.ParamsFunc(ctx, in, opts...) + } + panic("MockQueryClient.Params unimplemented") +} + +func (m *WasmClient) ContractsByCreator(ctx context.Context, in *wasmtypes.QueryContractsByCreatorRequest, opts ...grpc.CallOption) (*wasmtypes.QueryContractsByCreatorResponse, error) { + if m.ContractsByCreatorFunc != nil { + return m.ContractsByCreatorFunc(ctx, in, opts...) + } + panic("MockQueryClient.ContractsByCreator unimplemented") +} + +func (m *WasmClient) BuildAddress(ctx context.Context, in *wasmtypes.QueryBuildAddressRequest, opts ...grpc.CallOption) (*wasmtypes.QueryBuildAddressResponse, error) { + if m.BuildAddressFunc != nil { + return m.BuildAddressFunc(ctx, in, opts...) + } + panic("MockQueryClient.BuildAddress unimplemented") +} diff --git a/router/usecase/pools/routable_concentrated_pool.go b/router/usecase/pools/routable_concentrated_pool.go index a9543184..cc867c38 100644 --- a/router/usecase/pools/routable_concentrated_pool.go +++ b/router/usecase/pools/routable_concentrated_pool.go @@ -339,9 +339,9 @@ func (r *routableConcentratedPoolImpl) ChargeTakerFeeExactIn(tokenIn sdk.Coin) ( // ChargeTakerFeeExactOut implements domain.RoutablePool. // Charges the taker fee for the given token out and returns the token out after the fee has been charged. -func (r *routableConcentratedPoolImpl) ChargeTakerFeeExactOut(tokenOut sdk.Coin) (tokenOutAfterFee sdk.Coin) { - tokenOutAfterTakerFee, _ := poolmanager.CalcTakerFeeExactOut(tokenOut, r.GetTakerFee()) - return tokenOutAfterTakerFee +func (r *routableConcentratedPoolImpl) ChargeTakerFeeExactOut(tokenIn sdk.Coin) (tokenOutAfterFee sdk.Coin) { + tokenInAfterTakerFee, _ := poolmanager.CalcTakerFeeExactOut(tokenIn, r.GetTakerFee()) + return tokenInAfterTakerFee } // SetTokenInDenom implements domain.RoutablePool. diff --git a/router/usecase/pools/routable_cw_pool.go b/router/usecase/pools/routable_cw_pool.go index 19a0eac3..7b99a417 100644 --- a/router/usecase/pools/routable_cw_pool.go +++ b/router/usecase/pools/routable_cw_pool.go @@ -2,7 +2,6 @@ package pools import ( "context" - "errors" "fmt" "cosmossdk.io/math" @@ -86,8 +85,35 @@ func (r *routableCosmWasmPoolImpl) GetSpreadFactor() math.LegacyDec { } // CalculateTokenInByTokenOut implements domain.RoutablePool. +// It calculates the amount of token in given the amount of token out for a transmuter pool. +// Transmuter pool allows no slippage swaps. It just returns the same amount of token in as token out +// Returns error if: +// - the underlying chain pool set on the routable pool is not of transmuter type +// - the token out amount is greater than the balance of the token out +// - the token out amount is greater than the balance of the token in func (r *routableCosmWasmPoolImpl) CalculateTokenInByTokenOut(ctx context.Context, tokenOut sdk.Coin) (sdk.Coin, error) { - return sdk.Coin{}, errors.New("not implemented") + return r.calculateTokenInByTokenOut(ctx, tokenOut, r.TokenInDenom) +} + +func (r *routableCosmWasmPoolImpl) calculateTokenInByTokenOut(ctx context.Context, tokenOut sdk.Coin, tokenInDenom string) (sdk.Coin, error) { + poolType := r.GetType() + + // Ensure that the pool is cosmwasm + if poolType != poolmanagertypes.CosmWasm { + return sdk.Coin{}, domain.InvalidPoolTypeError{PoolType: int32(poolType)} + } + + // Configure the calc query message + calcMessage := msg.NewCalcInAmtGivenOutRequest(tokenInDenom, tokenOut, r.SpreadFactor) + + calcInAmtGivenOutResponse := msg.CalcInAmtGivenOutResponse{} + if err := cosmwasmdomain.QueryCosmwasmContract(ctx, r.wasmClient, r.ChainPool.ContractAddress, &calcMessage, &calcInAmtGivenOutResponse); err != nil { + return sdk.Coin{}, err + } + + // No slippage swaps - just return the same amount of token out as token in + // as long as there is enough liquidity in the pool. + return calcInAmtGivenOutResponse.TokenIn, nil } // CalculateTokenOutByTokenIn implements domain.RoutablePool. @@ -155,9 +181,10 @@ func (r *routableCosmWasmPoolImpl) ChargeTakerFeeExactIn(tokenIn sdk.Coin) (inAm } // ChargeTakerFeeExactOut implements domain.RoutablePool. -// Returns tokenOutAmount and does not charge any fee for transmuter pools. -func (r *routableCosmWasmPoolImpl) ChargeTakerFeeExactOut(tokenOut sdk.Coin) (outAmountAfterFee sdk.Coin) { - return sdk.Coin{} +// Returns tokenInAmount and does not charge any fee for transmuter pools. +func (r *routableCosmWasmPoolImpl) ChargeTakerFeeExactOut(tokenIn sdk.Coin) (inAmountAfterFee sdk.Coin) { + tokenInAfterTakerFee, _ := poolmanager.CalcTakerFeeExactOut(tokenIn, r.GetTakerFee()) + return tokenInAfterTakerFee } // GetTakerFee implements domain.RoutablePool. diff --git a/router/usecase/pools/routable_cw_pool_test.go b/router/usecase/pools/routable_cw_pool_test.go new file mode 100644 index 00000000..3d732760 --- /dev/null +++ b/router/usecase/pools/routable_cw_pool_test.go @@ -0,0 +1,269 @@ +package pools_test + +import ( + "context" + "fmt" + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/osmosis-labs/sqs/domain" + + cosmwasmdomain "github.com/osmosis-labs/sqs/domain/cosmwasm" + "github.com/osmosis-labs/sqs/domain/mocks" + "github.com/osmosis-labs/sqs/router/usecase/pools" + + "github.com/osmosis-labs/osmosis/osmomath" + poolmanagertypes "github.com/osmosis-labs/osmosis/v28/x/poolmanager/types" + + "github.com/osmosis-labs/osmosis/v28/app/apptesting" + + "github.com/stretchr/testify/suite" +) + +type CosmWasmPoolSuite struct { + apptesting.KeeperTestHelper +} + +func TestCosmWasmPoolSuite(t *testing.T) { + suite.Run(t, new(CosmWasmPoolSuite)) +} + +func (s *CosmWasmPoolSuite) SetupTest() { + s.Setup() +} + +func (s *CosmWasmPoolSuite) newPool(method domain.TokenSwapMethod, coin sdk.Coin, denom string, isInvalidPoolType bool, takerFee osmomath.Dec, err error) domain.RoutablePool { + cosmwasmPool := s.PrepareCustomTransmuterPoolCustomProject(s.TestAccs[0], []string{coin.Denom, denom}, "sqs", "scripts") + + mock := &mocks.MockRoutablePool{ChainPoolModel: cosmwasmPool.AsSerializablePool(), PoolType: poolmanagertypes.CosmWasm} + wasmclient := &mocks.WasmClient{} + + token := "token_out" + if method == domain.TokenSwapMethodExactOut { + token = "token_in" + } + wasmclient.WithSmartContractState( + []byte(fmt.Sprintf(`{ "%s": { "denom" : "%s", "amount" : "%s" } }`, token, ETH, coin.Amount.String())), + err, + ) + + cosmWasmPoolsParams := cosmwasmdomain.CosmWasmPoolsParams{ + Config: domain.CosmWasmPoolRouterConfig{ + GeneralCosmWasmCodeIDs: map[uint64]struct{}{ + cosmwasmPool.GetCodeId(): {}, + }, + }, + WasmClient: wasmclient, + ScalingFactorGetterCb: domain.UnsetScalingFactorGetterCb, + } + + routablePool, err := pools.NewRoutablePool(mock, coin.Denom, denom, takerFee, cosmWasmPoolsParams) + s.Require().NoError(err) + + // Overwrite pool type for edge case testing + if isInvalidPoolType { + mock.PoolType = poolmanagertypes.Concentrated + } + + return routablePool +} + +func (s *CosmWasmPoolSuite) TestCalculateTokenOutByTokenIn() { + defaultAmount := DefaultAmt0 + defaultBalances := sdk.NewCoins(sdk.NewCoin(USDC, defaultAmount), sdk.NewCoin(ETH, defaultAmount)) + + tests := map[string]struct { + tokenIn sdk.Coin + tokenOutDenom string + balances sdk.Coins + isInvalidPoolType bool + expectError error + }{ + "valid CosmWasm quote": { + tokenIn: sdk.NewCoin(USDC, defaultAmount), + tokenOutDenom: ETH, + balances: defaultBalances, + }, + "no error: token in is larger than balance of token in": { + tokenIn: sdk.NewCoin(USDC, defaultAmount), + tokenOutDenom: ETH, + // Make token in amount 1 smaller than the default amount + balances: sdk.NewCoins(sdk.NewCoin(USDC, defaultAmount.Sub(osmomath.OneInt())), sdk.NewCoin(ETH, defaultAmount)), + }, + "error: token in is larger than balance of token out": { + tokenIn: sdk.NewCoin(USDC, defaultAmount), + tokenOutDenom: ETH, + + // Make token out amount 1 smaller than the default amount + balances: sdk.NewCoins(sdk.NewCoin(USDC, defaultAmount), sdk.NewCoin(ETH, defaultAmount.Sub(osmomath.OneInt()))), + + expectError: domain.TransmuterInsufficientBalanceError{ + Denom: ETH, + BalanceAmount: defaultAmount.Sub(osmomath.OneInt()).String(), + Amount: defaultAmount.String(), + }, + }, + } + + for name, tc := range tests { + s.Run(name, func() { + s.Setup() + + routablePool := s.newPool(domain.TokenSwapMethodExactIn, tc.tokenIn, tc.tokenOutDenom, tc.isInvalidPoolType, noTakerFee, tc.expectError) + + tokenOut, err := routablePool.CalculateTokenOutByTokenIn(context.TODO(), tc.tokenIn) + + if tc.expectError != nil { + s.Require().Error(err) + s.Require().ErrorIs(err, tc.expectError) + return + } + s.Require().NoError(err) + + // No slippage swaps on success + s.Require().Equal(tc.tokenIn.Amount, tokenOut.Amount) + }) + } +} + +func (s *CosmWasmPoolSuite) TestChargeTakerFeeExactIn() { + defaultAmount := DefaultAmt0 + defaultBalances := sdk.NewCoins(sdk.NewCoin(USDC, defaultAmount), sdk.NewCoin(ETH, defaultAmount)) + + tests := map[string]struct { + poolType poolmanagertypes.PoolType + tokenIn sdk.Coin + takerFee osmomath.Dec + balances sdk.Coins + expectedToken sdk.Coin + }{ + "no taker fee": { + tokenIn: sdk.NewCoin(USDC, osmomath.NewInt(100)), + balances: defaultBalances, + takerFee: osmomath.NewDec(0), + expectedToken: sdk.NewCoin(USDC, osmomath.NewInt(100)), + }, + "small taker fee": { + tokenIn: sdk.NewCoin(USDT, osmomath.NewInt(100)), + takerFee: osmomath.NewDecWithPrec(1, 2), // 1% + expectedToken: sdk.NewCoin(USDT, osmomath.NewInt(99)), // 100 - 1 = 99 + }, + "large taker fee": { + tokenIn: sdk.NewCoin(USDC, osmomath.NewInt(100)), + takerFee: osmomath.NewDecWithPrec(5, 1), // 50% + expectedToken: sdk.NewCoin(USDC, osmomath.NewInt(50)), // 100 - 50 = 50 + }, + } + + for name, tc := range tests { + s.Run(name, func() { + s.Setup() + + routablePool := s.newPool(domain.TokenSwapMethodExactIn, tc.tokenIn, "", false, tc.takerFee, nil) + + tokenAfterFee := routablePool.ChargeTakerFeeExactIn(tc.tokenIn) + + s.Require().Equal(tc.expectedToken, tokenAfterFee) + }) + } +} + +func (s *CosmWasmPoolSuite) TestCalculateTokenInByTokenOut() { + defaultAmount := DefaultAmt0 + defaultBalances := sdk.NewCoins(sdk.NewCoin(USDC, defaultAmount), sdk.NewCoin(ETH, defaultAmount)) + + tests := map[string]struct { + tokenOut sdk.Coin + tokenInDenom string + balances sdk.Coins + isInvalidPoolType bool + expectError error + }{ + "valid CosmWasm quote": { + tokenOut: sdk.NewCoin(USDC, defaultAmount), + tokenInDenom: ETH, + balances: defaultBalances, + }, + "no error: token in is larger than balance of token in": { + tokenOut: sdk.NewCoin(USDC, defaultAmount), + tokenInDenom: ETH, + // Make token in amount 1 smaller than the default amount + balances: sdk.NewCoins(sdk.NewCoin(USDC, defaultAmount.Sub(osmomath.OneInt())), sdk.NewCoin(ETH, defaultAmount)), + }, + "error: token in is larger than balance of token out": { + tokenOut: sdk.NewCoin(USDC, defaultAmount), + tokenInDenom: ETH, + + // Make token out amount 1 smaller than the default amount + balances: sdk.NewCoins(sdk.NewCoin(USDC, defaultAmount), sdk.NewCoin(ETH, defaultAmount.Sub(osmomath.OneInt()))), + + expectError: domain.TransmuterInsufficientBalanceError{ + Denom: ETH, + BalanceAmount: defaultAmount.Sub(osmomath.OneInt()).String(), + Amount: defaultAmount.String(), + }, + }, + } + + for name, tc := range tests { + s.Run(name, func() { + s.Setup() + + routablePool := s.newPool(domain.TokenSwapMethodExactOut, tc.tokenOut, tc.tokenInDenom, tc.isInvalidPoolType, noTakerFee, tc.expectError) + + tokenIn, err := routablePool.CalculateTokenInByTokenOut(context.TODO(), tc.tokenOut) + + if tc.expectError != nil { + s.Require().Error(err) + s.Require().ErrorIs(err, tc.expectError) + return + } + s.Require().NoError(err) + + // No slippage swaps on success + s.Require().Equal(tc.tokenOut.Amount, tokenIn.Amount) + }) + } +} + +func (s *CosmWasmPoolSuite) TestChargeTakerFeeExactOut() { + defaultAmount := DefaultAmt0 + defaultBalances := sdk.NewCoins(sdk.NewCoin(USDC, defaultAmount), sdk.NewCoin(ETH, defaultAmount)) + + tests := map[string]struct { + poolType poolmanagertypes.PoolType + tokenIn sdk.Coin + takerFee osmomath.Dec + balances sdk.Coins + expectedToken sdk.Coin + }{ + "no taker fee": { + tokenIn: sdk.NewCoin(USDC, osmomath.NewInt(100)), + balances: defaultBalances, + takerFee: osmomath.NewDec(0), + expectedToken: sdk.NewCoin(USDC, osmomath.NewInt(100)), + }, + "small taker fee": { + tokenIn: sdk.NewCoin(USDT, osmomath.NewInt(100)), + takerFee: osmomath.NewDecWithPrec(1, 2), // 1% + expectedToken: sdk.NewCoin(USDT, osmomath.NewInt(102)), // 100 + 1 = 101.01 = 102 (round up) + }, + "large taker fee": { + tokenIn: sdk.NewCoin(USDC, osmomath.NewInt(100)), + takerFee: osmomath.NewDecWithPrec(5, 1), // 50% + expectedToken: sdk.NewCoin(USDC, osmomath.NewInt(200)), // 100 + 100 = 200 + }, + } + + for name, tc := range tests { + s.Run(name, func() { + s.Setup() + + routablePool := s.newPool(domain.TokenSwapMethodExactOut, tc.tokenIn, "", false, tc.takerFee, nil) + + tokenAfterFee := routablePool.ChargeTakerFeeExactOut(tc.tokenIn) + + s.Require().Equal(tc.expectedToken, tokenAfterFee) + }) + } +}