From 5d50cc8e2e32115267211c0ff92a2e61cb990e8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?lz=2Esir=CE=94rthurmoney=28=29?= Date: Tue, 7 Nov 2023 10:13:24 -0800 Subject: [PATCH] updating logic with fee for native --- .../token/oft/v2/fee/NativeOFTWithFee.sol | 26 +++++++++++++------ test/oft/v2/NativeOFTWithFee.test.js | 3 +-- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/contracts/token/oft/v2/fee/NativeOFTWithFee.sol b/contracts/token/oft/v2/fee/NativeOFTWithFee.sol index c38abf0c..fc5f9f1a 100644 --- a/contracts/token/oft/v2/fee/NativeOFTWithFee.sol +++ b/contracts/token/oft/v2/fee/NativeOFTWithFee.sol @@ -30,22 +30,26 @@ contract NativeOFTWithFee is OFTWithFee, ReentrancyGuard { ************************************************************************/ function sendFrom(address _from, uint16 _dstChainId, bytes32 _toAddress, uint _amount, uint _minAmount, LzCallParams calldata _callParams) public payable virtual override { _amount = _send(_from, _dstChainId, _toAddress, _amount, _callParams.refundAddress, _callParams.zroPaymentAddress, _callParams.adapterParams); - (_amount,) = _payOFTFee(address(this), _dstChainId, _amount); require(_amount >= _minAmount, "BaseOFTWithFee: amount is less than minAmount"); } function sendAndCall(address _from, uint16 _dstChainId, bytes32 _toAddress, uint _amount, uint _minAmount, bytes calldata _payload, uint64 _dstGasForCall, LzCallParams calldata _callParams) public payable virtual override { _amount = _sendAndCall(_from, _dstChainId, _toAddress, _amount, _payload, _dstGasForCall, _callParams.refundAddress, _callParams.zroPaymentAddress, _callParams.adapterParams); - (_amount,) = _payOFTFee(address(this), _dstChainId, _amount); require(_amount >= _minAmount, "BaseOFTWithFee: amount is less than minAmount"); } function _send(address _from, uint16 _dstChainId, bytes32 _toAddress, uint _amount, address payable _refundAddress, address _zroPaymentAddress, bytes memory _adapterParams) internal virtual override returns (uint amount) { _checkGasLimit(_dstChainId, PT_SEND, _adapterParams, NO_EXTRA_GAS); - (amount,) = _removeDust(_amount); - require(amount > 0, "NativeOFTWithFee: amount too small"); - uint messageFee = _debitFromNative(_from, amount); + require(_amount > 0, "NativeOFTWithFee: amount too small"); + uint messageFee = _debitFromNative(_from, _amount); + (_amount,) = _payOFTFee(address(this), _dstChainId, _amount); + + uint dust; + (amount, dust) = _removeDust(_amount); + if(dust > 0) { + _transferFrom(address(this), _from, dust); + } bytes memory lzPayload = _encodeSendPayload(_toAddress, _ld2sd(amount)); _lzSend(_dstChainId, lzPayload, _refundAddress, _zroPaymentAddress, _adapterParams, messageFee); @@ -56,9 +60,15 @@ contract NativeOFTWithFee is OFTWithFee, ReentrancyGuard { function _sendAndCall(address _from, uint16 _dstChainId, bytes32 _toAddress, uint _amount, bytes memory _payload, uint64 _dstGasForCall, address payable _refundAddress, address _zroPaymentAddress, bytes memory _adapterParams) internal virtual override returns (uint amount) { _checkGasLimit(_dstChainId, PT_SEND_AND_CALL, _adapterParams, _dstGasForCall); - (amount,) = _removeDust(_amount); - require(amount > 0, "NativeOFTWithFee: amount too small"); - uint messageFee = _debitFromNative(_from, amount); + require(_amount > 0, "NativeOFTWithFee: amount too small"); + uint messageFee = _debitFromNative(_from, _amount); + (_amount,) = _payOFTFee(address(this), _dstChainId, _amount); + + uint dust; + (amount, dust) = _removeDust(_amount); + if(dust > 0) { + _transferFrom(address(this), _from, dust); + } // encode the msg.sender into the payload instead of _from bytes memory lzPayload = _encodeSendAndCallPayload(msg.sender, _toAddress, _ld2sd(amount), _payload, _dstGasForCall); diff --git a/test/oft/v2/NativeOFTWithFee.test.js b/test/oft/v2/NativeOFTWithFee.test.js index 156c8654..107ff4b8 100644 --- a/test/oft/v2/NativeOFTWithFee.test.js +++ b/test/oft/v2/NativeOFTWithFee.test.js @@ -125,7 +125,7 @@ describe("NativeOFTWithFee: ", function () { expect(await ethers.provider.getBalance(localEndpoint.address)).to.be.equal(ethers.utils.parseEther("0")) // set default fee to 50% - await nativeOFTWithFee.setDefaultFeeBp(5000) + await nativeOFTWithFee.setDefaultFeeBp(1) await nativeOFTWithFee.setFeeOwner(bob.address) // ensure they're both allocated initial amounts @@ -155,7 +155,6 @@ describe("NativeOFTWithFee: ", function () { [owner.address, ethers.constants.AddressZero, defaultAdapterParams], { value: nativeFee.add(totalAmount) } // pass a msg.value to pay the LayerZero message fee ) - expect(await ethers.provider.getBalance(nativeOFTWithFee.address)).to.be.equal(totalAmount) expect(await ethers.provider.getBalance(localEndpoint.address)).to.be.equal(nativeFee) // collects expect(await nativeOFTWithFee.balanceOf(owner.address)).to.be.equal(leftOverAmount) expect(await nativeOFTWithFee.balanceOf(alice.address)).to.be.equal(leftOverAmount)