diff --git a/src/logic/editable-storage-logic.ts b/src/logic/editable-storage-logic.ts index e21f202..6922078 100644 --- a/src/logic/editable-storage-logic.ts +++ b/src/logic/editable-storage-logic.ts @@ -1,5 +1,5 @@ import { SetVariablesType, SmockVMManager } from '../types'; -import { fromHexString, toFancyAddress } from '../utils'; +import { fromHexString, remove0x, toFancyAddress, toHexString } from '../utils'; import { computeStorageSlots, SolidityStorageLayout } from '../utils/storage'; export class EditableStorageLogic { @@ -19,7 +19,14 @@ export class EditableStorageLogic { const slots = computeStorageSlots(this.storageLayout, { [variableName]: value }); for (const slot of slots) { - await this.vmManager.putContractStorage(toFancyAddress(this.contractAddress), fromHexString(slot.key), fromHexString(slot.val)); + let prevStorageValue = await this.vmManager.getContractStorage(toFancyAddress(this.contractAddress), fromHexString(slot.key)); + let stringValue = remove0x(toHexString(prevStorageValue)); + if (stringValue && (slot.type === 'address' || slot.type === 'bool' || slot.type.startsWith('contract'))) { + let lastResult = slot.val.slice(0, slot.val.length - stringValue.length).concat(stringValue); + await this.vmManager.putContractStorage(toFancyAddress(this.contractAddress), fromHexString(slot.key), fromHexString(lastResult)); + } else { + await this.vmManager.putContractStorage(toFancyAddress(this.contractAddress), fromHexString(slot.key), fromHexString(slot.val)); + } } } diff --git a/src/utils/storage.ts b/src/utils/storage.ts index 489f61e..b322035 100644 --- a/src/utils/storage.ts +++ b/src/utils/storage.ts @@ -40,6 +40,7 @@ export interface SolidityStorageLayout { interface StorageSlotPair { key: string; val: string; + type: string; } /** @@ -139,6 +140,7 @@ export function computeStorageSlots(storageLayout: SolidityStorageLayout, variab prevSlots.push({ key: slot.key, val: mergedVal, + type: slot.type, }); } @@ -228,6 +230,7 @@ function encodeVariable( { key: slotKey, val: padNumHexSlotValue(variable, storageObj.offset), + type: variableType.label, }, ]; } else if (variableType.label === 'bool') { @@ -249,6 +252,7 @@ function encodeVariable( { key: slotKey, val: padNumHexSlotValue(variable ? '1' : '0', storageObj.offset), + type: variableType.label, }, ]; } else if (variableType.label.startsWith('bytes')) { @@ -260,6 +264,7 @@ function encodeVariable( { key: slotKey, val: padBytesHexSlotValue(remove0x(variable).padEnd(parseInt(variableType.numberOfBytes, 10) * 2, '0'), storageObj.offset), + type: variableType.label, }, ]; } else if (variableType.label.startsWith('uint') || variableType.label.startsWith('int')) { @@ -271,6 +276,7 @@ function encodeVariable( { key: slotKey, val: padNumHexSlotValue(variable, storageObj.offset), + type: variableType.label, }, ]; } else if (variableType.label.startsWith('struct')) { @@ -314,6 +320,7 @@ function encodeVariable( ethers.BigNumber.from(bytes.length * 2).toHexString(), ]) ), + type: variableType.label, }, ]; } else { @@ -324,6 +331,7 @@ function encodeVariable( slots = slots.concat({ key: slotKey, val: padNumHexSlotValue(bytes.length * 2 + 1, 0), + type: variableType.label, }); // Each storage slot has 32 bytes so we make sure to slice the large bytes into 32bytes chunks @@ -335,6 +343,7 @@ function encodeVariable( slots = slots.concat({ key: key, val: ethers.utils.hexlify(ethers.utils.concat([bytes.slice(i * 32, i * 32 + 32), ethers.constants.HashZero]).slice(0, 32)), + type: variableType.label, }); } diff --git a/test/contracts/mock/StorageGetter.sol b/test/contracts/mock/StorageGetter.sol index 4e0b148..c547ec9 100644 --- a/test/contracts/mock/StorageGetter.sol +++ b/test/contracts/mock/StorageGetter.sol @@ -33,6 +33,10 @@ contract StorageGetter { bool internal _packedA; address internal _packedB; + // Testing slot-overwrite + address public _slotA; + bool public _slotB; + constructor(uint256 _inA) { _constructorUint256 = _inA; } diff --git a/test/unit/mock/editable-storage-logic.spec.ts b/test/unit/mock/editable-storage-logic.spec.ts index b6d16ae..046b876 100644 --- a/test/unit/mock/editable-storage-logic.spec.ts +++ b/test/unit/mock/editable-storage-logic.spec.ts @@ -62,6 +62,41 @@ describe('Mock: Editable storage logic', () => { expect(await mock.getAddress()).to.equal(ADDRESS_EXAMPLE); }); + it('should not be able to overwrite slot', async () => { + await mock.setVariable('_slotA', ADDRESS_EXAMPLE); + await mock.setVariable('_slotB', true); + await mock.setVariable('_bytes', BYTES_EXAMPLE); + const value = utils.parseUnits('123'); + await mock.setVariable('_uint256', value); + await mock.setVariable('_bytes32', BYTES32_EXAMPLE); + const struct = { + packedA: true, + packedB: ADDRESS_EXAMPLE, + }; + await mock.setVariable('_packedStruct', struct); + const mapKey = 1234; + const mapValue = 5678; + await mock.setVariable('_uint256Map', { [mapKey]: mapValue }); + const mapKeyA = 1234; + const mapKeyB = 4321; + const mapVal = 5678; + + await mock.setVariable('_uint256NestedMap', { + [mapKeyA]: { + [mapKeyB]: mapVal, + }, + }); + + expect(await mock._slotA()).to.equal(ADDRESS_EXAMPLE); + expect(await mock._slotB()).to.equal(true); + expect(await mock.getBytes()).to.equal(BYTES_EXAMPLE); + expect(await mock.getUint256()).to.equal(value); + expect(await mock.getBytes32()).to.equal(BYTES32_EXAMPLE); + expect(convertStructToPojo(await mock.getPackedStruct())).to.deep.equal(struct); + expect(await mock.getUint256MapValue(mapKey)).to.equal(mapValue); + expect(await mock.getNestedUint256MapValue(mapKeyA, mapKeyB)).to.equal(mapVal); + }); + it('should be able to set an address in a packed storage slot', async () => { await mock.setVariable('_packedB', ADDRESS_EXAMPLE);