From 7432a8ef44bddedfb7a873788040a48a5a9974b7 Mon Sep 17 00:00:00 2001 From: Kieron Taylor Date: Wed, 11 Oct 2023 15:36:45 +0000 Subject: [PATCH] Use best/better practices for type hints and pydantic constraints. --- server/npg/porch/endpoints/pipelines.py | 9 ++++---- server/npg/porch/endpoints/tasks.py | 29 ++++++++++++------------- server/npg/porch/models/pipeline.py | 5 ++--- server/npg/porch/models/task.py | 7 +++--- server/npg/porchdb/data_access.py | 27 +++++++++++------------ 5 files changed, 36 insertions(+), 41 deletions(-) diff --git a/server/npg/porch/endpoints/pipelines.py b/server/npg/porch/endpoints/pipelines.py index 2ee358e..d1afcd7 100644 --- a/server/npg/porch/endpoints/pipelines.py +++ b/server/npg/porch/endpoints/pipelines.py @@ -20,7 +20,6 @@ from fastapi import APIRouter, HTTPException, Depends import logging -from typing import List, Optional import re from sqlalchemy.exc import IntegrityError from sqlalchemy.orm.exc import NoResultFound @@ -44,7 +43,7 @@ @router.get( "/", - response_model=List[Pipeline], + response_model=list[Pipeline], summary="Get information about all pipelines.", description=''' Returns a list of pydantic Pipeline models. @@ -52,11 +51,11 @@ A valid token issued for any pipeline is required for authorisation.''' ) async def get_pipelines( - uri: Optional[str] = None, - version: Optional[str] = None, + uri: str | None = None, + version: str | None = None, db_accessor=Depends(get_DbAccessor), permissions=Depends(validate) -) -> List[Pipeline]: +) -> list[Pipeline]: return await db_accessor.get_all_pipelines(uri, version) diff --git a/server/npg/porch/endpoints/tasks.py b/server/npg/porch/endpoints/tasks.py index 02d7e8c..3a57d2a 100644 --- a/server/npg/porch/endpoints/tasks.py +++ b/server/npg/porch/endpoints/tasks.py @@ -19,18 +19,17 @@ # this program. If not, see . import logging -from fastapi import APIRouter, HTTPException, Depends -from pydantic import PositiveInt -from typing import List, Optional -from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm.exc import NoResultFound -from starlette import status +from typing import Annotated +from fastapi import APIRouter, Depends, HTTPException, Query +from npg.porch.auth.token import validate +from npg.porch.models.permission import PermissionValidationException from npg.porch.models.pipeline import Pipeline from npg.porch.models.task import Task, TaskStateEnum -from npg.porch.models.permission import PermissionValidationException from npg.porchdb.connection import get_DbAccessor -from npg.porch.auth.token import validate +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm.exc import NoResultFound +from starlette import status def _validate_request(permission, pipeline): @@ -61,18 +60,18 @@ def _validate_request(permission, pipeline): @router.get( "/", - response_model=List[Task], + response_model=list[Task], summary="Returns all tasks, and can be filtered to task status or pipeline name", description=''' Return all tasks. The list of tasks can be filtered by supplying a pipeline name and/or task status''' ) async def get_tasks( - pipeline_name: Optional[str] = None, - status: Optional[TaskStateEnum] = None, + pipeline_name: str | None = None, + status: TaskStateEnum | None = None, db_accessor=Depends(get_DbAccessor), permission=Depends(validate) -) -> List[Task]: +) -> list[Task]: print(pipeline_name, status) return await db_accessor.get_tasks(pipeline_name=pipeline_name, task_status=status) @@ -152,7 +151,7 @@ async def update_task( @router.post( "/claim", - response_model=List[Task], + response_model=list[Task], responses={ status.HTTP_200_OK: {"description": "Receive a list of tasks that have been claimed"} }, @@ -173,10 +172,10 @@ async def update_task( ) async def claim_task( pipeline: Pipeline, - num_tasks: PositiveInt | None = 1, + num_tasks: Annotated[int | None, Query(gt=0)] = 1, db_accessor=Depends(get_DbAccessor), permission=Depends(validate) -) -> List[Task]: +) -> list[Task]: _validate_request(permission, pipeline) tasks = await db_accessor.claim_tasks( diff --git a/server/npg/porch/models/pipeline.py b/server/npg/porch/models/pipeline.py index 2ffa1f1..ce212b8 100644 --- a/server/npg/porch/models/pipeline.py +++ b/server/npg/porch/models/pipeline.py @@ -19,7 +19,6 @@ # this program. If not, see . from pydantic import BaseModel, Field -from typing import Optional class Pipeline(BaseModel): name: str = Field( @@ -27,12 +26,12 @@ class Pipeline(BaseModel): title='Pipeline Name', description='A user-controlled name for the pipeline' ) - uri: Optional[str] = Field( + uri: str | None = Field( default = None, title='URI', description='URI to bootstrap the pipeline code' ) - version: Optional[str] = Field( + version: str | None = Field( default = None, title='Version', description='Pipeline version to use with URI' diff --git a/server/npg/porch/models/task.py b/server/npg/porch/models/task.py index 7233812..788518d 100644 --- a/server/npg/porch/models/task.py +++ b/server/npg/porch/models/task.py @@ -22,7 +22,6 @@ import hashlib import ujson from pydantic import BaseModel, Field -from typing import Optional, Dict from npg.porch.models.pipeline import Pipeline @@ -36,17 +35,17 @@ class TaskStateEnum(str, Enum): class Task(BaseModel): pipeline: Pipeline - task_input_id: Optional[str] = Field( + task_input_id: str | None = Field( None, title='Task Input ID', description='A stringified unique identifier for a piece of work. Set by the npg_porch server, not the client' # noqa: E501 ) - task_input: Dict = Field( + task_input: dict = Field( None, title='Task Input', description='A structured parameter set that uniquely identifies a piece of work, and enables an iteration of a pipeline' # noqa: E501 ) - status: Optional[TaskStateEnum] = None + status: TaskStateEnum | None = None def generate_task_id(self): return hashlib.sha256(ujson.dumps(self.task_input, sort_keys=True).encode()).hexdigest() diff --git a/server/npg/porchdb/data_access.py b/server/npg/porchdb/data_access.py index d36e2d9..effd0e4 100644 --- a/server/npg/porchdb/data_access.py +++ b/server/npg/porchdb/data_access.py @@ -23,7 +23,6 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import contains_eager, joinedload from sqlalchemy.orm.exc import NoResultFound -from typing import Optional, List from npg.porchdb.models import Pipeline as DbPipeline, Task as DbTask, Event from npg.porch.models import Task, Pipeline, TaskStateEnum @@ -55,10 +54,10 @@ async def _get_pipeline_db_object(self, name: str) -> Pipeline: async def _get_pipeline_db_objects( self, - name: Optional[str] = None, - version: Optional[str] = None, - uri: Optional[str] = None - ) -> List[Pipeline]: + name: str | None = None, + version: str | None = None, + uri: str | None = None + ) -> list[Pipeline]: query = select(DbPipeline) if name: query = query.filter_by(name=name) @@ -72,9 +71,9 @@ async def _get_pipeline_db_objects( async def get_all_pipelines( self, - uri: Optional[str] = None, - version: Optional[str] = None - ) -> List[Pipeline]: + uri: str | None = None, + version: str | None = None + ) -> list[Pipeline]: pipelines = [] pipelines = await self._get_pipeline_db_objects(uri=uri, version=version) return [pipe.convert_to_model() for pipe in pipelines] @@ -117,8 +116,8 @@ async def create_task(self, token_id: int, task: Task) -> Task: return t.convert_to_model() async def claim_tasks( - self, token_id: int, pipeline: Pipeline, claim_limit: Optional[int] = 1 - ) -> List[Task]: + self, token_id: int, pipeline: Pipeline, claim_limit: int | None = 1 + ) -> list[Task]: session = self.session try: @@ -196,9 +195,9 @@ async def update_task(self, token_id: int, task: Task) -> Task: async def get_tasks( self, - pipeline_name: Optional[str] = None, - task_status: Optional[TaskStateEnum] = None - ) -> List[Task]: + pipeline_name: str | None = None, + task_status: TaskStateEnum | None = None + ) -> list[Task]: ''' Gets all the tasks. @@ -229,7 +228,7 @@ def convert_task_to_db(task: Task, pipeline: DbPipeline) -> DbTask: state=task.status ) - async def get_events_for_task(self, task: Task) -> List[Event]: + async def get_events_for_task(self, task: Task) -> list[Event]: events = await self.session.execute( select(Event) .join(Event.task)