Skip to content

Commit

Permalink
Merge pull request #52 from nerdstrike/pydantic_second_pass
Browse files Browse the repository at this point in the history
Use best/better practices for type hints and pydantic constraints.
  • Loading branch information
mgcam authored Oct 11, 2023
2 parents aae9ee3 + 7432a8e commit 3bae5f3
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 41 deletions.
9 changes: 4 additions & 5 deletions server/npg/porch/endpoints/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from fastapi import APIRouter, HTTPException, Depends
import logging
from typing import List, Optional
import re
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.exc import NoResultFound
Expand All @@ -44,19 +43,19 @@

@router.get(
"/",
response_model=List[Pipeline],
response_model=list[Pipeline],
summary="Get information about all pipelines.",
description='''
Returns a list of pydantic Pipeline models.
A uri and/or version filter can be used.
A valid token issued for any pipeline is required for authorisation.'''
)
async def get_pipelines(
uri: Optional[str] = None,
version: Optional[str] = None,
uri: str | None = None,
version: str | None = None,
db_accessor=Depends(get_DbAccessor),
permissions=Depends(validate)
) -> List[Pipeline]:
) -> list[Pipeline]:

return await db_accessor.get_all_pipelines(uri, version)

Expand Down
29 changes: 14 additions & 15 deletions server/npg/porch/endpoints/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,17 @@
# this program. If not, see <http://www.gnu.org/licenses/>.

import logging
from fastapi import APIRouter, HTTPException, Depends
from pydantic import PositiveInt
from typing import List, Optional
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.exc import NoResultFound
from starlette import status
from typing import Annotated

from fastapi import APIRouter, Depends, HTTPException, Query
from npg.porch.auth.token import validate
from npg.porch.models.permission import PermissionValidationException
from npg.porch.models.pipeline import Pipeline
from npg.porch.models.task import Task, TaskStateEnum
from npg.porch.models.permission import PermissionValidationException
from npg.porchdb.connection import get_DbAccessor
from npg.porch.auth.token import validate
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.exc import NoResultFound
from starlette import status


def _validate_request(permission, pipeline):
Expand Down Expand Up @@ -61,18 +60,18 @@ def _validate_request(permission, pipeline):

@router.get(
"/",
response_model=List[Task],
response_model=list[Task],
summary="Returns all tasks, and can be filtered to task status or pipeline name",
description='''
Return all tasks. The list of tasks can be filtered by supplying a pipeline
name and/or task status'''
)
async def get_tasks(
pipeline_name: Optional[str] = None,
status: Optional[TaskStateEnum] = None,
pipeline_name: str | None = None,
status: TaskStateEnum | None = None,
db_accessor=Depends(get_DbAccessor),
permission=Depends(validate)
) -> List[Task]:
) -> list[Task]:
print(pipeline_name, status)
return await db_accessor.get_tasks(pipeline_name=pipeline_name, task_status=status)

