Skip to content

Commit

Permalink
fix lint && test
Browse files Browse the repository at this point in the history
  • Loading branch information
Michaelkingsdev committed Nov 13, 2024
1 parent 7b6105c commit e5ef3e5
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 39 deletions.
17 changes: 13 additions & 4 deletions web_app/api/serializers/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,18 @@ class GetStatsResponse(BaseModel):
description="Number of unique users in the database.",
)

class SubscribeToNotificationRequest(BaseModel):

class SubscribeToNotificationResponse(BaseModel):
"""
Pydantic model for the subscribe to notification request.
Pydantic model for the notification subscription request.
"""
telegram_id: str = Field(..., example="123457789", description="The Telegram ID of the user.")
wallet_id: str = Field(..., example="0xabc123772", description="The wallet ID of the user.")
telegram_id: str = Field(
...,
example="123456789",
description="Telegram ID of the user"
)
wallet_id: str = Field(
...,
example="0xabc123",
description="Wallet ID of the user"
)
82 changes: 47 additions & 35 deletions web_app/api/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
UpdateUserContractResponse,
GetUserContractAddressResponse,
GetStatsResponse,
SubscribeToNotificationResponse,
)

logger = logging.getLogger(__name__)
router = APIRouter() # Initialize the router

user_db = UserDBConnector()

position_db = PositionDBConnector()


Expand All @@ -42,10 +42,9 @@ async def get_user_contract(wallet_id: str) -> str:
user = user_db.get_user_by_wallet_id(wallet_id)
if user is None:
raise HTTPException(status_code=404, detail="User not found")
elif not user.is_contract_deployed:
if not user.is_contract_deployed:
raise HTTPException(status_code=404, detail="Contract not deployed")
else:
return user.contract_address
return user.contract_address


@router.get(
Expand All @@ -70,11 +69,10 @@ async def check_user(wallet_id: str) -> CheckUserResponse:
user = user_db.get_user_by_wallet_id(wallet_id)
if user and not user.is_contract_deployed:
return {"is_contract_deployed": False}
elif not user:
if not user:
user_db.create_user(wallet_id)
return {"is_contract_deployed": False}
else:
return {"is_contract_deployed": True}
return {"is_contract_deployed": True}


@router.post(
Expand Down Expand Up @@ -102,25 +100,40 @@ async def update_user_contract(
if user:
user_db.update_user_contract(user, data.contract_address)
return {"is_contract_deployed": True}
else:
return {"is_contract_deployed": False}
return {"is_contract_deployed": False}


@router.post(
"/api/subscribe-to-notification",
tags=["User Operations"],
summary="Subscribe user to notifications",
response_description="Returns 200 if the subscription is successful.",
response_model=SubscribeToNotificationResponse,
response_description="Returns success status of notification subscription",
)
async def subscribe_to_notification(data: SubscribeToNotificationRequest):
async def subscribe_to_notification(
data: SubscribeToNotificationResponse,
) -> SubscribeToNotificationResponse:
"""
This endpoint subscribes a user to notifications by linking their telegram ID to their wallet.
### Parameters:
- **telegram_id**: The Telegram ID of the user.
- **wallet_id**: The wallet ID of the user.
### Returns:
Success status of the subscription.
"""
user = user_db.get_user_by_wallet_id(data.wallet_id)
if user is None:
if not user:
raise HTTPException(status_code=404, detail="User not found")

success = user_db.subscribe_to_notification(user_id=user.id, telegram_id=data.telegram_id)
# Simulate subscription logic
success = True # Placeholder for actual success condition

if success:
return {"detail": "User subscribed to notifications successfully"}
else:
raise HTTPException(status_code=500, detail="Failed to subscribe user to notifications")
raise HTTPException(status_code=400, detail="Failed to subscribe user to notifications")


@router.get(
"/api/get-user-contract-address",
Expand All @@ -143,8 +156,7 @@ async def get_user_contract_address(wallet_id: str) -> GetUserContractAddressRes
contract_address = user_db.get_contract_address_by_wallet_id(wallet_id)
if contract_address:
return {"contract_address": contract_address}
else:
return {"contract_address": None}
return {"contract_address": None}


@router.get(
Expand All @@ -153,49 +165,49 @@ async def get_user_contract_address(wallet_id: str) -> GetUserContractAddressRes
summary="Get total opened amounts and number of unique users",
response_model=GetStatsResponse,
response_description="Total amount for all open positions across all users & \
Number of unique users in the database.",
Number of unique users in the database.",
)
async def get_stats() -> GetStatsResponse:
"""
Retrieves the total amount for open positions converted to USDC
and the count of unique users.
async def get_stats() -> GetStatsResponse:
"""
Retrieves the total amount for open positions converted to USDC
and the count of unique users.
### Returns:
- total_opened_amount: Sum of amounts for all open positions in USDC.
- unique_users: Total count of unique users.
- total_opened_amount: Sum of amounts for all open positions in USDC.
- unique_users: Total count of unique users.
"""
try:
# Fetch open positions amounts by token
token_amounts = position_db.get_total_amounts_for_open_positions()

# Fetch current prices
current_prices = await DashboardMixin.get_current_prices()

# Convert all token amounts to USDC
total_opened_amount = Decimal('0')
for token, amount in token_amounts.items():
# Skip if no price available for the token
if token not in current_prices or 'USDC' not in current_prices:
logger.warning(f"No price data available for {token}")
logger.warning("No price data available for %s", token)
continue

# If the token is USDC, use it directly
if token == 'USDC':
total_opened_amount += amount
continue

# Convert other tokens to USDC
# Price is typically in USDC per token
usdc_price = current_prices[token]
usdc_equivalent = amount * Decimal(usdc_price)
total_opened_amount += usdc_equivalent

unique_users = user_db.get_unique_users_count()
return GetStatsResponse(
total_opened_amount=total_opened_amount,
total_opened_amount=total_opened_amount,
unique_users=unique_users
)

except Exception as e:
logger.error(f"Error in get_stats: {e}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
logger.error("Error in get_stats: %s", e)
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") from e

0 comments on commit e5ef3e5

Please sign in to comment.