Skip to content

Commit

Permalink
Add authentication to endpoint functionalities
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitijrajsharma committed Nov 14, 2023
1 parent 2061349 commit cfed737
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 53 deletions.
41 changes: 37 additions & 4 deletions API/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,53 @@
from enum import Enum
from typing import Union

from fastapi import Header
from fastapi import Depends, Header, HTTPException
from osm_login_python.core import Auth
from pydantic import BaseModel
from pydantic import BaseModel, Field

from src.config import get_oauth_credentials
from src.config import ADMIN_IDS, get_oauth_credentials


class UserRole(Enum):
ADMIN = 1
STAFF = 2
GUEST = 3


class AuthUser(BaseModel):
id: int
username: str
img_url: Union[str, None]
role: UserRole = Field(default=UserRole.GUEST.value)


osm_auth = Auth(*get_oauth_credentials())


def is_admin(osm_id: int):
admin_ids = [int(admin_id) for admin_id in ADMIN_IDS]
return osm_id in admin_ids


def login_required(access_token: str = Header(...)):
return osm_auth.deserialize_access_token(access_token)
user = AuthUser(**osm_auth.deserialize_access_token(access_token))
if is_admin(user.id):
user.role = UserRole.ADMIN
return user


def get_optional_user(access_token: str = Header(default=None)) -> AuthUser:
if access_token:
user = AuthUser(**osm_auth.deserialize_access_token(access_token))
if is_admin(user.id):
user.role = UserRole.ADMIN
return user
else:
# If no token provided, return a user with limited options or guest user
return AuthUser(id=0, username="guest", img_url=None)


def admin_required(user: AuthUser = Depends(login_required)):
if not is_admin(user.id):
raise HTTPException(status_code=403, detail="User is not an admin")
return user
2 changes: 1 addition & 1 deletion API/auth/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from fastapi import APIRouter, Depends, Request

from . import AuthUser, login_required, osm_auth
from . import AuthUser, admin_required, login_required, osm_auth

router = APIRouter(prefix="/auth")

Expand Down
12 changes: 9 additions & 3 deletions API/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from src.config import logger as logging
from src.db_session import database_instance

# from .auth import router as auth_router
from .auth.routers import router as auth_router
from .raw_data import router as raw_data_router
from .tasks import router as tasks_router

Expand All @@ -59,10 +59,16 @@
os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1"

app = FastAPI(title="Raw Data API ")
# app.include_router(auth_router)
app.include_router(auth_router)
app.include_router(raw_data_router)
app.include_router(tasks_router)

app.openapi = {
"info": {
"title": "Raw Data API",
"version": "1.0",
},
"security": [{"OAuth2PasswordBearer": []}],
}

app = VersionedFastAPI(
app, enable_latest=False, version_format="{major}", prefix_format="/v{major}"
Expand Down
49 changes: 47 additions & 2 deletions API/raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,20 @@

"""[Router Responsible for Raw data API ]
"""
import json
import os
import shutil
import time

import requests
from fastapi import APIRouter, Body, Request
from area import area
from fastapi import APIRouter, Body, Depends, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi_versioning import version
from geojson import FeatureCollection

from src.app import RawData
from src.config import ALLOW_BIND_ZIP_FILTER, EXPORT_MAX_AREA_SQKM
from src.config import LIMITER as limiter
from src.config import RATE_LIMIT_PER_MIN as export_rate_limit
from src.config import logger as logging
Expand All @@ -41,6 +44,7 @@
)

from .api_worker import process_raw_data
from .auth import AuthUser, UserRole, get_optional_user

router = APIRouter(prefix="")

Expand Down Expand Up @@ -421,6 +425,7 @@ def get_osm_current_snapshot_as_file(
},
},
),
user: AuthUser = Depends(get_optional_user),
):
"""Generates the current raw OpenStreetMap data available on database based on the input geometry, query and spatial features.
Expand All @@ -434,6 +439,32 @@ def get_osm_current_snapshot_as_file(
2. Now navigate to /tasks/ with your task id to track progress and result
"""

