Skip to content

Commit

Permalink
Merge pull request #33 from ssciwr/fix_20_store_job_info
Browse files Browse the repository at this point in the history
Store job information
  • Loading branch information
lkeegan authored Oct 22, 2024
2 parents f73af6d + 5a92bc8 commit afc9d53
Show file tree
Hide file tree
Showing 10 changed files with 267 additions and 82 deletions.
39 changes: 34 additions & 5 deletions backend/src/predicTCR_server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
from flask_jwt_extended import JWTManager
from flask_cors import cross_origin
from predicTCR_server.logger import get_logger
from predicTCR_server.utils import timestamp_now
from predicTCR_server.model import (
db,
Sample,
User,
Job,
Status,
Settings,
add_new_user,
add_new_runner_user,
Expand Down Expand Up @@ -281,6 +284,16 @@ def admin_users():
)
return jsonify(users=[user.as_dict() for user in users])

@app.route("/api/admin/jobs", methods=["GET"])
@jwt_required()
def admin_jobs():
if not current_user.is_admin:
return jsonify(message="Admin account required"), 400
jobs = (
db.session.execute(db.select(Job).order_by(db.desc(Job.id))).scalars().all()
)
return jsonify(jobs=[job.as_dict() for job in jobs])

@app.route("/api/admin/runner_token", methods=["GET"])
@jwt_required()
def admin_runner_token():
Expand All @@ -305,7 +318,17 @@ def runner_request_job():
sample_id = request_job()
if sample_id is None:
return jsonify(message="No job available"), 204
return {"sample_id": sample_id}
new_job = Job(
id=None,
sample_id=sample_id,
timestamp_start=timestamp_now(),
timestamp_end=0,
status=Status.RUNNING,
error_message="",
)
db.session.add(new_job)
db.session.commit()
return {"job_id": new_job.id, "sample_id": sample_id}

@app.route("/api/runner/result", methods=["POST"])
@cross_origin()
Expand All @@ -317,6 +340,9 @@ def runner_result():
sample_id = form_as_dict.get("sample_id", None)
if sample_id is None:
return jsonify(message="Missing key: sample_id"), 400
job_id = form_as_dict.get("job_id", None)
if job_id is None:
return jsonify(message="Missing key: job_id"), 400
success = form_as_dict.get("success", None)
if success is None or success.lower() not in ["true", "false"]:
logger.info(" -> missing success key")
Expand All @@ -328,19 +354,22 @@ def runner_result():
return jsonify(message="Result has success=True but no file"), 400
runner_hostname = form_as_dict.get("runner_hostname", "")
logger.info(
f"Result upload for '{sample_id}' from runner {current_user.email} / {runner_hostname}"
f"Job '{job_id}' uploaded result for '{sample_id}' from runner {current_user.email} / {runner_hostname}"
)
error_message = form_as_dict.get("error_message", None)
if error_message is not None:
error_message = form_as_dict.get("error_message", "")
if error_message != "":
logger.info(f" -> error message: {error_message}")
message, code = process_result(sample_id, success, zipfile)
message, code = process_result(
int(job_id), int(sample_id), success, error_message, zipfile
)
return jsonify(message=message), code

