Skip to content

Commit

Permalink
Merge pull request #297 from hotosm/fix/role-update-my-info
Browse files Browse the repository at this point in the history
feat: update user profile role to array & add project centroid endpoints
  • Loading branch information
nrjadkry authored Oct 18, 2024
2 parents ad102a4 + 2286ed8 commit da594b0
Show file tree
Hide file tree
Showing 24 changed files with 344 additions and 105 deletions.
2 changes: 1 addition & 1 deletion src/backend/app/db/db_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ class GroundControlPoint(Base):
class DbUserProfile(Base):
__tablename__ = "user_profile"
user_id = cast(str, Column(String, ForeignKey("users.id"), primary_key=True))
role = cast(UserRole, Column(Enum(UserRole), default=UserRole.DRONE_PILOT))
role = cast(list, Column(ARRAY(Enum(UserRole))))
phone_number = cast(str, Column(String))
country = cast(str, Column(String))
city = cast(str, Column(String))
Expand Down
63 changes: 63 additions & 0 deletions src/backend/app/migrations/versions/b36a13183a83_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from typing import Sequence, Union
from alembic import op
from app.models.enums import UserRole
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy.sql import text


# revision identifiers, used by Alembic.
revision: str = "b36a13183a83"
down_revision: Union[str, None] = "5235ef4afa9c"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# Check if the enum type 'userrole' already exists
conn = op.get_bind()
result = conn.execute(
text("SELECT 1 FROM pg_type WHERE typname = 'userrole';")
).scalar()

if not result:
# Create a new enum type for user roles if it doesn't exist
userrole_enum = sa.Enum(UserRole, name="userrole")
userrole_enum.create(op.get_bind())

# Change the column from a single enum value to an array of enums
# We need to cast each enum value to text and back to an array of enums
op.alter_column(
"user_profile",
"role",
existing_type=sa.Enum(UserRole, name="userrole"),
type_=sa.ARRAY(
postgresql.ENUM("PROJECT_CREATOR", "DRONE_PILOT", name="userrole")
),
postgresql_using="ARRAY[role]::userrole[]", # Convert the single enum to an array
nullable=True,
)


def downgrade() -> None:
# Change the column back from an array to a single enum value
op.alter_column(
"user_profile",
"role",
existing_type=sa.ARRAY(
postgresql.ENUM("PROJECT_CREATOR", "DRONE_PILOT", name="userrole")
),
type_=sa.Enum(UserRole, name="userrole"),
postgresql_using="role[1]",
nullable=True,
)

# Drop the enum type only if it exists
conn = op.get_bind()
result = conn.execute(
text("SELECT 1 FROM pg_type WHERE typname = 'userrole';")
).scalar()

if result:
userrole_enum = sa.Enum(UserRole, name="userrole")
userrole_enum.drop(op.get_bind())
32 changes: 32 additions & 0 deletions src/backend/app/projects/project_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,38 @@
from app.projects.image_processing import DroneImageProcessor
from app.projects import project_schemas
from minio import S3Error
from psycopg.rows import dict_row


async def get_centroids(db: Connection):
try:
async with db.cursor(row_factory=dict_row) as cur:
await cur.execute("""
SELECT
p.id,
p.slug,
p.name,
ST_AsGeoJSON(p.centroid)::jsonb AS centroid,
COUNT(t.id) AS total_task_count,
COUNT(CASE WHEN te.state IN ('LOCKED_FOR_MAPPING', 'REQUEST_FOR_MAPPING', 'IMAGE_UPLOADED', 'UNFLYABLE_TASK') THEN 1 END) AS ongoing_task_count,
COUNT(CASE WHEN te.state = 'IMAGE_PROCESSED' THEN 1 END) AS completed_task_count
FROM
projects p
LEFT JOIN
tasks t ON p.id = t.project_id
LEFT JOIN
task_events te ON t.id = te.task_id
GROUP BY
p.id, p.slug, p.name, p.centroid;
""")
centroids = await cur.fetchall()

if not centroids:
raise HTTPException(status_code=404, detail="No centroids found.")

return centroids
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


async def upload_file_to_s3(
Expand Down
39 changes: 36 additions & 3 deletions src/backend/app/projects/project_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,28 @@
)


@router.get(
"/centroids", tags=["Projects"], response_model=list[project_schemas.CentroidOut]
)
async def read_project_centroids(
db: Annotated[Connection, Depends(database.get_db)],
user_data: Annotated[AuthUser, Depends(login_required)],
):
"""
Get all project centroids.
"""
try:
centroids = await project_logic.get_centroids(
db,
)
if not centroids:
return []

return centroids
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


