From cf8ab38ca1791e90f61ed2b70298d060fb7488af Mon Sep 17 00:00:00 2001 From: Philip May Date: Wed, 13 Dec 2023 10:46:13 +0100 Subject: [PATCH] Compare data loaders with original implementation. (#111) * comment test_load_colon_data and test_load_colon_label functions * Add ori_data_loader * fix linting for ori_data_loader * Add tests to compare loaded data with original data --- tests/ori_data_loader.py | 127 +++++++++++++++++++++++++++++++++++++++ tests/test_data.py | 34 ++++++++++- 2 files changed, 159 insertions(+), 2 deletions(-) create mode 100644 tests/ori_data_loader.py diff --git a/tests/ori_data_loader.py b/tests/ori_data_loader.py new file mode 100644 index 0000000..0426731 --- /dev/null +++ b/tests/ori_data_loader.py @@ -0,0 +1,127 @@ +# Copyright (c) 2021 Sigrun May, Helmholtz-Zentrum für Infektionsforschung GmbH (HZI) +# Copyright (c) 2021 Sigrun May, Ostfalia Hochschule für angewandte Wissenschaften +# Copyright (c) 2020 Philip May +# This software is distributed under the terms of the MIT license +# which is available at https://opensource.org/licenses/MIT + +# this is the original implementation from +# https://github.com/sigrun-may/cv-pruner/blob/ac35eba88a824e6bb6a6435cda67224a4db69e65/examples/data_loader.py + +"""Data loader module.""" + +from typing import Tuple + +import numpy as np +import pandas as pd +import requests +from bs4 import BeautifulSoup + + +def load_colon_data() -> Tuple[pd.Series, pd.DataFrame]: + """Load colon data. + + The data is loaded and parsed from the internet. + Also see + + Returns: + Tuple containing labels and data. + """ + html_data = "http://genomics-pubs.princeton.edu/oncology/affydata/I2000.html" + + page = requests.get(html_data, timeout=10) + + soup = BeautifulSoup(page.content, "html.parser") + colon_text_data = soup.get_text() + + colon_text_data_lines = colon_text_data.splitlines() + colon_text_data_lines = [[float(s) for s in line.split()] for line in colon_text_data_lines if len(line) > 20] + assert len(colon_text_data_lines) == 2000 + assert len(colon_text_data_lines[0]) == 62 + + data = np.array(colon_text_data_lines).T + + html_label = "http://genomics-pubs.princeton.edu/oncology/affydata/tissues.html" + page = requests.get(html_label, timeout=10) + soup = BeautifulSoup(page.content, "html.parser") + colon_text_label = soup.get_text() + colon_text_label = colon_text_label.splitlines() + + label = [] + + for line in colon_text_label: + try: + i = int(line) + label.append(0 if i > 0 else 1) + except: # noqa: S110, E722 + pass + + assert len(label) == 62 + + data_df = pd.DataFrame(data) + + # generate feature names + column_names = [] + for column_name in data_df.columns: + column_names.append("gene_" + str(column_name)) + + data_df.columns = column_names + + return pd.Series(label), data_df + + +# TODO append random features and shuffle + + +def load_prostate_data() -> Tuple[pd.Series, pd.DataFrame]: + """Load prostate data. + + The data is loaded and parsed from + + Returns: + Tuple containing labels and data. + """ + df = pd.read_csv("https://web.stanford.edu/~hastie/CASI_files/DATA/prostmat.csv") + data = df.T + + # labels + labels = [] + for label in df.columns: # pylint:disable=no-member + if "control" in label: + labels.append(0) + elif "cancer" in label: + labels.append(1) + else: + assert False, "This must not happen!" + + assert len(labels) == 102 + assert data.shape == (102, 6033) + + return pd.Series(labels), data + + +def load_leukemia_data() -> Tuple[pd.Series, pd.DataFrame]: + """Load leukemia data. + + The data is loaded and parsed from the internet. + Also see + + Returns: + Tuple containing labels and data. + """ + df = pd.read_csv("https://web.stanford.edu/~hastie/CASI_files/DATA/leukemia_big.csv") + data = df.T + + # labels + labels = [] + for label in df.columns: # pylint:disable=no-member + if "ALL" in label: + labels.append(0) + elif "AML" in label: + labels.append(1) + else: + assert False, "This must not happen!" + + assert len(labels) == 72 + assert data.shape == (72, 7128) + + return pd.Series(labels), data diff --git a/tests/test_data.py b/tests/test_data.py index 46113c0..3ea9165 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -3,19 +3,22 @@ # which is available at https://opensource.org/licenses/MIT import pandas as pd +from numpy.testing import assert_almost_equal from mltb2.data import _load_colon_data, _load_colon_label, load_colon, load_leukemia_big, load_prostate +from .ori_data_loader import load_colon_data, load_leukemia_data, load_prostate_data + def test_load_colon_data(): - result = _load_colon_data() + result = _load_colon_data() # only load data not labels assert result is not None assert isinstance(result, pd.DataFrame) assert result.shape == (62, 2000) def test_load_colon_label(): - result = _load_colon_label() + result = _load_colon_label() # only load labels not data assert result is not None assert isinstance(result, pd.Series) assert len(result) == 62 @@ -32,6 +35,15 @@ def test_load_colon(): assert result[1].shape == (62, 2000) +def test_load_colon_compare_original(): + result = load_colon() + ori_result = load_colon_data() + assert result[0].shape == ori_result[0].shape + assert result[1].shape == ori_result[1].shape + assert_almost_equal(result[0].to_numpy(), ori_result[0].to_numpy()) + assert_almost_equal(result[1].to_numpy(), ori_result[1].to_numpy()) + + def test_load_prostate(): result = load_prostate() assert result is not None @@ -43,6 +55,15 @@ def test_load_prostate(): assert result[1].shape == (102, 6033) +def test_load_prostate_compare_original(): + result = load_prostate() + ori_result = load_prostate_data() + assert result[0].shape == ori_result[0].shape + assert result[1].shape == ori_result[1].shape + assert_almost_equal(result[0].to_numpy(), ori_result[0].to_numpy()) + assert_almost_equal(result[1].to_numpy(), ori_result[1].to_numpy()) + + def test_load_leukemia_big(): result = load_leukemia_big() assert result is not None @@ -52,3 +73,12 @@ def test_load_leukemia_big(): assert isinstance(result[1], pd.DataFrame) assert result[0].shape == (72,) assert result[1].shape == (72, 7128) + + +def test_load_leukemia_big_compare_original(): + result = load_leukemia_big() + ori_result = load_leukemia_data() + assert result[0].shape == ori_result[0].shape + assert result[1].shape == ori_result[1].shape + assert_almost_equal(result[0].to_numpy(), ori_result[0].to_numpy()) + assert_almost_equal(result[1].to_numpy(), ori_result[1].to_numpy())