Skip to content

Commit

Permalink
Provide base class for dataset loaders (#59)
Browse files Browse the repository at this point in the history
* Update class hierarchy for dataset loader

This will allow for the future addition of an eager dataset loader

* Class usage cleanup
  • Loading branch information
cthoyt authored Jan 21, 2022
1 parent f5f0ffb commit 47af8b4
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 70 deletions.
11 changes: 9 additions & 2 deletions chemicalx/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
from class_resolver import Resolver

from .contextfeatureset import ContextFeatureSet
from .datasetloader import DatasetLoader, DrugbankDDI, DrugComb, DrugCombDB, TwoSides
from .datasetloader import (
DatasetLoader,
DrugbankDDI,
DrugComb,
DrugCombDB,
RemoteDatasetLoader,
TwoSides,
)
from .drugfeatureset import DrugFeatureSet
from .drugpairbatch import DrugPairBatch
from .labeledtriples import LabeledTriples
Expand All @@ -22,4 +29,4 @@
"DrugCombDB",
]

dataset_resolver = Resolver.from_subclasses(base=DatasetLoader)
dataset_resolver = Resolver.from_subclasses(base=DatasetLoader, skip={RemoteDatasetLoader})
150 changes: 91 additions & 59 deletions chemicalx/data/datasetloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io
import json
import urllib.request
from abc import ABC, abstractmethod
from functools import lru_cache
from textwrap import dedent
from typing import Dict, Optional, Tuple, cast
Expand All @@ -18,6 +19,7 @@

__all__ = [
"DatasetLoader",
"RemoteDatasetLoader",
# Actual datasets
"DrugCombDB",
"DrugComb",
Expand All @@ -26,18 +28,8 @@
]


class DatasetLoader:
"""General dataset loader for the integrated drug pair scoring datasets."""

def __init__(self, dataset_name: str):
"""Instantiate the dataset loader.
Args:
dataset_name (str): The name of the dataset.
"""
self.base_url = "https://raw.githubusercontent.com/AstraZeneca/chemicalx/main/dataset"
self.dataset_name = dataset_name
assert dataset_name in ["drugcombdb", "drugcomb", "twosides", "drugbankddi"]
class DatasetLoader(ABC):
"""A generic dataset."""

def get_generators(
self,
Expand Down Expand Up @@ -95,6 +87,86 @@ def get_generator(
labeled_triples=self.get_labeled_triples() if labeled_triples is None else labeled_triples,
)

@abstractmethod
def get_context_features(self) -> ContextFeatureSet:
"""
Get the context feature set.
Returns:
: The ContextFeatureSet of the dataset of interest.
"""

@property
def num_contexts(self) -> int:
"""Get the number of contexts."""
return len(self.get_context_features())

@property
def context_channels(self) -> int:
"""Get the number of features for each context."""
return next(iter(self.get_context_features().values())).shape[1]

@abstractmethod
def get_drug_features(self):
"""
Get the drug feature set.
Returns:
: The DrugFeatureSet of the dataset of interest.
"""

@property
def num_drugs(self) -> int:
"""Get the number of drugs."""
return len(self.get_drug_features())

@property
def drug_channels(self) -> int:
"""Get the number of features for each drug."""
return next(iter(self.get_drug_features().values()))["features"].shape[1]

def get_labeled_triples(self) -> LabeledTriples:
"""
Get the labeled triples file from the storage.
Returns:
: The labeled triples in the dataset.
"""

@property
def num_labeled_triples(self) -> int:
"""Get the number of labeled triples."""
return len(self.get_labeled_triples())

def summarize(self) -> None:
"""Summarize the dataset."""
print(
dedent(
f"""\
Name: {self.__class__.__name__}
Contexts: {self.num_contexts}
Context Feature Size: {self.context_channels}
Drugs: {self.num_drugs}
Drug Feature Size: {self.drug_channels}
Triples: {self.num_labeled_triples}
"""
)
)


class RemoteDatasetLoader(DatasetLoader):
"""General dataset loader for the integrated drug pair scoring datasets."""

def __init__(self, dataset_name: str):
"""Instantiate the dataset loader.
Args:
dataset_name (str): The name of the dataset.
"""
self.base_url = "https://raw.githubusercontent.com/AstraZeneca/chemicalx/main/dataset"
self.dataset_name = dataset_name
assert dataset_name in ["drugcombdb", "drugcomb", "twosides", "drugbankddi"]

def generate_path(self, file_name: str) -> str:
"""
Generate a complete url for a dataset file.
Expand Down Expand Up @@ -140,30 +212,20 @@ def get_context_features(self):
Get the context feature set.
Returns:
context_feature_set (ContextFeatureSet): The ContextFeatureSet of the dataset of interest.
: The ContextFeatureSet of the dataset of interest.
"""
path = self.generate_path("context_set.json")
raw_data = self.load_raw_json_data(path)
raw_data = {k: torch.FloatTensor(np.array(v).reshape(1, -1)) for k, v in raw_data.items()}
return ContextFeatureSet(raw_data)

@property
def num_contexts(self) -> int:
"""Get the number of contexts."""
return len(self.get_context_features())

@property
def context_channels(self) -> int:
"""Get the number of features for each context."""
return next(iter(self.get_context_features().values())).shape[1]

@lru_cache(maxsize=1)
def get_drug_features(self):
"""
Get the drug feature set.
Returns:
drug_feature_set (DrugFeatureSet): The DrugFeatureSet of the dataset of interest.
: The DrugFeatureSet of the dataset of interest.
"""
path = self.generate_path("drug_set.json")
raw_data = self.load_raw_json_data(path)
Expand All @@ -173,74 +235,44 @@ def get_drug_features(self):
}
return DrugFeatureSet.from_dict(raw_data)