@router.get("/{project_id}/download-boundaries", tags=["Projects"])
async def download_boundaries(
project_id: Annotated[
Expand Down Expand Up @@ -448,7 +470,6 @@ async def odm_webhook(

task_id = payload.get("uuid")
status = payload.get("status")

if not task_id or not status:
raise HTTPException(status_code=400, detail="Invalid webhook payload")

Expand All @@ -457,6 +478,8 @@ async def odm_webhook(
# If status is 'success', download and upload assets to S3.
# 40 is the status code for success in odm
if status["code"] == 40:
log.info(f"Task ID: {task_id}, Status: going for download......")

# Call function to download assets from ODM and upload to S3
background_tasks.add_task(
image_processing.download_and_upload_assets_from_odm_to_s3,
Expand All @@ -468,6 +491,16 @@ async def odm_webhook(
dtm_user_id,
)
elif status["code"] == 30:
# failed task
log.error(f'ODM task {task_id} failed: {status["errorMessage"]}')
background_tasks.add_task(
image_processing.download_and_upload_assets_from_odm_to_s3,
db,
settings.NODE_ODM_URL,
task_id,
dtm_project_id,
dtm_task_id,
dtm_user_id,
)

log.info(f"Task ID: {task_id}, Status: Webhook received")

return {"message": "Webhook received", "task_id": task_id}
27 changes: 27 additions & 0 deletions src/backend/app/projects/project_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,33 @@
from app.s3 import get_presigned_url


class CentroidOut(BaseModel):
id: uuid.UUID
slug: str
name: str
centroid: dict
total_task_count: int
ongoing_task_count: int
completed_task_count: int
status: str = None

@model_validator(mode="after")
def calculate_status(cls, values):
"""Set the project status based on task counts."""
ongoing_task_count = values.ongoing_task_count
completed_task_count = values.completed_task_count
total_task_count = values.total_task_count

if completed_task_count == 0 and ongoing_task_count == 0:
values.status = "not-started"
elif completed_task_count == total_task_count:
values.status = "completed"
else:
values.status = "ongoing"

return values


class AssetsInfo(BaseModel):
project_id: str
task_id: str
Expand Down
65 changes: 21 additions & 44 deletions src/backend/app/tasks/task_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,30 +73,12 @@ async def get_task_stats(
):
"Retrieve statistics related to tasks for the authenticated user."
user_id = user_data.id

try:
async with db.cursor(row_factory=dict_row) as cur:
# Check if the user profile exists
await cur.execute(
"""SELECT role FROM user_profile WHERE user_id = %(user_id)s""",
{"user_id": user_id},
)
records = await cur.fetchall()

if not records:
raise HTTPException(status_code=404, detail="User profile not found")
roles = [record["role"] for record in records]

if UserRole.PROJECT_CREATOR.name in roles:
role = "PROJECT_CREATOR"
else:
role = "DRONE_PILOT"

# Query for task statistics
raw_sql = """
SELECT
COUNT(CASE WHEN te.state = 'REQUEST_FOR_MAPPING' THEN 1 END) AS request_logs,
COUNT(CASE WHEN te.state IN ('LOCKED_FOR_MAPPING', 'REQUEST_FOR_MAPPING', 'IMAGE_UPLOADED', 'UNFLYABLE_TASK') THEN 1 END) AS ongoing_tasks,
COUNT(CASE WHEN te.state IN ('LOCKED_FOR_MAPPING', 'IMAGE_UPLOADED') THEN 1 END) AS ongoing_tasks,
COUNT(CASE WHEN te.state = 'IMAGE_PROCESSED' THEN 1 END) AS completed_tasks,
COUNT(CASE WHEN te.state = 'UNFLYABLE_TASK' THEN 1 END) AS unflyable_tasks
FROM (
Expand All @@ -106,17 +88,21 @@ async def get_task_stats(
te.created_at
FROM task_events te
WHERE
(%(role)s = 'DRONE_PILOT' AND te.user_id = %(user_id)s)
(
%(role)s = 'DRONE_PILOT'
AND te.user_id = %(user_id)s
)
OR
(%(role)s != 'DRONE_PILOT' AND te.task_id IN (
SELECT t.id
FROM tasks t
WHERE t.project_id IN (SELECT id FROM projects WHERE author_id = %(user_id)s)
(%(role)s = 'PROJECT_CREATOR' AND te.project_id IN (
SELECT p.id
FROM projects p
WHERE p.author_id = %(user_id)s
))
ORDER BY te.task_id, te.created_at DESC
) AS te;
"""
await cur.execute(raw_sql, {"user_id": user_id, "role": role})

await cur.execute(raw_sql, {"user_id": user_id, "role": user_data.role})
db_counts = await cur.fetchone()

return db_counts
Expand All @@ -137,25 +123,7 @@ async def list_tasks(
):
"""Get all tasks for a all user."""
user_id = user_data.id

async with db.cursor(row_factory=dict_row) as cur:
# Check if the user profile exists
await cur.execute(
"""SELECT role FROM user_profile WHERE user_id = %(user_id)s""",
{"user_id": user_id},
)
records = await cur.fetchall()

if not records:
raise HTTPException(status_code=404, detail="User profile not found")

roles = [record["role"] for record in records]

if UserRole.PROJECT_CREATOR.name in roles:
role = "PROJECT_CREATOR"
else:
role = "DRONE_PILOT"

role = user_data.role
return await task_schemas.UserTasksStatsOut.get_tasks_by_user(
db, user_id, role, skip, limit
)
Expand Down Expand Up @@ -183,10 +151,19 @@ async def new_event(
):
user_id = user_data.id
project = project.model_dump()
user_role = user_data.role

match detail.event:
case EventType.REQUESTS:
# Determine the appropriate state and message
is_author = project["author_id"] == user_id

if user_role != UserRole.DRONE_PILOT and not is_author:
raise HTTPException(
status_code=403,
detail="Only the project author or drone operators can request tasks for this project.",
)

requires_approval = project["requires_approval_from_manager_for_locking"]

if is_author or not requires_approval:
Expand Down
26 changes: 21 additions & 5 deletions src/backend/app/tasks/task_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ async def get_tasks_by_user(
):
async with db.cursor(row_factory=class_row(UserTasksStatsOut)) as cur:
await cur.execute(
"""SELECT DISTINCT ON (tasks.id)
"""
SELECT DISTINCT ON (tasks.id)
tasks.id AS task_id,
tasks.project_task_index AS project_task_index,
task_events.project_id AS project_id,
Expand All @@ -169,7 +170,7 @@ async def get_tasks_by_user(
task_events.updated_at,
CASE
WHEN task_events.state = 'REQUEST_FOR_MAPPING' THEN 'request logs'
WHEN task_events.state = 'LOCKED_FOR_MAPPING' OR task_events.state = 'IMAGE_UPLOADED' THEN 'ongoing'
WHEN task_events.state IN ('LOCKED_FOR_MAPPING', 'IMAGE_UPLOADED') THEN 'ongoing'
WHEN task_events.state = 'IMAGE_PROCESSED' THEN 'completed'
WHEN task_events.state = 'UNFLYABLE_TASK' THEN 'unflyable task'
ELSE ''
Expand All @@ -182,16 +183,31 @@ async def get_tasks_by_user(
projects ON task_events.project_id = projects.id
WHERE
(
%(role)s = 'DRONE_PILOT' AND task_events.user_id = %(user_id)s
%(role)s = 'DRONE_PILOT'
AND task_events.user_id = %(user_id)s
)
OR
(
%(role)s!= 'DRONE_PILOT' AND task_events.project_id IN (SELECT id FROM projects WHERE author_id = %(user_id)s)
%(role)s = 'PROJECT_CREATOR'
AND task_events.project_id IN (
SELECT p.id
FROM projects p
WHERE p.id IN (
SELECT t.project_id
FROM tasks t
WHERE t.project_id IN (
SELECT DISTINCT te2.project_id
FROM task_events te2
WHERE te2.user_id = %(user_id)s
)
)
)
)
ORDER BY
tasks.id, task_events.created_at DESC
OFFSET %(skip)s
LIMIT %(limit)s;""",
LIMIT %(limit)s;
""",
{"user_id": user_id, "role": role, "skip": skip, "limit": limit},
)
try:
Expand Down
4 changes: 3 additions & 1 deletion src/backend/app/users/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def login(self) -> dict:
login_url, _ = self.oauth.authorization_url(self.authorization_url)
return json.loads(Login(login_url=login_url).model_dump_json())

def callback(self, callback_url: str) -> str:
def callback(self, callback_url: str, role: str) -> str:
"""Performs token exchange between Google and the callback website.
Core will use Oauth secret key from configuration while deserializing token,
Expand All @@ -83,11 +83,13 @@ def callback(self, callback_url: str) -> str:

data = resp.json()
serializer = URLSafeSerializer(self.secret_key)

user_data = {
"id": data.get("id"),
"email": data.get("email"),
"name": data.get("name"),
"profile_img": data.get("picture") if data.get("picture") else None,
"role": role,
}
token = serializer.dumps(user_data)
access_token = base64.b64encode(bytes(token, "utf-8")).decode("utf-8")
Expand Down
Loading

0 comments on commit da594b0

Please sign in to comment.