From 94484688d4bccbdeb441e27f5897cf3358264c67 Mon Sep 17 00:00:00 2001 From: Samra Solomon Barnabas Date: Wed, 6 Nov 2024 17:54:11 +0300 Subject: [PATCH 1/2] Crouton: -Feature Filter Function --- pkgs/crouton/crouton/core/sqlalchemy.py | 49 +++++++++++++++++++++---- 1 file changed, 41 insertions(+), 8 deletions(-) diff --git a/pkgs/crouton/crouton/core/sqlalchemy.py b/pkgs/crouton/crouton/core/sqlalchemy.py index 58270f3..f55e884 100644 --- a/pkgs/crouton/crouton/core/sqlalchemy.py +++ b/pkgs/crouton/crouton/core/sqlalchemy.py @@ -1,6 +1,8 @@ -from typing import Any, Callable, List, Type, Generator, Optional, Union +from typing import Any, Callable, List, Type, Generator, Optional, Union, Dict from fastapi import Depends, HTTPException +import logging +import json from . import CRUDGenerator, NOT_FOUND, _utils from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA @@ -21,6 +23,8 @@ CALLABLE = Callable[..., Model] CALLABLE_LIST = Callable[..., List[Model]] +logger = logging.getLogger('uvicorn.error') + class SQLAlchemyCRUDRouter(CRUDGenerator[SCHEMA]): def __init__( @@ -66,20 +70,49 @@ def __init__( **kwargs ) + def get_filter_by(self, query: dict) -> dict: + + # The Fields in the Schema + accepted_fields = self.schema.__dict__["model_fields"].keys() + + # Check if the values passed match those in the schema + for key in query.keys(): + if key not in accepted_fields: + raise HTTPException(400, "Invalid Query") + + return query + def _get_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST: def route( db: Session = Depends(self.db_func), pagination: PAGINATION = self.pagination, + # query: Optional[Dict[str, Any]] = None, + **kwargs ) -> List[Model]: + skip, limit = pagination.get("skip"), pagination.get("limit") - db_models: List[Model] = ( - db.query(self.db_model) - .order_by(getattr(self.db_model, self._pk)) - .limit(limit) - .offset(skip) - .all() - ) + if kwargs: + # Pass the given query to get checked + new_query = self.get_filter_by(eval(kwargs['kwargs'])) + + db_models: List[Model] = ( + db.query(self.db_model) + .filter_by(**new_query) + .limit(limit) + .offset(skip) + .all() + ) + + else: + db_models: List[Model] = ( + db.query(self.db_model) + .order_by(getattr(self.db_model, self._pk)) + .limit(limit) + .offset(skip) + .all() + ) + return db_models return route From 0a7a9066c74fcd833fdd51e84644dd4e4a4e983d Mon Sep 17 00:00:00 2001 From: Samra Solomon Barnabas Date: Thu, 7 Nov 2024 10:58:46 +0300 Subject: [PATCH 2/2] Crouton: -Fix Field Typing --- pkgs/crouton/crouton/core/sqlalchemy.py | 54 ++++++++++++++++++++----- 1 file changed, 44 insertions(+), 10 deletions(-) diff --git a/pkgs/crouton/crouton/core/sqlalchemy.py b/pkgs/crouton/crouton/core/sqlalchemy.py index f55e884..6509bb8 100644 --- a/pkgs/crouton/crouton/core/sqlalchemy.py +++ b/pkgs/crouton/crouton/core/sqlalchemy.py @@ -1,6 +1,6 @@ -from typing import Any, Callable, List, Type, Generator, Optional, Union, Dict +from typing import Any, Callable, List, Type, Generator, Optional, Union, Annotated, get_origin -from fastapi import Depends, HTTPException +from fastapi import Depends, HTTPException, Query import logging import json @@ -70,32 +70,66 @@ def __init__( **kwargs ) - def get_filter_by(self, query: dict) -> dict: + def get_filter_by(self, query: str) -> dict: # The Fields in the Schema accepted_fields = self.schema.__dict__["model_fields"].keys() - # Check if the values passed match those in the schema - for key in query.keys(): + # Prepare dictionary by splitting the query + new_query = {} + key_value = query.split('&') + for values in key_value: + key, value = values.split('=')[0], values.split('=')[1] + + # Check if the values passed in query match those in the schema if key not in accepted_fields: raise HTTPException(400, "Invalid Query") - return query + # Check if the values are repeated in dictionary + if key not in new_query.keys(): + + # Check the current value of the key + column = getattr(self.db_model, key) + try: + # Find the type of the current value + column_type = column.type.python_type + + # Assign correct value to the bool in the query + if column_type == bool: + if value == "True" or value == "true" or value == "TRUE": + value = True + else: + value = False + + # Assign the correct type to the value + parsed_value = column_type(value) + + # create a key-value dictionary of the query + new_query[key] = parsed_value + + # Handle excpetion when error occurs + except (ValueError, TypeError) as e: + raise HTTPException( + status_code=422, detail=f"Invalid value for {key}: {e}" + ) + + return new_query def _get_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST: def route( db: Session = Depends(self.db_func), pagination: PAGINATION = self.pagination, - # query: Optional[Dict[str, Any]] = None, - **kwargs + query: Annotated[str, Query()] = None, ) -> List[Model]: skip, limit = pagination.get("skip"), pagination.get("limit") - if kwargs: + if query: + # Pass the given query to get checked - new_query = self.get_filter_by(eval(kwargs['kwargs'])) + new_query = self.get_filter_by(query) + # Find the data that has been filtered db_models: List[Model] = ( db.query(self.db_model) .filter_by(**new_query)