Skip to content

Commit

Permalink
add impl with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
PhilipMay committed Dec 8, 2023
1 parent 7a9ff83 commit 398e972
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 2 deletions.
50 changes: 50 additions & 0 deletions mltb2/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,59 @@ def load_prostate() -> Tuple[pd.Series, pd.DataFrame]:
labels.append(1)
else:
assert False, "This must not happen!"
label_series = pd.Series(labels)
assert len(label_series) == 102

data_df = data_df.reset_index(drop=True) # reset the index to default integer index
assert data_df.shape == (102, 6033)

result = (label_series, data_df)
joblib.dump(result, full_path, compress=("gzip", 3))
else:
result = joblib.load(full_path)
return result


def load_leukemia_big() -> Tuple[pd.Series, pd.DataFrame]:
"""Load leukemia (big) data.
The data is loaded and parsed from the internet.
Also see `leukemia data
<https://web.stanford.edu/~hastie/CASI_files/DATA/leukemia.html>`_.
Returns:
Tuple containing labels and data.
"""
filename = "leukemia_big.pkl.gz"
mltb2_data_home = get_and_create_mltb2_data_dir()
full_path = os.path.join(mltb2_data_home, filename)
if not os.path.exists(full_path):
# download data file
url = "https://web.stanford.edu/~hastie/CASI_files/DATA/leukemia_big.csv"
page = requests.get(url, timeout=10)
page_str = page.text

# check checksum of data file
page_hash = sha256(page_str.encode("utf-8")).hexdigest()
assert page_hash == "35e84928da625da0787efb31a451dedbdf390e821a94ef74b7b7ab6cab9466d4", page_hash

data_df = pd.read_csv(StringIO(page_str))
data_df = data_df.T

labels = []
for label in data_df.index:
if "ALL" in label:
labels.append(0)
elif "AML" in label:
labels.append(1)
else:
assert False, "This must not happen!"
label_series = pd.Series(labels)
assert len(label_series) == 72

data_df = data_df.reset_index(drop=True) # reset the index to default integer index
assert data_df.shape == (72, 7128)

result = (label_series, data_df)
joblib.dump(result, full_path, compress=("gzip", 3))
else:
Expand Down
15 changes: 13 additions & 2 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pandas as pd

from mltb2.data import _load_colon_data, _load_colon_label, load_colon, load_prostate
from mltb2.data import _load_colon_data, _load_colon_label, load_colon, load_leukemia_big, load_prostate


def test_load_colon_data():
Expand Down Expand Up @@ -32,7 +32,7 @@ def test_load_colon():
assert result[1].shape == (62, 2000)


def test_load_prostate_data():
def test_load_prostate():
result = load_prostate()
assert result is not None
assert isinstance(result, tuple)
Expand All @@ -41,3 +41,14 @@ def test_load_prostate_data():
assert isinstance(result[1], pd.DataFrame)
assert result[0].shape == (102,)
assert result[1].shape == (102, 6033)


def test_load_leukemia_big():
result = load_leukemia_big()
assert result is not None
assert isinstance(result, tuple)
assert len(result) == 2
assert isinstance(result[0], pd.Series)
assert isinstance(result[1], pd.DataFrame)
assert result[0].shape == (72,)
assert result[1].shape == (72, 7128)

0 comments on commit 398e972

Please sign in to comment.