if not (user.role == UserRole.STAFF or user.role == UserRole.ADMIN):
area_m2 = area(json.loads(params.geometry.json()))
area_km2 = area_m2 * 1e-6
RAWDATA_CURRENT_POLYGON_AREA = int(EXPORT_MAX_AREA_SQKM)
if area_km2 > RAWDATA_CURRENT_POLYGON_AREA:
raise HTTPException(
status_code=400,
detail=[
{
"msg": f"""Polygon Area {int(area_km2)} Sq.KM is higher than Threshold : {RAWDATA_CURRENT_POLYGON_AREA} Sq.KM"""
}
],
)
if not params.uuid:
raise HTTPException(
status_code=403,
detail=[{"msg": "Insufficient Permission for uuid = False"}],
)
if ALLOW_BIND_ZIP_FILTER:
if not params.bind_zip:
raise HTTPException(
status_code=403,
detail=[{"msg": "Insufficient Permission for bind_zip"}],
)

queue_name = "recurring_queue" if not params.uuid else "raw_default"
task = process_raw_data.apply_async(args=(params,), queue=queue_name)
return JSONResponse({"task_id": task.id, "track_link": f"/tasks/status/{task.id}/"})
Expand All @@ -442,7 +473,9 @@ def get_osm_current_snapshot_as_file(
@router.post("/snapshot/plain/", response_model=FeatureCollection)
@version(1)
def get_osm_current_snapshot_as_plain_geojson(
request: Request, params: RawDataCurrentParamsBase
request: Request,
params: RawDataCurrentParamsBase,
user: AuthUser = Depends(get_optional_user),
):
"""Generates the Plain geojson for the polygon within 30 Sqkm and returns the result right away
Expand All @@ -453,6 +486,18 @@ def get_osm_current_snapshot_as_plain_geojson(
Returns:
Featurecollection: Geojson
"""
if not (user.role == UserRole.STAFF or user.role == UserRole.ADMIN):
area_m2 = area(json.loads(params.geometry.json()))
area_km2 = area_m2 * 1e-6
if area_km2 > 30:
raise HTTPException(
status_code=400,
detail=[
{
"msg": f"""Polygon Area {int(area_km2)} Sq.KM is higher than Threshold : 30 Sq.KM"""
}
],
)
params.output_type = "geojson" # always geojson
result = RawData(params).extract_plain_geojson()
return result
Expand Down
7 changes: 4 additions & 3 deletions API/tasks.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from celery.result import AsyncResult
from fastapi import APIRouter
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from fastapi_versioning import version

from src.validation.models import SnapshotTaskResponse

from .api_worker import celery
from .auth import AuthUser, admin_required, login_required

router = APIRouter(prefix="/tasks")

Expand Down Expand Up @@ -39,7 +40,7 @@ def get_task_status(task_id):

@router.get("/revoke/{task_id}/")
@version(1)
def revoke_task(task_id):
def revoke_task(task_id, user: AuthUser = Depends(login_required)):
"""Revokes task , Terminates if it is executing
Args:
Expand Down Expand Up @@ -80,7 +81,7 @@ def ping_workers():

@router.get("/purge/")
@version(1)
def discard_all_waiting_tasks():
def discard_all_waiting_tasks(user: AuthUser = Depends(admin_required)):
"""
Discards all waiting tasks from the queue
Returns : Number of tasks discarded
Expand Down
2 changes: 2 additions & 0 deletions docs/src/installation/configurations.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ The following are the different configuration options that are accepted.
| `LOGIN_REDIRECT_URI` | `LOGIN_REDIRECT_URI` | `[OAUTH]` | _none_ | Redirect URL set in the OAuth2 application | REQUIRED |
| `APP_SECRET_KEY` | `APP_SECRET_KEY` | `[OAUTH]` | _none_ | High-entropy string generated for the application | REQUIRED |
| `OSM_URL` | `OSM_URL` | `[OAUTH]` | `https://www.openstreetmap.org` | OSM instance Base URL | OPTIONAL |
| `ADMIN_IDS` | `ADMIN_IDS` | `[OAUTH]` | `00000` | List of Admin OSMId separated by , | OPTIONAL |
| `LOG_LEVEL` | `LOG_LEVEL` | `[API_CONFIG]` | `debug` | Application log level; info,debug,warning,error | OPTIONAL |
| `RATE_LIMITER_STORAGE_URI` | `RATE_LIMITER_STORAGE_URI` | `[API_CONFIG]` | `redis://redis:6379` | Redis connection string for rate-limiter data | OPTIONAL |
| `RATE_LIMIT_PER_MIN` | `RATE_LIMIT_PER_MIN` | `[API_CONFIG]` | `5` | Number of requests per minute before being rate limited | OPTIONAL |
Expand Down Expand Up @@ -67,6 +68,7 @@ The following are the different configuration options that are accepted.
| `LOGIN_REDIRECT_URI` | TBD | Yes | No |
| `APP_SECRET_KEY` | TBD | Yes | No |
| `OSM_URL` | TBD | Yes | No |
| `ADMIN_IDS` | TBD | Yes | No |
| `LOG_LEVEL` | `[API_CONFIG]` | Yes | Yes |
| `RATE_LIMITER_STORAGE_URI` | `[API_CONFIG]` | Yes | No |
| `RATE_LIMIT_PER_MIN` | `[API_CONFIG]` | Yes | No |
Expand Down
7 changes: 7 additions & 0 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@
"API_CONFIG", "ENABLE_TILES", fallback=None
)

