Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modified GET All Function #7

Merged
merged 2 commits into from
Nov 7, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 76 additions & 9 deletions pkgs/crouton/crouton/core/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any, Callable, List, Type, Generator, Optional, Union
from typing import Any, Callable, List, Type, Generator, Optional, Union, Annotated, get_origin

from fastapi import Depends, HTTPException
from fastapi import Depends, HTTPException, Query
import logging
import json

from . import CRUDGenerator, NOT_FOUND, _utils
from ._types import DEPENDENCIES, PAGINATION, PYDANTIC_SCHEMA as SCHEMA
Expand All @@ -21,6 +23,8 @@
CALLABLE = Callable[..., Model]
CALLABLE_LIST = Callable[..., List[Model]]

logger = logging.getLogger('uvicorn.error')


class SQLAlchemyCRUDRouter(CRUDGenerator[SCHEMA]):
def __init__(
Expand Down Expand Up @@ -66,20 +70,83 @@ def __init__(
**kwargs
)

def get_filter_by(self, query: str) -> dict:

# The Fields in the Schema
accepted_fields = self.schema.__dict__["model_fields"].keys()

# Prepare dictionary by splitting the query
new_query = {}
key_value = query.split('&')
for values in key_value:
key, value = values.split('=')[0], values.split('=')[1]

# Check if the values passed in query match those in the schema
if key not in accepted_fields:
raise HTTPException(400, "Invalid Query")

# Check if the values are repeated in dictionary
if key not in new_query.keys():

# Check the current value of the key
column = getattr(self.db_model, key)
try:
# Find the type of the current value
column_type = column.type.python_type

# Assign correct value to the bool in the query
if column_type == bool:
if value == "True" or value == "true" or value == "TRUE":
value = True
else:
value = False

# Assign the correct type to the value
parsed_value = column_type(value)

# create a key-value dictionary of the query
new_query[key] = parsed_value

# Handle excpetion when error occurs
except (ValueError, TypeError) as e:
raise HTTPException(
status_code=422, detail=f"Invalid value for {key}: {e}"
)

return new_query

def _get_all(self, *args: Any, **kwargs: Any) -> CALLABLE_LIST:
def route(
db: Session = Depends(self.db_func),
pagination: PAGINATION = self.pagination,
query: Annotated[str, Query()] = None,
) -> List[Model]:

skip, limit = pagination.get("skip"), pagination.get("limit")

db_models: List[Model] = (
db.query(self.db_model)
.order_by(getattr(self.db_model, self._pk))
.limit(limit)
.offset(skip)
.all()
)
if query:

# Pass the given query to get checked
new_query = self.get_filter_by(query)

# Find the data that has been filtered
db_models: List[Model] = (
db.query(self.db_model)
.filter_by(**new_query)
.limit(limit)
.offset(skip)
.all()
)

else:
db_models: List[Model] = (
db.query(self.db_model)
.order_by(getattr(self.db_model, self._pk))
.limit(limit)
.offset(skip)
.all()
)

return db_models

return route
Expand Down