Skip to content

Commit

Permalink
Adding in update to NativeOFTWithFee (#131)
Browse files Browse the repository at this point in the history
Refactoring fee logic and adding in outboundAmount
  • Loading branch information
sirarthurmoney authored Dec 19, 2023
1 parent 30d6896 commit 9b1a8d5
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 17 deletions.
4 changes: 4 additions & 0 deletions contracts/token/oft/v2/NativeOFTV2.sol
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import "@openzeppelin/contracts/security/ReentrancyGuard.sol";
import "./OFTV2.sol";

contract NativeOFTV2 is OFTV2, ReentrancyGuard {
uint public outboundAmount;

event Deposit(address indexed _dst, uint _amount);
event Withdrawal(address indexed _src, uint _amount);

Expand Down Expand Up @@ -110,6 +112,7 @@ contract NativeOFTV2 is OFTV2, ReentrancyGuard {
}

function _debitFromNative(address _from, uint _amount) internal returns (uint messageFee) {
outboundAmount += _amount;
messageFee = msg.sender == _from ? _debitMsgSender(_amount) : _debitMsgFrom(_from, _amount);
}

Expand Down Expand Up @@ -165,6 +168,7 @@ contract NativeOFTV2 is OFTV2, ReentrancyGuard {
address _toAddress,
uint _amount
) internal override returns (uint) {
outboundAmount -= _amount;
_burn(address(this), _amount);
(bool success, ) = _toAddress.call{value: _amount}("");
require(success, "NativeOFTV2: failed to _creditTo");
Expand Down
65 changes: 49 additions & 16 deletions contracts/token/oft/v2/fee/NativeOFTWithFee.sol
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import "@openzeppelin/contracts/security/ReentrancyGuard.sol";
import "./OFTWithFee.sol";

contract NativeOFTWithFee is OFTWithFee, ReentrancyGuard {
uint public outboundAmount;

event Deposit(address indexed _dst, uint _amount);
event Withdrawal(address indexed _src, uint _amount);
Expand All @@ -25,12 +26,24 @@ contract NativeOFTWithFee is OFTWithFee, ReentrancyGuard {
emit Withdrawal(msg.sender, _amount);
}

/************************************************************************
* public functions
************************************************************************/
function sendFrom(address _from, uint16 _dstChainId, bytes32 _toAddress, uint _amount, uint _minAmount, LzCallParams calldata _callParams) public payable virtual override {
_amount = _send(_from, _dstChainId, _toAddress, _amount, _callParams.refundAddress, _callParams.zroPaymentAddress, _callParams.adapterParams);
require(_amount >= _minAmount, "NativeOFTWithFee: amount is less than minAmount");
}

function sendAndCall(address _from, uint16 _dstChainId, bytes32 _toAddress, uint _amount, uint _minAmount, bytes calldata _payload, uint64 _dstGasForCall, LzCallParams calldata _callParams) public payable virtual override {
_amount = _sendAndCall(_from, _dstChainId, _toAddress, _amount, _payload, _dstGasForCall, _callParams.refundAddress, _callParams.zroPaymentAddress, _callParams.adapterParams);
require(_amount >= _minAmount, "NativeOFTWithFee: amount is less than minAmount");
}

function _send(address _from, uint16 _dstChainId, bytes32 _toAddress, uint _amount, address payable _refundAddress, address _zroPaymentAddress, bytes memory _adapterParams) internal virtual override returns (uint amount) {
_checkGasLimit(_dstChainId, PT_SEND, _adapterParams, NO_EXTRA_GAS);

(amount,) = _removeDust(_amount);
require(amount > 0, "NativeOFTWithFee: amount too small");
uint messageFee = _debitFromNative(_from, amount);
uint messageFee;
(messageFee, amount) = _debitFromNative(_from, _amount, _dstChainId);

bytes memory lzPayload = _encodeSendPayload(_toAddress, _ld2sd(amount));
_lzSend(_dstChainId, lzPayload, _refundAddress, _zroPaymentAddress, _adapterParams, messageFee);
Expand All @@ -41,9 +54,8 @@ contract NativeOFTWithFee is OFTWithFee, ReentrancyGuard {
function _sendAndCall(address _from, uint16 _dstChainId, bytes32 _toAddress, uint _amount, bytes memory _payload, uint64 _dstGasForCall, address payable _refundAddress, address _zroPaymentAddress, bytes memory _adapterParams) internal virtual override returns (uint amount) {
_checkGasLimit(_dstChainId, PT_SEND_AND_CALL, _adapterParams, _dstGasForCall);

(amount,) = _removeDust(_amount);
require(amount > 0, "NativeOFTWithFee: amount too small");
uint messageFee = _debitFromNative(_from, amount);
uint messageFee;
(messageFee, amount) = _debitFromNative(_from, _amount, _dstChainId);

// encode the msg.sender into the payload instead of _from
bytes memory lzPayload = _encodeSendAndCallPayload(msg.sender, _toAddress, _ld2sd(amount), _payload, _dstGasForCall);
Expand All @@ -52,35 +64,55 @@ contract NativeOFTWithFee is OFTWithFee, ReentrancyGuard {
emit SendToChain(_dstChainId, _from, _toAddress, amount);
}

function _debitFromNative(address _from, uint _amount) internal returns (uint messageFee) {
messageFee = msg.sender == _from ? _debitMsgSender(_amount) : _debitMsgFrom(_from, _amount);
function _debitFromNative(address _from, uint _amount, uint16 _dstChainId) internal returns (uint messageFee, uint amount) {
uint fee = quoteOFTFee(_dstChainId, _amount);
uint newMsgValue = msg.value;

if(fee > 0) {
// subtract fee from _amount
_amount -= fee;

// pay fee and update newMsgValue
if(balanceOf(_from) >= fee) {
_transferFrom(_from, feeOwner, fee);
} else {
_mint(feeOwner, fee);
newMsgValue -= fee;
}
}

(amount,) = _removeDust(_amount);
require(amount > 0, "NativeOFTWithFee: amount too small");
outboundAmount += amount;
messageFee = msg.sender == _from ? _debitMsgSender(amount, newMsgValue) : _debitMsgFrom(_from, amount, newMsgValue);
}

function _debitMsgSender(uint _amount) internal returns (uint messageFee) {
function _debitMsgSender(uint _amount, uint currentMsgValue) internal returns (uint messageFee) {
uint msgSenderBalance = balanceOf(msg.sender);

if (msgSenderBalance < _amount) {
require(msgSenderBalance + msg.value >= _amount, "NativeOFTWithFee: Insufficient msg.value");
require(msgSenderBalance + currentMsgValue >= _amount, "NativeOFTWithFee: Insufficient msg.value");

// user can cover difference with additional msg.value ie. wrapping
uint mintAmount = _amount - msgSenderBalance;

_mint(address(msg.sender), mintAmount);

// update the messageFee to take out mintAmount
messageFee = msg.value - mintAmount;
messageFee = currentMsgValue - mintAmount;
} else {
messageFee = msg.value;
messageFee = currentMsgValue;
}

_transfer(msg.sender, address(this), _amount);
return messageFee;
}

function _debitMsgFrom(address _from, uint _amount) internal returns (uint messageFee) {
function _debitMsgFrom(address _from, uint _amount, uint currentMsgValue) internal returns (uint messageFee) {
uint msgFromBalance = balanceOf(_from);

if (msgFromBalance < _amount) {
require(msgFromBalance + msg.value >= _amount, "NativeOFTWithFee: Insufficient msg.value");
require(msgFromBalance + currentMsgValue >= _amount, "NativeOFTWithFee: Insufficient msg.value");

// user can cover difference with additional msg.value ie. wrapping
uint mintAmount = _amount - msgFromBalance;
Expand All @@ -93,9 +125,9 @@ contract NativeOFTWithFee is OFTWithFee, ReentrancyGuard {
_amount = msgFromBalance;

// update the messageFee to take out mintAmount
messageFee = msg.value - mintAmount;
messageFee = currentMsgValue - mintAmount;
} else {
messageFee = msg.value;
messageFee = currentMsgValue;
}

_spendAllowance(_from, msg.sender, _amount);
Expand All @@ -104,6 +136,7 @@ contract NativeOFTWithFee is OFTWithFee, ReentrancyGuard {
}

function _creditTo(uint16, address _toAddress, uint _amount) internal override returns(uint) {
outboundAmount -= _amount;
_burn(address(this), _amount);
(bool success, ) = _toAddress.call{value: _amount}("");
require(success, "NativeOFTWithFee: failed to _creditTo");
Expand Down
11 changes: 11 additions & 0 deletions test/oft/v2/NativeOFTV2.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ describe("NativeOFTV2: ", function () {
expect(await nativeOFTV2.balanceOf(nativeOFTV2.address)).to.be.equal(totalAmount)
expect(await nativeOFTV2.balanceOf(owner.address)).to.be.equal(leftOverAmount)
expect(await remoteOFTV2.balanceOf(owner.address)).to.be.equal(totalAmount)
expect(await nativeOFTV2.outboundAmount()).to.be.equal(totalAmount)
expect(await remoteOFTV2.totalSupply()).to.be.equal(totalAmount)


let ownerBalance2 = await ethers.provider.getBalance(owner.address)

Expand All @@ -106,6 +109,8 @@ describe("NativeOFTV2: ", function () {
expect(await ethers.provider.getBalance(owner.address)).to.be.equal(ownerBalance2.sub(nativeFee).sub(transFee))
expect(await nativeOFTV2.balanceOf(owner.address)).to.equal(leftOverAmount)
expect(await remoteOFTV2.balanceOf(owner.address)).to.equal(0)
expect(await remoteOFTV2.totalSupply()).to.be.equal(leftOverAmount)
expect(await nativeOFTV2.outboundAmount()).to.be.equal(leftOverAmount)
})

it("sendFrom() - with enough native", async function () {
Expand Down Expand Up @@ -142,6 +147,8 @@ describe("NativeOFTV2: ", function () {
expect(await nativeOFTV2.balanceOf(nativeOFTV2.address)).to.be.equal(totalAmountMinusDust)
expect(await nativeOFTV2.balanceOf(owner.address)).to.be.equal(leftOverAmount)
expect(await remoteOFTV2.balanceOf(owner.address)).to.be.equal(totalAmountMinusDust)
expect(await nativeOFTV2.outboundAmount()).to.be.equal(totalAmountMinusDust)
expect(await remoteOFTV2.totalSupply()).to.be.equal(totalAmountMinusDust)
})

it("sendFrom() - from != sender with addition msg.value", async function () {
Expand Down Expand Up @@ -180,6 +187,8 @@ describe("NativeOFTV2: ", function () {
expect(await nativeOFTV2.balanceOf(nativeOFTV2.address)).to.be.equal(totalAmount)
expect(await nativeOFTV2.balanceOf(owner.address)).to.be.equal(leftOverAmount)
expect(await remoteOFTV2.balanceOf(owner.address)).to.be.equal(totalAmount)
expect(await nativeOFTV2.outboundAmount()).to.be.equal(totalAmount)
expect(await remoteOFTV2.totalSupply()).to.be.equal(totalAmount)
})

it("sendFrom() - from != sender with not enough native", async function () {
Expand Down Expand Up @@ -286,6 +295,8 @@ describe("NativeOFTV2: ", function () {
// verify tokens burned on source chain and minted on destination chain
expect(await nativeOFTV2.balanceOf(nativeOFTV2.address)).to.be.equal(amount)
expect(await remoteOFTV2.balanceOf(owner.address)).to.be.equal(amount)
expect(await nativeOFTV2.outboundAmount()).to.be.equal(amount)
expect(await remoteOFTV2.totalSupply()).to.be.equal(amount)
})

it("setMinDstGas() - when type is not set on destination chain", async function () {
Expand Down
Loading

0 comments on commit 9b1a8d5

Please sign in to comment.