diff --git a/docs/CHANGELOG.rst b/docs/CHANGELOG.rst index 58562089..568d5d01 100644 --- a/docs/CHANGELOG.rst +++ b/docs/CHANGELOG.rst @@ -14,6 +14,16 @@ Added and propagate this change in ``Data`` resource +========== +Unreleased +========== + +Added +----- +- Add support for predictions +- Add version to annotation field + + =================== 21.2.1 - 2024-10-08 =================== diff --git a/src/resdk/query.py b/src/resdk/query.py index 3a0ceb97..a9229745 100644 --- a/src/resdk/query.py +++ b/src/resdk/query.py @@ -16,7 +16,7 @@ import tqdm -from resdk.resources import AnnotationField, DescriptorSchema, Process +from resdk.resources import AnnotationField, DescriptorSchema, PredictionField, Process from resdk.resources.base import BaseResource @@ -419,7 +419,48 @@ def _fetch(self): if missing: # Get corresponding annotation field details in a single query and attach it to # the values. - for field in self.resolwe.annotation_field.filter(id__in=missing.keys()): + for field in self.resolwe.annotation_field.filter( + id__in=missing.keys() + ).iterate(): + for value in missing[field.id]: + value._field = field + value._original_values["field"] = field._original_values + + +class PredictionFieldQuery(ResolweQuery): + """Add additional method to the prediction field query.""" + + def from_path(self, full_path: str) -> "PredictionField": + """Get the PredictionField from full path. + + :raises LookupError: when field at the specified path does not exist. + """ + group_name, field_name = full_path.split(".", maxsplit=1) + return self.get(name=field_name, group__name=group_name) + + +class PredictionValueQuery(ResolweQuery): + """Populate prediction fields with a single query.""" + + def _fetch(self): + """Make request to the server and populate cache. + + Fetch all values and their fields with 2 queries. + """ + # Execute the query in a single request. + super()._fetch() + + missing = collections.defaultdict(list) + for value in self._cache: + if value._field is None: + missing[value.field_id].append(value) + + if missing: + # Get corresponding annotation field details in a single query and attach it to + # the values. + for field in self.resolwe.prediction_field.filter( + id__in=missing.keys() + ).iterate(): for value in missing[field.id]: value._field = field value._original_values["field"] = field._original_values diff --git a/src/resdk/resolwe.py b/src/resdk/resolwe.py index e31ec803..29c29433 100644 --- a/src/resdk/resolwe.py +++ b/src/resdk/resolwe.py @@ -33,7 +33,13 @@ from .constants import CHUNK_SIZE from .exceptions import ValidationError, handle_http_exception -from .query import AnnotationFieldQuery, AnnotationValueQuery, ResolweQuery +from .query import ( + AnnotationFieldQuery, + AnnotationValueQuery, + PredictionFieldQuery, + PredictionValueQuery, + ResolweQuery, +) from .resources import ( AnnotationField, AnnotationValue, @@ -43,6 +49,8 @@ Geneset, Group, Metadata, + PredictionField, + PredictionValue, Process, Relation, Sample, @@ -106,26 +114,30 @@ class Resolwe: # Map between resource and Query map. Default in ResorweQuery, only overrides must # be listed here. resource_query_class = { - AnnotationValue: AnnotationValueQuery, AnnotationField: AnnotationFieldQuery, + AnnotationValue: AnnotationValueQuery, + PredictionField: PredictionFieldQuery, + PredictionValue: PredictionValueQuery, } # Map resource class to ResolweQuery name resource_query_mapping = { AnnotationField: "annotation_field", AnnotationValue: "annotation_value", - Data: "data", Collection: "collection", - Sample: "sample", - Relation: "relation", - Process: "process", + Data: "data", DescriptorSchema: "descriptor_schema", - User: "user", - Group: "group", Feature: "feature", - Mapping: "mapping", Geneset: "geneset", + Group: "group", + Mapping: "mapping", Metadata: "metadata", + PredictionField: "prediction_field", + PredictionValue: "prediction_value", + Process: "process", + Relation: "relation", + Sample: "sample", + User: "user", } # Map ResolweQuery name to it's slug_field slug_field_mapping = { diff --git a/src/resdk/resources/__init__.py b/src/resdk/resources/__init__.py index 061a7964..a0bd66c8 100644 --- a/src/resdk/resources/__init__.py +++ b/src/resdk/resources/__init__.py @@ -54,6 +54,18 @@ :members: :inherited-members: +.. autoclass:: resdk.resources.PredictionValue + :members: + :inherited-members: + +.. autoclass:: resdk.resources.PredictionGroup + :members: + :inherited-members: + +.. autoclass:: resdk.resources.PredictionField + :members: + :inherited-members: + .. autoclass:: resdk.resources.User :members: :inherited-members: @@ -98,6 +110,7 @@ from .descriptor import DescriptorSchema from .geneset import Geneset from .metadata import Metadata +from .predictions import PredictionField, PredictionGroup, PredictionValue from .process import Process from .relation import Relation from .sample import Sample @@ -113,6 +126,9 @@ "Geneset", "Group", "Metadata", + "PredictionField", + "PredictionGroup", + "PredictionValue", "Sample", "Process", "Relation", diff --git a/src/resdk/resources/annotations.py b/src/resdk/resources/annotations.py index e35c38a5..d0369cc1 100644 --- a/src/resdk/resources/annotations.py +++ b/src/resdk/resources/annotations.py @@ -51,6 +51,7 @@ class AnnotationField(BaseResource): "validator_regex", "vocabulary", "required", + "version", ) def __init__(self, resolwe: "Resolwe", **model_data): diff --git a/src/resdk/resources/predictions.py b/src/resdk/resources/predictions.py new file mode 100644 index 00000000..d7bb3c80 --- /dev/null +++ b/src/resdk/resources/predictions.py @@ -0,0 +1,238 @@ +"""Predictions resources.""" + +import logging +from enum import Enum +from typing import TYPE_CHECKING, NamedTuple, Optional, Union + +from ..utils.decorators import assert_object_exists +from .base import BaseResource +from .sample import Sample +from .utils import parse_resolwe_datetime + +if TYPE_CHECKING: + from resdk.resolwe import Resolwe + + +class PredictionType(Enum): + """Supported prediction types.""" + + SCORE = "SCORE" + CLASS = "CLASS" + + +class ScorePredictionType(NamedTuple): + """Prediction score type.""" + + score: float + + +class ClassPredictionType(NamedTuple): + """Prediction class type.""" + + class_: str + probability: float + + +class PredictionGroup(BaseResource): + """Resolwe PredictionGroup resource.""" + + # There is currently no endpoint for PredictionGroup object, but it might be + # created in the future. The objects are created when PredictionField is + # initialized. + endpoint = "prediction_group" + + READ_ONLY_FIELDS = BaseResource.READ_ONLY_FIELDS + ("name", "sort_order", "label") + + def __init__(self, resolwe: "Resolwe", **model_data): + """Initialize the instance. + + :param resolwe: Resolwe instance + :param model_data: Resource model data + """ + self.logger = logging.getLogger(__name__) + super().__init__(resolwe, **model_data) + + def __repr__(self): + """Return user friendly string representation.""" + return f"PredictionGroup " + + +class PredictionField(BaseResource): + """Resolwe PredictionField resource.""" + + endpoint = "prediction_field" + + READ_ONLY_FIELDS = BaseResource.READ_ONLY_FIELDS + ( + "description", + "group", + "label", + "name", + "sort_order", + "type", + "validator_regex", + "vocabulary", + "required", + "version", + ) + + def __init__(self, resolwe: "Resolwe", **model_data): + """Initialize the instance. + + :param resolwe: Resolwe instance + :param model_data: Resource model data + """ + self.logger = logging.getLogger(__name__) + #: prediction group + self._group = None + super().__init__(resolwe, **model_data) + + @property + def group(self) -> PredictionGroup: + """Get prediction group.""" + assert ( + self._group is not None + ), "PredictionGroup must be set before it can be used." + return self._group + + @group.setter + def group(self, payload: dict): + """Set prediction group.""" + if self._group is None: + self._resource_setter(payload, PredictionGroup, "_group") + else: + raise AttributeError("PredictionGroup is read-only.") + + def __repr__(self): + """Return user friendly string representation.""" + return f"PredictionField " + + def __str__(self): + """Return full path of the prediction field.""" + return f"{self.group.name}.{self.name}" + + +class PredictionValue(BaseResource): + """Resolwe PredictionValue resource.""" + + endpoint = "prediction_value" + + READ_ONLY_FIELDS = BaseResource.READ_ONLY_FIELDS + ("label",) + + UPDATE_PROTECTED_FIELDS = BaseResource.UPDATE_PROTECTED_FIELDS + ("field", "sample") + + WRITABLE_FIELDS = BaseResource.WRITABLE_FIELDS + ("value",) + + def __init__(self, resolwe: "Resolwe", **model_data): + """Initialize the instance. + + :param resolwe: Resolwe instance + :param model_data: Resource model data + """ + self.logger = logging.getLogger(__name__) + + #: prediction field + self._field: Optional[PredictionField] = None + self._value: Optional[Union[ScorePredictionType, ClassPredictionType]] = None + self.field_id: Optional[int] = None + + #: sample + self.sample_id: Optional[int] = None + self._sample: Optional[Sample] = None + super().__init__(resolwe, **model_data) + + @property + @assert_object_exists + def modified(self): + """Modification time.""" + return parse_resolwe_datetime(self._original_values["created"]) + + @property + def sample(self): + """Get sample.""" + if self._sample is None: + if self.sample_id is None: + self.sample_id = self._original_values["entity"] + self._sample = Sample(resolwe=self.resolwe, id=self.sample_id) + # Without this save will fail due to change in read-only field. + self._original_values["sample"] = {"id": self.sample_id} + return self._sample + + @sample.setter + def sample(self, payload): + """Set the sample.""" + # Update fields sets sample to None. + if payload is None: + return + if self.sample_id is not None: + raise AttributeError("Sample is read-only.") + if isinstance(payload, Sample): + self.sample_id = payload.id + elif isinstance(payload, dict): + self.sample_id = payload["id"] + else: + self.sample_id = payload + + @property + def value(self): + """Get the value.""" + if self._value is None: + if self.field.type == PredictionType.SCORE.value: + self._value = ScorePredictionType(*self._original_values["value"]) + elif self.field.type == PredictionType.CLASS.value: + self._value = ClassPredictionType(**self._original_values["value"]) + else: + raise TypeError(f"Unknown prediction type {self.field.type}.") + return self._value + + @value.setter + def value(self, value): + """Set the value.""" + if isinstance(value, (ScorePredictionType, ClassPredictionType)): + self._value = value + elif isinstance(value, list): + factory = ( + ScorePredictionType + if self.field.type == PredictionType.SCORE.value + else ClassPredictionType + ) + try: + value = factory(*value) + except TypeError: + raise TypeError( + "Value must be of type ScorePredictionType or ClassPredictionType." + ) + self._value = value + + @property + def field(self) -> PredictionField: + """Get the prediction field.""" + if self._field is None: + assert ( + self.field_id is not None + ), "PredictionField must be set before it can be used." + self._field = self.resolwe.prediction_field.get(id=self.field_id) + # The field is read-only but we have to modify original values here so save + # can detect there were no changes. + self._original_values["field"] = self._field._original_values + return self._field + + @field.setter + def field(self, payload: Union[int, PredictionField, dict]): + """Set prediction field.""" + field_id = None + if isinstance(payload, int): + field_id = payload + elif isinstance(payload, dict): + field_id = payload["id"] + elif isinstance(payload, PredictionField): + field_id = payload.id + if field_id != self.field_id: + self._field = None + self.field_id = field_id + + def __repr__(self): + """Format resource name.""" + return ( + f"PredictionValue " + ) diff --git a/src/resdk/resources/sample.py b/src/resdk/resources/sample.py index 25cf7c14..96d5f034 100644 --- a/src/resdk/resources/sample.py +++ b/src/resdk/resources/sample.py @@ -1,7 +1,7 @@ """Sample resource.""" import logging -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional, Union from resdk.exceptions import ResolweServerError from resdk.shortcuts.sample import SampleUtilsMixin @@ -12,6 +12,7 @@ if TYPE_CHECKING: from .annotations import AnnotationValue + from .predictions import ClassPredictionType, ScorePredictionType class Sample(SampleUtilsMixin, BaseCollection): @@ -253,6 +254,11 @@ def annotations(self): """Get the annotations for the given sample.""" return self.resolwe.annotation_value.filter(entity=self.id) + @property + def predictions(self): + """Get the predictions for the given sample.""" + return self.resolwe.prediction_value.filter(entity=self.id) + def get_annotation(self, full_path: str) -> "AnnotationValue": """Get the AnnotationValue from full path. @@ -303,3 +309,20 @@ def set_annotations(self, annotations: Dict[str, Any]): {"field_path": key, "value": value} for key, value in annotations.items() ] self.api(self.id).set_annotations.post(payload) + + def get_predictions(self) -> Dict[str, Any]: + """Get all predictions for the given sample in a dictionary.""" + return { + str(prediction.field): prediction.value + for prediction in self.predictions.all() + } + + def set_predictions( + self, + predictions: Dict[str, Union["ScorePredictionType", "ClassPredictionType"]], + ): + """Bulk set predictions on the sample.""" + payload = [ + {"field_path": key, "value": value} for key, value in predictions.items() + ] + self.api(self.id).set_predictions.post(payload) diff --git a/tests/functional/predictions/__init__.py b/tests/functional/predictions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/functional/predictions/e2e_predictions.py b/tests/functional/predictions/e2e_predictions.py new file mode 100644 index 00000000..fb2279d4 --- /dev/null +++ b/tests/functional/predictions/e2e_predictions.py @@ -0,0 +1,52 @@ +import os + +from resdk.resources.predictions import ClassPredictionType, ScorePredictionType + +from ..base import BaseResdkFunctionalTest +from ..docs.e2e_docs import TEST_FILES_DIR + + +class TestPredictions(BaseResdkFunctionalTest): + def setUp(self): + super().setUp() + self.collection = self.res.collection.create(name="Test collection") + + self.geneset = None + + def tearDown(self): + # self.geneset is deleted along with collection + self.collection.delete(force=True) + + def test_predictions(self): + reads = self.res.run( + slug="upload-fastq-single", + input={"src": os.path.join(TEST_FILES_DIR, "reads.fastq.gz")}, + collection=self.collection.id, + ) + sample = reads.sample + self.assertEqual(sample.get_predictions(), {}) + + sample.set_predictions({"general.score": ScorePredictionType(0.5)}) + self.assertEqual( + sample.get_predictions(), {"general.score": ScorePredictionType(0.5)} + ) + + sample.set_predictions( + { + "general.score": ScorePredictionType(1.5), + "general.class": ClassPredictionType("positive", 0.5), + } + ) + self.assertEqual( + sample.get_predictions(), + { + "general.score": ScorePredictionType(1.5), + "general.class": ClassPredictionType("positive", 0.5), + }, + ) + + sample.set_predictions({"general.class": None}) + self.assertEqual( + sample.get_predictions(), + {"general.score": ScorePredictionType(1.5)}, + )