From 398e97261c1ba68ec35235e4b4dc82af7dd597cf Mon Sep 17 00:00:00 2001 From: PhilipMay Date: Fri, 8 Dec 2023 13:25:37 +0100 Subject: [PATCH] add impl with tests --- mltb2/data.py | 50 ++++++++++++++++++++++++++++++++++++++++++++++ tests/test_data.py | 15 ++++++++++++-- 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/mltb2/data.py b/mltb2/data.py index 6dfb92e..2d96f06 100644 --- a/mltb2/data.py +++ b/mltb2/data.py @@ -149,9 +149,59 @@ def load_prostate() -> Tuple[pd.Series, pd.DataFrame]: labels.append(1) else: assert False, "This must not happen!" + label_series = pd.Series(labels) + assert len(label_series) == 102 data_df = data_df.reset_index(drop=True) # reset the index to default integer index + assert data_df.shape == (102, 6033) + + result = (label_series, data_df) + joblib.dump(result, full_path, compress=("gzip", 3)) + else: + result = joblib.load(full_path) + return result + + +def load_leukemia_big() -> Tuple[pd.Series, pd.DataFrame]: + """Load leukemia (big) data. + + The data is loaded and parsed from the internet. + Also see `leukemia data + `_. + + Returns: + Tuple containing labels and data. + """ + filename = "leukemia_big.pkl.gz" + mltb2_data_home = get_and_create_mltb2_data_dir() + full_path = os.path.join(mltb2_data_home, filename) + if not os.path.exists(full_path): + # download data file + url = "https://web.stanford.edu/~hastie/CASI_files/DATA/leukemia_big.csv" + page = requests.get(url, timeout=10) + page_str = page.text + + # check checksum of data file + page_hash = sha256(page_str.encode("utf-8")).hexdigest() + assert page_hash == "35e84928da625da0787efb31a451dedbdf390e821a94ef74b7b7ab6cab9466d4", page_hash + + data_df = pd.read_csv(StringIO(page_str)) + data_df = data_df.T + + labels = [] + for label in data_df.index: + if "ALL" in label: + labels.append(0) + elif "AML" in label: + labels.append(1) + else: + assert False, "This must not happen!" label_series = pd.Series(labels) + assert len(label_series) == 72 + + data_df = data_df.reset_index(drop=True) # reset the index to default integer index + assert data_df.shape == (72, 7128) + result = (label_series, data_df) joblib.dump(result, full_path, compress=("gzip", 3)) else: diff --git a/tests/test_data.py b/tests/test_data.py index 219d33d..46113c0 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -4,7 +4,7 @@ import pandas as pd -from mltb2.data import _load_colon_data, _load_colon_label, load_colon, load_prostate +from mltb2.data import _load_colon_data, _load_colon_label, load_colon, load_leukemia_big, load_prostate def test_load_colon_data(): @@ -32,7 +32,7 @@ def test_load_colon(): assert result[1].shape == (62, 2000) -def test_load_prostate_data(): +def test_load_prostate(): result = load_prostate() assert result is not None assert isinstance(result, tuple) @@ -41,3 +41,14 @@ def test_load_prostate_data(): assert isinstance(result[1], pd.DataFrame) assert result[0].shape == (102,) assert result[1].shape == (102, 6033) + + +def test_load_leukemia_big(): + result = load_leukemia_big() + assert result is not None + assert isinstance(result, tuple) + assert len(result) == 2 + assert isinstance(result[0], pd.Series) + assert isinstance(result[1], pd.DataFrame) + assert result[0].shape == (72,) + assert result[1].shape == (72, 7128)