with app.app_context():
db.create_all()
if db.session.get(Settings, 1) is None:
db.session.add(
Settings(
id=None,
default_personal_submission_quota=10,
default_personal_submission_interval_mins=30,
global_quota=1000,
Expand Down
109 changes: 70 additions & 39 deletions backend/src/predicTCR_server/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

import re
import flask
from enum import Enum
import enum
import argon2
import pathlib
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy.orm import DeclarativeBase, MappedAsDataclass, Mapped, mapped_column
from werkzeug.datastructures import FileStorage
from sqlalchemy.inspection import inspect
from sqlalchemy import Integer, String, Boolean, Enum
from dataclasses import dataclass
from predicTCR_server.email import send_email
from predicTCR_server.settings import predicTCR_url
Expand All @@ -20,12 +22,17 @@
decode_password_reset_token,
)

db = SQLAlchemy()

class Base(DeclarativeBase, MappedAsDataclass):
pass


db = SQLAlchemy(model_class=Base)
ph = argon2.PasswordHasher()
logger = get_logger()


class Status(str, Enum):
class Status(str, enum.Enum):
QUEUED = "queued"
RUNNING = "running"
COMPLETED = "completed"
Expand All @@ -34,34 +41,45 @@ class Status(str, Enum):

@dataclass
class Settings(db.Model):
id: int = db.Column(db.Integer, primary_key=True)
default_personal_submission_quota: int = db.Column(db.Integer, nullable=False)
default_personal_submission_interval_mins: int = db.Column(
db.Integer, nullable=False
id: Mapped[int] = mapped_column(Integer, primary_key=True)
default_personal_submission_quota: Mapped[int] = mapped_column(
Integer, nullable=False
)
default_personal_submission_interval_mins: Mapped[int] = mapped_column(
Integer, nullable=False
)
global_quota: int = db.Column(db.Integer, nullable=False)
tumor_types: str = db.Column(db.String, nullable=False)
sources: str = db.Column(db.String, nullable=False)
csv_required_columns: str = db.Column(db.String, nullable=False)
global_quota: Mapped[int] = mapped_column(Integer, nullable=False)
tumor_types: Mapped[str] = mapped_column(String, nullable=False)
sources: Mapped[str] = mapped_column(String, nullable=False)
csv_required_columns: Mapped[str] = mapped_column(String, nullable=False)

def as_dict(self):
return {
c: getattr(self, c)
for c in inspect(self).attrs.keys()
if c != "password_hash"
}
return {c: getattr(self, c) for c in inspect(self).attrs.keys()}


@dataclass
class Job(db.Model):
id: Mapped[int] = mapped_column(Integer, primary_key=True)
sample_id: Mapped[int] = mapped_column(Integer, nullable=False)
timestamp_start: Mapped[int] = mapped_column(Integer, nullable=False)
timestamp_end: Mapped[int] = mapped_column(Integer, nullable=False)
status: Mapped[Status] = mapped_column(Enum(Status), nullable=False)
error_message: Mapped[str] = mapped_column(String, nullable=False)

def as_dict(self):
return {c: getattr(self, c) for c in inspect(self).attrs.keys()}


@dataclass
class Sample(db.Model):
id: int = db.Column(db.Integer, primary_key=True)
email: str = db.Column(db.String(256), nullable=False)
name: str = db.Column(db.String(128), nullable=False)
tumor_type: str = db.Column(db.String(128), nullable=False)
source: str = db.Column(db.String(128), nullable=False)
timestamp: int = db.Column(db.Integer, nullable=False)
status: Status = db.Column(db.Enum(Status), nullable=False)
has_results_zip: bool = db.Column(db.Boolean, nullable=False)
id: Mapped[int] = mapped_column(Integer, primary_key=True)
email: Mapped[str] = mapped_column(String(256), nullable=False)
name: Mapped[str] = mapped_column(String(128), nullable=False)
tumor_type: Mapped[str] = mapped_column(String(128), nullable=False)
source: Mapped[str] = mapped_column(String(128), nullable=False)
timestamp: Mapped[int] = mapped_column(Integer, nullable=False)
status: Mapped[Status] = mapped_column(Enum(Status), nullable=False)
has_results_zip: Mapped[bool] = mapped_column(Boolean, nullable=False)

def _base_path(self) -> pathlib.Path:
data_path = flask.current_app.config["PREDICTCR_DATA_PATH"]
Expand All @@ -79,17 +97,17 @@ def result_file_path(self) -> pathlib.Path:

@dataclass
class User(db.Model):
id: int = db.Column(db.Integer, primary_key=True)
email: str = db.Column(db.Text, nullable=False, unique=True)
password_hash: str = db.Column(db.Text, nullable=False)
activated: bool = db.Column(db.Boolean, nullable=False)
enabled: bool = db.Column(db.Boolean, nullable=False)
quota: int = db.Column(db.Integer, nullable=False)
submission_interval_minutes: int = db.Column(db.Integer, nullable=False)
last_submission_timestamp: int = db.Column(db.Integer, nullable=False)
is_admin: bool = db.Column(db.Boolean, nullable=False)
is_runner: bool = db.Column(db.Boolean, nullable=False)
full_results: bool = db.Column(db.Boolean, nullable=False)
id: int = mapped_column(Integer, primary_key=True)
email: str = mapped_column(String, nullable=False, unique=True)
password_hash: str = mapped_column(String, nullable=False)
activated: bool = mapped_column(Boolean, nullable=False)
enabled: bool = mapped_column(Boolean, nullable=False)
quota: int = mapped_column(Integer, nullable=False)
submission_interval_minutes: int = mapped_column(Integer, nullable=False)
last_submission_timestamp: int = mapped_column(Integer, nullable=False)
is_admin: bool = mapped_column(Boolean, nullable=False)
is_runner: bool = mapped_column(Boolean, nullable=False)
full_results: bool = mapped_column(Boolean, nullable=False)

def set_password_nocheck(self, new_password: str):
self.password_hash = ph.hash(new_password)
Expand Down Expand Up @@ -145,17 +163,26 @@ def request_job() -> int | None:


def process_result(
sample_id: str, success: bool, result_zip_file: FileStorage | None
job_id: int,
sample_id: int,
success: bool,
error_message: str,
result_zip_file: FileStorage | None,
) -> tuple[str, int]:
sample = db.session.execute(
db.select(Sample).filter_by(id=sample_id)
).scalar_one_or_none()
sample = db.session.get(Sample, sample_id)
if sample is None:
logger.warning(f" --> Unknown sample id {sample_id}")
return f"Unknown sample id {sample_id}", 400
job = db.session.get(Job, job_id)
if job is None:
logger.warning(f" --> Unknown job id {job_id}")
return f"Unknown job id {job_id}", 400
job.timestamp_end = timestamp_now()
if success is False:
sample.has_results_zip = False
sample.status = Status.FAILED
job.status = Status.FAILED
job.error_message = error_message
db.session.commit()
return "Result processed", 200
if result_zip_file is None:
Expand All @@ -165,6 +192,7 @@ def process_result(
result_zip_file.save(sample.result_file_path())
sample.has_results_zip = True
sample.status = Status.COMPLETED
job.status = Status.COMPLETED
db.session.commit()
return "Result processed", 200

Expand Down Expand Up @@ -244,6 +272,7 @@ def add_new_user(email: str, password: str, is_admin: bool) -> tuple[str, int]:
try:
db.session.add(
User(
id=None,
email=email,
password_hash=ph.hash(password),
activated=False,
Expand Down Expand Up @@ -282,6 +311,7 @@ def add_new_runner_user() -> User | None:
runner_name = f"runner{runner_number}"
db.session.add(
User(
id=None,
email=runner_name,
password_hash="",
activated=False,
Expand Down Expand Up @@ -419,6 +449,7 @@ def add_new_sample(
settings = db.session.get(Settings, 1)
settings.global_quota -= 1
new_sample = Sample(
id=None,
email=email,
name=name,
tumor_type=tumor_type,
Expand Down
2 changes: 2 additions & 0 deletions backend/tests/helpers/flask_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def add_test_users(app):
email = f"{name}@abc.xy"
db.session.add(
User(
id=None,
email=email,
password_hash=ph.hash(name),
activated=True,
Expand Down Expand Up @@ -46,6 +47,7 @@ def add_test_samples(app, data_path: pathlib.Path):
with open(f"{ref_dir}/input.{input_file_type}", "w") as f:
f.write(input_file_type)
new_sample = Sample(
id=None,
email="[email protected]",
name=name,
tumor_type=f"tumor_type{sample_id}",
Expand Down
59 changes: 46 additions & 13 deletions backend/tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,13 @@ def test_result_invalid(client):
assert "No results available" in response.json["message"]


def _upload_result(client, result_zipfile: pathlib.Path, sample_id: int):
def _upload_result(client, result_zipfile: pathlib.Path, job_id: int, sample_id: int):
headers = _get_auth_headers(client, "[email protected]", "runner")
with open(result_zipfile, "rb") as f:
response = client.post(
"/api/runner/result",
data={
"job_id": job_id,
"sample_id": sample_id,
"success": True,
"file": (io.BytesIO(f.read()), result_zipfile.name),
Expand All @@ -225,19 +226,57 @@ def _upload_result(client, result_zipfile: pathlib.Path, sample_id: int):
return response


def test_result_valid(client, result_zipfile):
headers = _get_auth_headers(client, "[email protected]", "user")
sample_id = 1
assert _upload_result(client, result_zipfile, sample_id).status_code == 200
def test_runner_valid_success(client, result_zipfile):
headers = _get_auth_headers(client, "[email protected]", "runner")
# request job
request_job_response = client.post(
"/api/runner/request_job",
json={"runner_hostname": "me"},
headers=headers,
)
assert request_job_response.status_code == 200
assert request_job_response.json == {"sample_id": 1, "job_id": 1}
# upload successful result
assert _upload_result(client, result_zipfile, 1, 1).status_code == 200
response = client.post(
"/api/result",
json={"sample_id": sample_id},
headers=headers,
json={"sample_id": 1},
headers=_get_auth_headers(client, "[email protected]", "user"),
)
assert response.status_code == 200
assert len(response.data) > 1


def test_runner_valid_failure(client, result_zipfile):
headers = _get_auth_headers(client, "[email protected]", "runner")
# request job
request_job_response = client.post(
"/api/runner/request_job",
json={"runner_hostname": "me"},
headers=headers,
)
assert request_job_response.status_code == 200
assert request_job_response.json == {"sample_id": 1, "job_id": 1}
# upload failure result
result_response = client.post(
"/api/runner/result",
data={
"job_id": 1,
"sample_id": 1,
"success": False,
"error_message": "Something went wrong",
},
headers=headers,
)
assert result_response.status_code == 200
response = client.post(
"/api/result",
json={"sample_id": 1},
headers=_get_auth_headers(client, "[email protected]", "user"),
)
assert response.status_code == 400


def test_admin_samples_valid(client):
headers = _get_auth_headers(client, "[email protected]", "admin")
response = client.get("/api/admin/samples", headers=headers)
Expand Down Expand Up @@ -288,12 +327,6 @@ def test_admin_users_valid(client):
assert "users" in response.json


def test_runner_result_valid(client, result_zipfile):
response = _upload_result(client, result_zipfile, 1)
assert response.status_code == 200
assert "result processed" in response.json["message"].lower()


def test_admin_update_user_valid(client):
headers = _get_auth_headers(client, "[email protected]", "admin")
user = client.get("/api/admin/users", headers=headers).json["users"][0]
Expand Down
Loading

0 comments on commit afc9d53

Please sign in to comment.