Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use best/better practices for type hints and pydantic constraints. #52

Merged
merged 1 commit into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions server/npg/porch/endpoints/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from fastapi import APIRouter, HTTPException, Depends
import logging
from typing import List, Optional
import re
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.exc import NoResultFound
Expand All @@ -44,19 +43,19 @@

@router.get(
"/",
response_model=List[Pipeline],
response_model=list[Pipeline],
summary="Get information about all pipelines.",
description='''
Returns a list of pydantic Pipeline models.
A uri and/or version filter can be used.
A valid token issued for any pipeline is required for authorisation.'''
)
async def get_pipelines(
uri: Optional[str] = None,
version: Optional[str] = None,
uri: str | None = None,
version: str | None = None,
db_accessor=Depends(get_DbAccessor),
permissions=Depends(validate)
) -> List[Pipeline]:
) -> list[Pipeline]:

return await db_accessor.get_all_pipelines(uri, version)

Expand Down
29 changes: 14 additions & 15 deletions server/npg/porch/endpoints/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,17 @@
# this program. If not, see <http://www.gnu.org/licenses/>.

import logging
from fastapi import APIRouter, HTTPException, Depends
from pydantic import PositiveInt
from typing import List, Optional
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.exc import NoResultFound
from starlette import status
from typing import Annotated

from fastapi import APIRouter, Depends, HTTPException, Query
from npg.porch.auth.token import validate
from npg.porch.models.permission import PermissionValidationException
from npg.porch.models.pipeline import Pipeline
from npg.porch.models.task import Task, TaskStateEnum
from npg.porch.models.permission import PermissionValidationException
from npg.porchdb.connection import get_DbAccessor
from npg.porch.auth.token import validate
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm.exc import NoResultFound
from starlette import status


def _validate_request(permission, pipeline):
Expand Down Expand Up @@ -61,18 +60,18 @@ def _validate_request(permission, pipeline):

@router.get(
"/",
response_model=List[Task],
response_model=list[Task],
summary="Returns all tasks, and can be filtered to task status or pipeline name",
description='''
Return all tasks. The list of tasks can be filtered by supplying a pipeline
name and/or task status'''
)
async def get_tasks(
pipeline_name: Optional[str] = None,
status: Optional[TaskStateEnum] = None,
pipeline_name: str | None = None,
status: TaskStateEnum | None = None,
db_accessor=Depends(get_DbAccessor),
permission=Depends(validate)
) -> List[Task]:
) -> list[Task]:
print(pipeline_name, status)
return await db_accessor.get_tasks(pipeline_name=pipeline_name, task_status=status)

Expand Down Expand Up @@ -152,7 +151,7 @@ async def update_task(

@router.post(
"/claim",
response_model=List[Task],
response_model=list[Task],
responses={
status.HTTP_200_OK: {"description": "Receive a list of tasks that have been claimed"}
},
Expand All @@ -173,10 +172,10 @@ async def update_task(
)
async def claim_task(
pipeline: Pipeline,
num_tasks: PositiveInt | None = 1,
num_tasks: Annotated[int | None, Query(gt=0)] = 1,
db_accessor=Depends(get_DbAccessor),
permission=Depends(validate)
) -> List[Task]:
) -> list[Task]:

_validate_request(permission, pipeline)
tasks = await db_accessor.claim_tasks(
Expand Down
5 changes: 2 additions & 3 deletions server/npg/porch/models/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,19 @@
# this program. If not, see <http://www.gnu.org/licenses/>.

from pydantic import BaseModel, Field
from typing import Optional

class Pipeline(BaseModel):
name: str = Field(
default = None,
title='Pipeline Name',
description='A user-controlled name for the pipeline'
)
uri: Optional[str] = Field(
uri: str | None = Field(
default = None,
title='URI',
description='URI to bootstrap the pipeline code'
)
version: Optional[str] = Field(
version: str | None = Field(
default = None,
title='Version',
description='Pipeline version to use with URI'
Expand Down
7 changes: 3 additions & 4 deletions server/npg/porch/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import hashlib
import ujson
from pydantic import BaseModel, Field
from typing import Optional, Dict

from npg.porch.models.pipeline import Pipeline

Expand All @@ -36,17 +35,17 @@ class TaskStateEnum(str, Enum):

class Task(BaseModel):
pipeline: Pipeline
task_input_id: Optional[str] = Field(
task_input_id: str | None = Field(
None,
title='Task Input ID',
description='A stringified unique identifier for a piece of work. Set by the npg_porch server, not the client' # noqa: E501
)
task_input: Dict = Field(
task_input: dict = Field(
None,
title='Task Input',
description='A structured parameter set that uniquely identifies a piece of work, and enables an iteration of a pipeline' # noqa: E501
)
status: Optional[TaskStateEnum] = None
status: TaskStateEnum | None = None

def generate_task_id(self):
return hashlib.sha256(ujson.dumps(self.task_input, sort_keys=True).encode()).hexdigest()
Expand Down
27 changes: 13 additions & 14 deletions server/npg/porchdb/data_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import contains_eager, joinedload
from sqlalchemy.orm.exc import NoResultFound
from typing import Optional, List

from npg.porchdb.models import Pipeline as DbPipeline, Task as DbTask, Event
from npg.porch.models import Task, Pipeline, TaskStateEnum
Expand Down Expand Up @@ -55,10 +54,10 @@ async def _get_pipeline_db_object(self, name: str) -> Pipeline:

async def _get_pipeline_db_objects(
self,
name: Optional[str] = None,
version: Optional[str] = None,
uri: Optional[str] = None
) -> List[Pipeline]:
name: str | None = None,
version: str | None = None,
uri: str | None = None
) -> list[Pipeline]:
query = select(DbPipeline)
if name:
query = query.filter_by(name=name)
Expand All @@ -72,9 +71,9 @@ async def _get_pipeline_db_objects(

async def get_all_pipelines(
self,
uri: Optional[str] = None,
version: Optional[str] = None
) -> List[Pipeline]:
uri: str | None = None,
version: str | None = None
) -> list[Pipeline]:
pipelines = []
pipelines = await self._get_pipeline_db_objects(uri=uri, version=version)
return [pipe.convert_to_model() for pipe in pipelines]
Expand Down Expand Up @@ -117,8 +116,8 @@ async def create_task(self, token_id: int, task: Task) -> Task:
return t.convert_to_model()

async def claim_tasks(
self, token_id: int, pipeline: Pipeline, claim_limit: Optional[int] = 1
) -> List[Task]:
self, token_id: int, pipeline: Pipeline, claim_limit: int | None = 1
) -> list[Task]:
session = self.session

try:
Expand Down Expand Up @@ -196,9 +195,9 @@ async def update_task(self, token_id: int, task: Task) -> Task:

async def get_tasks(
self,
pipeline_name: Optional[str] = None,
task_status: Optional[TaskStateEnum] = None
) -> List[Task]:
pipeline_name: str | None = None,
task_status: TaskStateEnum | None = None
) -> list[Task]:
'''
Gets all the tasks.
Expand Down Expand Up @@ -229,7 +228,7 @@ def convert_task_to_db(task: Task, pipeline: DbPipeline) -> DbTask:
state=task.status
)

async def get_events_for_task(self, task: Task) -> List[Event]:
async def get_events_for_task(self, task: Task) -> list[Event]:
events = await self.session.execute(
select(Event)
.join(Event.task)
Expand Down
Loading