Skip to content

Commit

Permalink
Feat: Add ERC721Enumerable component
Browse files Browse the repository at this point in the history
  • Loading branch information
gianalarcon committed Aug 25, 2024
1 parent 958b0c8 commit 25970fe
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 55 deletions.
1 change: 0 additions & 1 deletion packages/snfoundry/contracts/src/Counter.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ pub trait ICounter<T> {
#[starknet::component]
pub mod CounterComponent {
use starknet::ContractAddress;
use starknet::get_caller_address;
use super::{ICounter};

#[storage]
Expand Down
42 changes: 42 additions & 0 deletions packages/snfoundry/contracts/src/ERC721Enumerable.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use starknet::ContractAddress;

#[starknet::interface]
pub trait IERC721Enumerable<T> {
fn token_of_owner_by_index(self: @T, owner: ContractAddress, index: u256) -> u256;
fn total_supply(self: @T) -> u256;
}

#[starknet::component]
pub mod ERC721EnumerableComponent {
use super::{IERC721Enumerable, ContractAddress};

#[storage]
struct Storage {
// Mapping from owner to list of owned token IDs
owned_tokens: LegacyMap<(ContractAddress, u256), u256>,
// Mapping from token ID to index of the owner tokens list
owned_tokens_index: LegacyMap<u256, u256>,
// Mapping with all token ids,
all_tokens: LegacyMap<u256, u256>,
// Helper to get the length of `all_tokens`
all_tokens_length: u256,
// Mapping from token id to position in the allTokens array
all_tokens_index: LegacyMap<u256, u256>
}

#[embeddable_as(ERC721EnumerableImpl)]
impl ERC721Enumerable<
TContractState, +HasComponent<TContractState>
> of IERC721Enumerable<ComponentState<TContractState>> {
fn token_of_owner_by_index(
self: @ComponentState<TContractState>, owner: ContractAddress, index: u256
) -> u256 {
// TODO: Add this check back in
//assert(index < self.erc721.balance_of(owner), 'Owner index out of bounds');
self.owned_tokens.read((owner, index))
}
fn total_supply(self: @ComponentState<TContractState>) -> u256 {
self.all_tokens_length.read()
}
}
}
82 changes: 30 additions & 52 deletions packages/snfoundry/contracts/src/YourCollectible.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@ pub trait IYourCollectible<T> {
fn mint_item(ref self: T, recipient: ContractAddress, uri: ByteArray) -> u256;
}

#[starknet::interface]
pub trait IERC721Enumerable<T> {
fn token_of_owner_by_index(self: @T, owner: ContractAddress, index: u256) -> u256;
fn total_supply(self: @T) -> u256;
}

#[starknet::interface]
pub trait IERC721Metadata<T> {
// Define custom `token_uri` function
Expand All @@ -20,6 +14,7 @@ pub trait IERC721Metadata<T> {
#[starknet::contract]
mod YourCollectible {
use contracts::Counter::CounterComponent;
use contracts::ERC721Enumerable::ERC721EnumerableComponent;
use core::num::traits::zero::Zero;
use openzeppelin::access::ownable::OwnableComponent;

Expand All @@ -34,20 +29,24 @@ mod YourCollectible {
};

use starknet::get_caller_address;
use super::{IYourCollectible, ContractAddress, IERC721Enumerable};
use super::{IYourCollectible, ContractAddress};

component!(path: ERC721Component, storage: erc721, event: ERC721Event);
component!(path: SRC5Component, storage: src5, event: SRC5Event);
component!(path: OwnableComponent, storage: ownable, event: OwnableEvent);
component!(path: CounterComponent, storage: token_id_counter, event: CounterEvent);
component!(path: ERC721ReceiverComponent, storage: erc721_receiver, event: ERC721ReceiverEvent);
component!(path: ERC721EnumerableComponent, storage: enumerable, event: EnumerableEvent);

#[abi(embed_v0)]
impl OwnableImpl = OwnableComponent::OwnableImpl<ContractState>;
#[abi(embed_v0)]
impl CounterImpl = CounterComponent::CounterImpl<ContractState>;
#[abi(embed_v0)]
impl ERC721Impl = ERC721Component::ERC721Impl<ContractState>;
#[abi(embed_v0)]
impl ERC721EnumerableImpl =
ERC721EnumerableComponent::ERC721EnumerableImpl<ContractState>;

impl ERC721InternalImpl = ERC721Component::InternalImpl<ContractState>;
impl OwnableInternalImpl = OwnableComponent::InternalImpl<ContractState>;
Expand All @@ -66,20 +65,11 @@ mod YourCollectible {
ownable: OwnableComponent::Storage,
#[substorage(v0)]
token_id_counter: CounterComponent::Storage,
#[substorage(v0)]
enumerable: ERC721EnumerableComponent::Storage,
// ERC721URIStorage variables
// Mapping for token URIs
token_uris: LegacyMap<u256, ByteArray>,
// IERC721Enumerable variables
// Mapping from owner to list of owned token IDs
owned_tokens: LegacyMap<(ContractAddress, u256), u256>,
// Mapping from token ID to index of the owner tokens list
owned_tokens_index: LegacyMap<u256, u256>,
// Mapping with all token ids,
all_tokens: LegacyMap<u256, u256>,
// Helper to get the length of `all_tokens`
all_tokens_length: u256,
// Mapping from token id to position in the allTokens array
all_tokens_index: LegacyMap<u256, u256>
}

#[event]
Expand All @@ -93,7 +83,8 @@ mod YourCollectible {
SRC5Event: SRC5Component::Event,
#[flat]
OwnableEvent: OwnableComponent::Event,
CounterEvent: CounterComponent::Event
CounterEvent: CounterComponent::Event,
EnumerableEvent: ERC721EnumerableComponent::Event,
}

#[constructor]
Expand Down Expand Up @@ -132,20 +123,6 @@ mod YourCollectible {
}
}


#[abi(embed_v0)]
impl IERC721EnumerableImpl of IERC721Enumerable<ContractState> {
fn token_of_owner_by_index(
self: @ContractState, owner: ContractAddress, index: u256
) -> u256 {
assert(index < self.erc721.balance_of(owner), 'Owner index out of bounds');
self.owned_tokens.read((owner, index))
}
fn total_supply(self: @ContractState) -> u256 {
self.all_tokens_length.read()
}
}

#[generate_trait]
impl InternalImpl of InternalTrait {
// token_uri custom implementation
Expand Down Expand Up @@ -175,47 +152,48 @@ mod YourCollectible {
) {
let mut contract_state = ERC721Component::HasComponent::get_contract_mut(ref self);
let token_id_counter = contract_state.token_id_counter.current();
let mut enumerable = contract_state.enumerable;
if (token_id == token_id_counter) { // Mint Token case: self._add_token_to_all_tokens_enumeration(first_token_id);
let length = contract_state.all_tokens_length.read();
contract_state.all_tokens_index.write(token_id, length);
contract_state.all_tokens.write(length, token_id);
let length = enumerable.all_tokens_length.read();
enumerable.all_tokens_index.write(token_id, length);
enumerable.all_tokens.write(length, token_id);
} else if (token_id < token_id_counter
+ 1) { // Transfer Token Case: self._remove_token_from_owner_enumeration(auth, first_token_id);
// To prevent a gap in from's tokens array, we store the last token in the index of the token to delete, and
// then delete the last slot (swap and pop).
let owner = self.owner_of(token_id);
let last_token_index = self.balance_of(owner) - 1;
let token_index = contract_state.owned_tokens_index.read(token_id);
let token_index = enumerable.owned_tokens_index.read(token_id);

// When the token to delete is the last token, the swap operation is unnecessary
if (token_index != last_token_index) {
let last_token_id = contract_state.owned_tokens.read((owner, last_token_index));
let last_token_id = enumerable.owned_tokens.read((owner, last_token_index));
// Move the last token to the slot of the to-delete token
contract_state.owned_tokens.write((owner, token_index), last_token_id);
enumerable.owned_tokens.write((owner, token_index), last_token_id);
// Update the moved token's index
contract_state.owned_tokens_index.write(last_token_id, token_index);
enumerable.owned_tokens_index.write(last_token_id, token_index);
}

// Clear the last slot
contract_state.owned_tokens.write((owner, last_token_index), 0);
contract_state.owned_tokens_index.write(token_id, 0);
enumerable.owned_tokens.write((owner, last_token_index), 0);
enumerable.owned_tokens_index.write(token_id, 0);
}
if (to == Zero::zero()) { // Burn Token case: self._remove_token_from_all_tokens_enumeration(first_token_id);
let last_token_index = contract_state.all_tokens_length.read() - 1;
let token_index = contract_state.all_tokens_index.read(token_id);
let last_token_index = enumerable.all_tokens_length.read() - 1;
let token_index = enumerable.all_tokens_index.read(token_id);

let last_token_id = contract_state.all_tokens.read(last_token_index);
let last_token_id = enumerable.all_tokens.read(last_token_index);

contract_state.all_tokens.write(token_index, last_token_id);
contract_state.all_tokens_index.write(last_token_id, token_index);
enumerable.all_tokens.write(token_index, last_token_id);
enumerable.all_tokens_index.write(last_token_id, token_index);

contract_state.all_tokens_index.write(token_id, 0);
contract_state.all_tokens.write(last_token_index, 0);
contract_state.all_tokens_length.write(last_token_index);
enumerable.all_tokens_index.write(token_id, 0);
enumerable.all_tokens.write(last_token_index, 0);
enumerable.all_tokens_length.write(last_token_index);
} else if (to != auth) { //self._add_token_to_owner_enumeration(to, first_token_id);
let length = self.balance_of(to);
contract_state.owned_tokens.write((to, length), token_id);
contract_state.owned_tokens_index.write(token_id, length);
enumerable.owned_tokens.write((to, length), token_id);
enumerable.owned_tokens_index.write(token_id, length);
}
}

Expand Down
1 change: 1 addition & 0 deletions packages/snfoundry/contracts/src/lib.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod Counter;
mod ERC721Enumerable;
mod YourCollectible;
mod mock_contracts {
pub mod Receiver;
Expand Down
5 changes: 3 additions & 2 deletions packages/snfoundry/contracts/src/test/TestContract.cairo
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use contracts::YourCollectible::{
IYourCollectibleDispatcher, IYourCollectibleDispatcherTrait, IERC721EnumerableDispatcher,
IERC721EnumerableDispatcherTrait, IERC721MetadataDispatcher, IERC721MetadataDispatcherTrait
IYourCollectibleDispatcher, IYourCollectibleDispatcherTrait, IERC721MetadataDispatcher,
IERC721MetadataDispatcherTrait
};

use contracts::mock_contracts::Receiver;
use contracts::ERC721Enumerable::{IERC721EnumerableDispatcher, IERC721EnumerableDispatcherTrait};
use core::clone::Clone;
use openzeppelin::token::erc721::interface::{IERC721Dispatcher, IERC721DispatcherTrait};
use openzeppelin::utils::serde::SerializedAppend;
Expand Down

0 comments on commit 25970fe

Please sign in to comment.