######

ADMIN_IDS = os.environ.get("ADMIN_IDS") or config.get(
"OAUTH", "ADMIN_IDS", fallback="00000"
).split(",")


####################

### EXPORT_UPLOAD CONFIG BLOCK
Expand Down
40 changes: 0 additions & 40 deletions src/validation/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from enum import Enum
from typing import Dict, List, Optional, Union

from area import area
from geojson_pydantic import MultiPolygon, Polygon
from geojson_pydantic.types import BBox
from pydantic import BaseModel as PydanticModel
Expand Down Expand Up @@ -85,22 +84,6 @@ class JoinFilterType(Enum):
AND = "AND"


#
# "tags": { # no of rows returned
# "point" : {"amenity":["shop"]},
# "line" : {},
# "polygon" : {"key":["value"]},
# "all_geometry" : {"building":['yes']}
# },
# "attributes": { # no of columns / name
# "point": [], column
# "line" : [],
# "polygon" : [],
# "all_geometry" : [],
# }
# }


class SQLFilter(BaseModel):
join_or: Optional[Dict[str, List[str]]]
join_and: Optional[Dict[str, List[str]]]
Expand Down Expand Up @@ -163,17 +146,6 @@ class RawDataCurrentParamsBase(BaseModel):
},
)

@validator("geometry", always=True)
def check_geometry_area(cls, value, values):
"""Validates geom area_m2"""
area_m2 = area(json.loads(value.json()))
area_km2 = area_m2 * 1e-6
if area_km2 > 30: # 30 square km
raise ValueError(
f"""Polygon Area {int(area_km2)} Sq.KM is higher than 30 sqkm , Consider using /snapshot/ for larger area"""
)
return value

@validator("geometry_type", allow_reuse=True)
def return_unique_value(cls, value):
"""return unique list"""
Expand Down Expand Up @@ -205,18 +177,6 @@ def check_bind_option(cls, value, values):
)
return value

@validator("geometry", always=True)
def check_geometry_area(cls, value, values):
"""Validates geom area_m2"""
area_m2 = area(json.loads(value.json()))
area_km2 = area_m2 * 1e-6
RAWDATA_CURRENT_POLYGON_AREA = int(EXPORT_MAX_AREA_SQKM)
if area_km2 > RAWDATA_CURRENT_POLYGON_AREA:
raise ValueError(
f"""Polygon Area {int(area_km2)} Sq.KM is higher than Threshold : {RAWDATA_CURRENT_POLYGON_AREA} Sq.KM"""
)
return value


class SnapshotResponse(BaseModel):
task_id: str
Expand Down

0 comments on commit cfed737

Please sign in to comment.