@property
def num_drugs(self) -> int:
"""Get the number of drugs."""
return len(self.get_drug_features())

@property
def drug_channels(self) -> int:
"""Get the number of features for each drug."""
return next(iter(self.get_drug_features().values()))["features"].shape[1]

@lru_cache(maxsize=1)
def get_labeled_triples(self):
"""
Get the labeled triples file from the storage.
Returns:
labeled_triples (LabeledTriples): The labeled triples in the dataset.
: The labeled triples in the dataset.
"""
path = self.generate_path("labeled_triples.csv")
df = self.load_raw_csv_data(path)
return LabeledTriples(df)

@property
def num_labeled_triples(self) -> int:
"""Get the number of labeled triples."""
return len(self.get_labeled_triples())

def summarize(self) -> None:
"""Summarize the dataset."""
print(
dedent(
f"""\
Name: {self.dataset_name}
Contexts: {self.num_contexts}
Context Feature Size: {self.context_channels}
Drugs: {self.num_drugs}
Drug Feature Size: {self.drug_channels}
Triples: {self.num_labeled_triples}
"""
)
)


class DrugCombDB(DatasetLoader):
class DrugCombDB(RemoteDatasetLoader):
"""A dataset loader for `DrugCombDB <http://drugcombdb.denglab.org>`_."""

def __init__(self):
"""Instantiate the DrugCombDB dataset loader."""
super().__init__("drugcombdb")


class DrugComb(DatasetLoader):
class DrugComb(RemoteDatasetLoader):
"""A dataset loader for `DrugComb <https://drugcomb.fimm.fi/>`_."""

def __init__(self):
"""Instantiate the DrugComb dataset loader."""
super().__init__("drugcomb")


class TwoSides(DatasetLoader):
class TwoSides(RemoteDatasetLoader):
"""A dataset loader for a sample of `TWOSIDES <http://tatonettilab.org/offsides/>`_."""

def __init__(self):
"""Instantiate the TWOSIDES dataset loader."""
super().__init__("twosides")


class DrugbankDDI(DatasetLoader):
class DrugbankDDI(RemoteDatasetLoader):
"""A dataset loader for `Drugbank DDI <https://www.pnas.org/content/115/18/E4304>`_."""

def __init__(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import unittest
from typing import ClassVar

from chemicalx.data import DatasetLoader
from chemicalx.data import DatasetLoader, DrugCombDB


class TestGeneratorDrugCombDB(unittest.TestCase):
Expand All @@ -14,7 +14,7 @@ class TestGeneratorDrugCombDB(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
"""Set up the class with a dataset loader."""
cls.loader = DatasetLoader("drugcombdb")
cls.loader = DrugCombDB()

def test_all_true(self):
"""Test sizes of drug features during batch generation."""
Expand Down
10 changes: 5 additions & 5 deletions tests/unit/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import unittest
from typing import ClassVar

from chemicalx.data import DatasetLoader
from chemicalx.data import DatasetLoader, DrugbankDDI, DrugComb, DrugCombDB, TwoSides


class TestDrugComb(unittest.TestCase):
Expand All @@ -14,7 +14,7 @@ class TestDrugComb(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
"""Set up the test case."""
cls.loader = DatasetLoader("drugcomb")
cls.loader = DrugComb()

def test_get_context_features(self):
"""Test the number of context features."""
Expand All @@ -40,7 +40,7 @@ class TestDrugCombDB(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
"""Set up the test case."""
cls.loader = DatasetLoader("drugcombdb")
cls.loader = DrugCombDB()

def test_get_context_features(self):
"""Test the number of context features."""
Expand All @@ -66,7 +66,7 @@ class TestDeepDDI(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
"""Set up the test case."""
cls.loader = DatasetLoader("drugbankddi")
cls.loader = DrugbankDDI()

def test_get_context_features(self):
"""Test the number of context features."""
Expand All @@ -92,7 +92,7 @@ class TestTwoSides(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
"""Set up the test case."""
cls.loader = DatasetLoader("twosides")
cls.loader = TwoSides()

def test_get_context_features(self):
"""Test the number of context features."""
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import chemicalx.models
from chemicalx import pipeline
from chemicalx.data import DatasetLoader, DrugCombDB
from chemicalx.data import DatasetLoader, DrugComb, DrugCombDB
from chemicalx.models import (
CASTER,
EPGCNDS,
Expand Down Expand Up @@ -111,7 +111,7 @@ class TestModels(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
"""Set up the test case with a dataset."""
cls.loader = DatasetLoader("drugcomb")
cls.loader = DrugComb()

def setUp(self):
"""Set up the test case."""
Expand Down

0 comments on commit 47af8b4

Please sign in to comment.