Expand Down Expand Up @@ -152,7 +151,7 @@ async def update_task(

@router.post(
"/claim",
response_model=List[Task],
response_model=list[Task],
responses={
status.HTTP_200_OK: {"description": "Receive a list of tasks that have been claimed"}
},
Expand All @@ -173,10 +172,10 @@ async def update_task(
)
async def claim_task(
pipeline: Pipeline,
num_tasks: PositiveInt | None = 1,
num_tasks: Annotated[int | None, Query(gt=0)] = 1,
db_accessor=Depends(get_DbAccessor),
permission=Depends(validate)
) -> List[Task]:
) -> list[Task]:

_validate_request(permission, pipeline)
tasks = await db_accessor.claim_tasks(
Expand Down
5 changes: 2 additions & 3 deletions server/npg/porch/models/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,19 @@
# this program. If not, see <http://www.gnu.org/licenses/>.

from pydantic import BaseModel, Field
from typing import Optional

class Pipeline(BaseModel):
name: str = Field(
default = None,
title='Pipeline Name',
description='A user-controlled name for the pipeline'
)
uri: Optional[str] = Field(
uri: str | None = Field(
default = None,
title='URI',
description='URI to bootstrap the pipeline code'
)
version: Optional[str] = Field(
version: str | None = Field(
default = None,
title='Version',
description='Pipeline version to use with URI'
Expand Down
7 changes: 3 additions & 4 deletions server/npg/porch/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import hashlib
import ujson
from pydantic import BaseModel, Field
from typing import Optional, Dict

from npg.porch.models.pipeline import Pipeline

Expand All @@ -36,17 +35,17 @@ class TaskStateEnum(str, Enum):

class Task(BaseModel):
pipeline: Pipeline
task_input_id: Optional[str] = Field(
task_input_id: str | None = Field(
None,
title='Task Input ID',
description='A stringified unique identifier for a piece of work. Set by the npg_porch server, not the client' # noqa: E501
)
task_input: Dict = Field(
task_input: dict = Field(
None,
title='Task Input',
description='A structured parameter set that uniquely identifies a piece of work, and enables an iteration of a pipeline' # noqa: E501
)
status: Optional[TaskStateEnum] = None
status: TaskStateEnum | None = None

def generate_task_id(self):
return hashlib.sha256(ujson.dumps(self.task_input, sort_keys=True).encode()).hexdigest()
Expand Down
27 changes: 13 additions & 14 deletions server/npg/porchdb/data_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import contains_eager, joinedload
from sqlalchemy.orm.exc import NoResultFound
from typing import Optional, List

from npg.porchdb.models import Pipeline as DbPipeline, Task as DbTask, Event
from npg.porch.models import Task, Pipeline, TaskStateEnum
Expand Down Expand Up @@ -55,10 +54,10 @@ async def _get_pipeline_db_object(self, name: str) -> Pipeline:

async def _get_pipeline_db_objects(
self,
name: Optional[str] = None,
version: Optional[str] = None,
uri: Optional[str] = None
) -> List[Pipeline]:
name: str | None = None,
version: str | None = None,
uri: str | None = None
) -> list[Pipeline]:
query = select(DbPipeline)
if name:
query = query.filter_by(name=name)
Expand All @@ -72,9 +71,9 @@ async def _get_pipeline_db_objects(

async def get_all_pipelines(
self,
uri: Optional[str] = None,
version: Optional[str] = None
) -> List[Pipeline]:
uri: str | None = None,
version: str | None = None
) -> list[Pipeline]:
pipelines = []
pipelines = await self._get_pipeline_db_objects(uri=uri, version=version)
return [pipe.convert_to_model() for pipe in pipelines]
Expand Down Expand Up @@ -117,8 +116,8 @@ async def create_task(self, token_id: int, task: Task) -> Task:
return t.convert_to_model()

async def claim_tasks(
self, token_id: int, pipeline: Pipeline, claim_limit: Optional[int] = 1
) -> List[Task]:
self, token_id: int, pipeline: Pipeline, claim_limit: int | None = 1
) -> list[Task]:
session = self.session

try:
Expand Down Expand Up @@ -196,9 +195,9 @@ async def update_task(self, token_id: int, task: Task) -> Task:

async def get_tasks(
self,
pipeline_name: Optional[str] = None,
task_status: Optional[TaskStateEnum] = None
) -> List[Task]:
pipeline_name: str | None = None,
task_status: TaskStateEnum | None = None
) -> list[Task]:
'''
Gets all the tasks.
Expand Down Expand Up @@ -229,7 +228,7 @@ def convert_task_to_db(task: Task, pipeline: DbPipeline) -> DbTask:
state=task.status
)

async def get_events_for_task(self, task: Task) -> List[Event]:
async def get_events_for_task(self, task: Task) -> list[Event]:
events = await self.session.execute(
select(Event)
.join(Event.task)
Expand Down

0 comments on commit 3bae5f3

Please sign in to comment.