diff --git a/alembic/versions/c14c4b6f36b3_create_table_conditions.py b/alembic/versions/c14c4b6f36b3_create_table_conditions.py new file mode 100644 index 0000000..3365001 --- /dev/null +++ b/alembic/versions/c14c4b6f36b3_create_table_conditions.py @@ -0,0 +1,87 @@ +"""create-table-conditions + +Revision ID: c14c4b6f36b3 +Revises: 0894f3022876 +Create Date: 2024-01-11 16:02:26.575786 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "c14c4b6f36b3" +down_revision: Union[str, None] = "0894f3022876" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "conditions", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column( + "type", + sa.Enum( + "MAX_CALLS_PER_ENTRYPOINT", + "MAX_CALLS_PER_SPONSEE", + name="conditiontype", + ), + nullable=True, + ), + sa.Column( + "sponsee_address", + sa.String(), + sa.CheckConstraint( + "(type = 'MAX_CALLS_PER_SPONSEE') = (sponsee_address IS NOT NULL)", + name="sponsee_address_not_null_constraint", + ), + nullable=True, + ), + sa.Column( + "contract_id", + sa.UUID(), + sa.CheckConstraint( + "(type = 'MAX_CALLS_PER_ENTRYPOINT') = (contract_id IS NOT NULL)", + name="contract_id_not_null_constraint", + ), + nullable=True, + ), + sa.Column( + "entrypoint_id", + sa.UUID(), + sa.CheckConstraint( + "(type = 'MAX_CALLS_PER_ENTRYPOINT') = (entrypoint_id IS NOT NULL)", + name="entrypoint_id_not_null_constraint", + ), + nullable=True, + ), + sa.Column("vault_id", sa.UUID(), nullable=False), + sa.Column("max", sa.Integer(), nullable=False), + sa.Column("current", sa.Integer(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.ForeignKeyConstraint( + ["contract_id"], + ["contracts.id"], + ), + sa.ForeignKeyConstraint( + ["entrypoint_id"], + ["entrypoints.id"], + ), + sa.ForeignKeyConstraint( + ["vault_id"], + ["credits.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("conditions") + op.execute("DROP type conditiontype") + # ### end Alembic commands ### diff --git a/demo/staking-contract.mligo b/demo/staking-contract.mligo index 8bbea34..c8555da 100644 --- a/demo/staking-contract.mligo +++ b/demo/staking-contract.mligo @@ -1,9 +1,10 @@ #import "../permit-cameligo/src/main.mligo" "FA2" -type storage = { - nft_address: address; - staked: (address, nat) big_map; -} +type storage = + { + nft_address : address; + staked : (address, nat) big_map + } (* We need to provide the address of the NFT's owner so that the transfer can be done by someone * else (we don't rely on Tezos.get_sender ()) *) diff --git a/src/crud.py b/src/crud.py index 87a5813..3b7d98c 100644 --- a/src/crud.py +++ b/src/crud.py @@ -5,6 +5,7 @@ from sqlalchemy.orm import Session from .utils import ( + ConditionAlreadyExists, ContractAlreadyRegistered, ContractNotFound, CreditNotFound, @@ -308,3 +309,149 @@ def check_calls_per_month(db, contract_id): return True nb_operations_already_made = get_operations_by_contracts_per_month(db, contract_id) return max_calls >= len(nb_operations_already_made) + + +def create_max_calls_per_sponsee_condition( + db: Session, condition: schemas.CreateMaxCallsPerSponseeCondition +): + # If a condition still exists, do not create a new one + existing_condition = ( + db.query(models.Condition) + .filter(models.Condition.sponsee_address == condition.sponsee_address) + .filter(models.Condition.vault_id == condition.vault_id) + .filter(models.Condition.current < models.Condition.max) + .one_or_none() + ) + if existing_condition is not None: + raise ConditionAlreadyExists( + "A condition with maximum calls per sponsee already exists and the maximum is not reached. Cannot create a new one." + ) + db_condition = models.Condition( + **{ + "type": schemas.ConditionType.MAX_CALLS_PER_SPONSEE, + "sponsee_address": condition.sponsee_address, + "vault_id": condition.vault_id, + "max": condition.max, + "current": 0, + } + ) + db.add(db_condition) + db.commit() + db.refresh(db_condition) + return schemas.MaxCallsPerSponseeCondition( + sponsee_address=db_condition.sponsee_address, + vault_id=db_condition.vault_id, + max=db_condition.max, + current=db_condition.current, + type=db_condition.type, + created_at=db_condition.created_at, + id=db_condition.id, + ) + + +def create_max_calls_per_entrypoint_condition( + db: Session, condition: schemas.CreateMaxCallsPerEntrypointCondition +): + # If a condition still exists, do not create a new one + existing_condition = ( + db.query(models.Condition) + .filter(models.Condition.entrypoint_id == condition.entrypoint_id) + .filter(models.Condition.contract_id == condition.contract_id) + .filter(models.Condition.vault_id == condition.vault_id) + .filter(models.Condition.current < models.Condition.max) + .one_or_none() + ) + if existing_condition is not None: + raise ConditionAlreadyExists( + "A condition with maximum calls per entrypoint already exists and the maximum is not reached. Cannot create a new one." + ) + db_condition = models.Condition( + **{ + "type": schemas.ConditionType.MAX_CALLS_PER_ENTRYPOINT, + "contract_id": condition.contract_id, + "entrypoint_id": condition.entrypoint_id, + "vault_id": condition.vault_id, + "max": condition.max, + "current": 0, + } + ) + db.add(db_condition) + db.commit() + db.refresh(db_condition) + return schemas.MaxCallsPerEntrypointCondition( + contract_id=db_condition.contract_id, + entrypoint_id=db_condition.entrypoint_id, + vault_id=db_condition.vault_id, + max=db_condition.max, + current=db_condition.current, + type=db_condition.type, + created_at=db_condition.created_at, + id=db_condition.id, + ) + + +def check_max_calls_per_sponsee(db: Session, sponsee_address: str, vault_id: UUID4): + return ( + db.query(models.Condition) + .filter(models.Condition.type == schemas.ConditionType.MAX_CALLS_PER_SPONSEE) + .filter(models.Condition.sponsee_address == sponsee_address) + .filter(models.Condition.vault_id == vault_id) + .one_or_none() + ) + + +def check_max_calls_per_entrypoint( + db: Session, contract_id: UUID4, entrypoint_id: UUID4, vault_id: UUID4 +): + return ( + db.query(models.Condition) + .filter(models.Condition.type == schemas.ConditionType.MAX_CALLS_PER_ENTRYPOINT) + .filter(models.Condition.contract_id == contract_id) + .filter(models.Condition.entrypoint_id == entrypoint_id) + .filter(models.Condition.vault_id == vault_id) + .one_or_none() + ) + + +def check_conditions(db: Session, datas: schemas.CheckConditions): + print(datas) + sponsee_condition = check_max_calls_per_sponsee( + db, datas.sponsee_address, datas.vault_id + ) + entrypoint_condition = check_max_calls_per_entrypoint( + db, datas.contract_id, datas.entrypoint_id, datas.vault_id + ) + + # No condition registered + if sponsee_condition is None and entrypoint_condition is None: + return True + # One of condition is excedeed + if ( + sponsee_condition is not None + and (sponsee_condition.current >= sponsee_condition.max) + ) or ( + entrypoint_condition is not None + and (entrypoint_condition.current >= entrypoint_condition.max) + ): + return False + + # Update conditions + # TODO - Rewrite with list + + if sponsee_condition: + update_condition(db, sponsee_condition) + if entrypoint_condition: + update_condition(db, entrypoint_condition) + return True + + +def update_condition(db: Session, condition: models.Condition): + db.query(models.Condition).filter(models.Condition.id == condition.id).update( + {"current": condition.current + 1} + ) + + +def get_conditions_by_vault(db: Session, vault_id: str): + return ( + db.query(models.Condition).filter(models.Condition.vault_id == vault_id).all() + ) diff --git a/src/models.py b/src/models.py index 43e454a..b0a7220 100644 --- a/src/models.py +++ b/src/models.py @@ -1,7 +1,18 @@ -from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String +from sqlalchemy import ( + Boolean, + CheckConstraint, + Column, + DateTime, + Enum, + ForeignKey, + Integer, + String, +) from sqlalchemy.orm import relationship from sqlalchemy.dialects.postgresql import UUID import uuid + +from .schemas import ConditionType from .database import Base import datetime @@ -46,6 +57,7 @@ def __repr__(self): entrypoints = relationship("Entrypoint", back_populates="contract") credit = relationship("Credit", back_populates="contracts") operations = relationship("Operation", back_populates="contract") + conditions = relationship("Condition", back_populates="contract") # ------- ENTRYPOINT ------- # @@ -69,6 +81,7 @@ def __repr__(self): contract = relationship("Contract", back_populates="entrypoints") operations = relationship("Operation", back_populates="entrypoint") + conditions = relationship("Condition", back_populates="entrypoint") # ------- CREDITS ------- # @@ -88,6 +101,7 @@ def __repr__(self): owner = relationship("User", back_populates="credits") contracts = relationship("Contract", back_populates="credit") + conditions = relationship("Condition", back_populates="vault") # ------- OPERATIONS ------- # @@ -107,3 +121,49 @@ class Operation(Base): contract = relationship("Contract", back_populates="operations") entrypoint = relationship("Entrypoint", back_populates="operations") + + +# ------- CONDITIONS ------- # + + +class Condition(Base): + __tablename__ = "conditions" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + type = Column(Enum(ConditionType)) + sponsee_address = Column( + String, + CheckConstraint( + "(type = 'MAX_CALLS_PER_SPONSEE') = (sponsee_address IS NOT NULL)", + name="sponsee_address_not_null_constraint", + ), + nullable=True, + ) + contract_id = Column( + UUID(as_uuid=True), + CheckConstraint( + "(type = 'MAX_CALLS_PER_ENTRYPOINT') = (contract_id IS NOT NULL)", + name="contract_id_not_null_constraint", + ), + ForeignKey("contracts.id"), + nullable=True, + ) + entrypoint_id = Column( + UUID(as_uuid=True), + CheckConstraint( + "(type = 'MAX_CALLS_PER_ENTRYPOINT') = (entrypoint_id IS NOT NULL)", + name="entrypoint_id_not_null_constraint", + ), + ForeignKey("entrypoints.id"), + nullable=True, + ) + vault_id = Column(UUID(as_uuid=True), ForeignKey("credits.id"), nullable=False) + max = Column(Integer, nullable=False) + current = Column(Integer, nullable=False) + created_at = Column( + DateTime(timezone=True), default=datetime.datetime.utcnow(), nullable=False + ) + + contract = relationship("Contract", back_populates="conditions") + entrypoint = relationship("Entrypoint", back_populates="conditions") + vault = relationship("Credit", back_populates="conditions") diff --git a/src/routes.py b/src/routes.py index 3ff875c..1522229 100644 --- a/src/routes.py +++ b/src/routes.py @@ -6,6 +6,8 @@ from pytezos.rpc.errors import MichelsonError from pytezos.crypto.encoding import is_address from .utils import ( + ConditionAlreadyExists, + ConditionExceed, ContractAlreadyRegistered, ContractNotFound, CreditNotFound, @@ -17,6 +19,7 @@ OperationNotFound, ) from .config import logging +from .schemas import ConditionType router = APIRouter() @@ -299,6 +302,17 @@ async def post_operation( entrypoint = crud.get_entrypoint(db, str(contract.address), entrypoint_name) if not entrypoint.is_enabled: raise EntrypointDisabled() + + if not crud.check_conditions( + db, + schemas.CheckConditions( + sponsee_address=call_data.sender_address, + contract_id=contract.id, + entrypoint_id=entrypoint.id, + vault_id=contract.credit_id, + ), + ): + raise ConditionExceed() except EntrypointNotFound: logging.warning(f"Entrypoint {entrypoint_name} is not found") raise HTTPException( @@ -311,6 +325,12 @@ async def post_operation( status_code=status.HTTP_403_FORBIDDEN, detail=f"Entrypoint {entrypoint_name} is disabled.", ) + except ConditionExceed: + logging.warning(f"A condition exceed the maximum defined.") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"A condition exceed the maximum defined.", + ) try: # Simulate the operation alone without sending it @@ -402,3 +422,55 @@ async def update_max_calls( status_code=status.HTTP_400_BAD_REQUEST, detail="Max calls cannot be < -1" ) return crud.update_max_calls_per_month_condition(db, body.max_calls, contract_id) + + +@router.post("/condition") +async def create_condition( + body: schemas.CreateCondition, db: Session = Depends(database.get_db) +): + try: + if ( + body.type == ConditionType.MAX_CALLS_PER_ENTRYPOINT + and body.contract_id is not None + and body.entrypoint_id is not None + ): + return crud.create_max_calls_per_entrypoint_condition( + db, + schemas.CreateMaxCallsPerEntrypointCondition( + contract_id=body.contract_id, + vault_id=body.vault_id, + max=body.max, + entrypoint_id=body.entrypoint_id, + ), + ) + elif ( + body.type == ConditionType.MAX_CALLS_PER_SPONSEE + and body.sponsee_address is not None + ): + return crud.create_max_calls_per_sponsee_condition( + db, + schemas.CreateMaxCallsPerSponseeCondition( + sponsee_address=body.sponsee_address, + vault_id=body.vault_id, + max=body.max, + ), + ) + else: + logging.error("Unknown condition or missing parameters.") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Unknown condition or missing parameters.", + ) + except ConditionAlreadyExists as e: + logging.warning(e) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=str(e), + ) + + +@router.get("/condition/{vault_id}") +async def get_conditions_by_vault( + vault_id: str, db: Session = Depends(database.get_db) +): + return crud.get_conditions_by_vault(db, vault_id) diff --git a/src/schemas.py b/src/schemas.py index 8b6ae86..c7db896 100644 --- a/src/schemas.py +++ b/src/schemas.py @@ -1,5 +1,13 @@ +import datetime +import enum from pydantic import BaseModel, UUID4 -from typing import List, Any +from typing import List, Any, Optional + + +# -- UTILITY TYPES -- +class ConditionType(enum.Enum): + MAX_CALLS_PER_ENTRYPOINT = "MAX_CALLS_PER_ENTRYPOINT" + MAX_CALLS_PER_SPONSEE = "MAX_CALLS_PER_SPONSEE" # Users @@ -124,5 +132,53 @@ class CreateOperation(BaseModel): status: str +# Conditions class UpdateMaxCallsPerMonth(BaseModel): max_calls: int + + +class CreateCondition(BaseModel): + type: ConditionType + sponsee_address: Optional[str] = None + contract_id: Optional[UUID4] = None + entrypoint_id: Optional[UUID4] = None + vault_id: UUID4 + max: int + + +class CreateMaxCallsPerEntrypointCondition(BaseModel): + contract_id: UUID4 + entrypoint_id: UUID4 + vault_id: UUID4 + max: int + + +class CreateMaxCallsPerSponseeCondition(BaseModel): + sponsee_address: str + vault_id: UUID4 + max: int + + +class CheckConditions(BaseModel): + sponsee_address: str + contract_id: UUID4 + entrypoint_id: UUID4 + vault_id: UUID4 + + +class ConditionBase(BaseModel): + vault_id: UUID4 + max: int + current: int + type: ConditionType + id: UUID4 + created_at: datetime.datetime + + +class MaxCallsPerEntrypointCondition(ConditionBase): + contract_id: UUID4 + entrypoint_id: UUID4 + + +class MaxCallsPerSponseeCondition(ConditionBase): + sponsee_address: str diff --git a/src/utils.py b/src/utils.py index d2924ee..4192b48 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,6 +1,9 @@ # -- EXCEPTIONS -- +import enum + + class UserNotFound(Exception): pass @@ -39,3 +42,11 @@ class NotEnoughFunds(Exception): class TooManyCallsForThisMonth(Exception): pass + + +class ConditionAlreadyExists(Exception): + pass + + +class ConditionExceed(Exception): + pass