From 95873b693ee67222fa0d221fab5ca5a36d70d136 Mon Sep 17 00:00:00 2001 From: Eisha Tir Raazia <49692285+EishaMazhar@users.noreply.github.com> Date: Fri, 25 Oct 2024 14:29:09 +0200 Subject: [PATCH] Evaluation (#18) * update docker file * update evaluation_s3 function * Complete evaluate_s3 logic and upload data to s3 * add .csv format to .gitignore * Remove extra comments and imports from main * update requirements * update requirements * update the evaluation code to cater to algorithms * add loop for algos over datasets --------- Co-authored-by: Cyril Matthey-Doret --- .gitignore | 2 +- evaluation/__main__.py | 20 ++--- evaluation/evaluate.py | 153 +++++++++++++++++++++++++++++++++++- evaluation/requirements.txt | 34 +++++++- 4 files changed, 197 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 2392931..0bae38d 100644 --- a/.gitignore +++ b/.gitignore @@ -3,8 +3,8 @@ __pycache__/ *.py[cod] *$py.class .DS_Store +*.csv website/data/* - # C extensions *.so diff --git a/evaluation/__main__.py b/evaluation/__main__.py index 5942634..b2a8733 100644 --- a/evaluation/__main__.py +++ b/evaluation/__main__.py @@ -1,12 +1,14 @@ if __name__ == "__main__": - import argparse - from evaluate.evaluate import evaluate + from evaluate import evaluate_s3, evaluate + import os + from dotenv import load_dotenv + load_dotenv('evaluation/.env') - parser = argparse.ArgumentParser( - description="Evaluation code to compare annotations from a seizure detection algorithm to ground truth annotations." - ) - parser.add_argument("ref", help="Path to the root folder containing the reference annotations.") - parser.add_argument("hyp", help="Path to the root folder containing the hypothesis annotations.") - args = parser.parse_args() - evaluate(args.ref, args.hyp) \ No newline at end of file + AWS_REGION = os.getenv('AWS_REGION') + AWS_BUCKET = os.getenv('AWS_BUCKET') + AWS_ACCESS_KEY = os.getenv('AWS_ACCESS_KEY') + AWS_SECRET_KEY = os.getenv('AWS_SECRET_KEY') + + + evaluate_s3(AWS_REGION, AWS_BUCKET, AWS_ACCESS_KEY, AWS_SECRET_KEY) diff --git a/evaluation/evaluate.py b/evaluation/evaluate.py index 155d164..417bd81 100644 --- a/evaluation/evaluate.py +++ b/evaluation/evaluate.py @@ -2,6 +2,9 @@ import numpy as np import pandas as pd +import boto3 +from tempfile import NamedTemporaryFile +import json from epilepsy2bids.annotations import Annotations, SeizureType from timescoring import annotations, scoring @@ -14,7 +17,7 @@ def toMask(annotations): for event in annotations.events: if event["eventType"].value != "bckg": mask[ - round(event["onset"] * FS) : round(event["onset"] + event["duration"]) + round(event["onset"] * FS): round(event["onset"] + event["duration"]) * FS ] = 1 return mask @@ -118,3 +121,151 @@ def evaluate(refFolder: str, hypFolder: str): + "- F1-score : {:.2f} \n".format(f1) + "- FP/24h : {:.2f} \n".format(fpRate) ) + +def evaluate_s3(AWS_REGION: str, AWS_BUCKET: str, AWS_ACCESS_KEY: str, AWS_SECRET_KEY: str): + + results = { + "dataset": [], + "subject": [], + "file": [], + "algorithm": [], + "duration": [], + "tp_sample": [], + "fp_sample": [], + "refTrue_sample": [], + "tp_event": [], + "fp_event": [], + "refTrue_event": [], + } + + s3 = boto3.client('s3', aws_access_key_id=AWS_ACCESS_KEY, + aws_secret_access_key=AWS_SECRET_KEY, region_name=AWS_REGION) + # List objects in the bucket + + response = s3.list_objects_v2(Bucket=AWS_BUCKET) + count = 0 + + for obj in response.get('Contents', []): + file_path = obj['Key'] + + if file_path.endswith('.tsv') and not file_path.endswith('participants.tsv') and file_path.startswith('datasets/'): + refTsv = file_path + DATASET = refTsv.split('/')[1] + + # Get the object from S3 + tsv_content = s3.get_object(Bucket=AWS_BUCKET, Key=refTsv) + +# Read the content of the file + tsv_obj = tsv_content['Body'].read().decode('utf-8') + + # Write the content to a temporary file (we can remove this later once library is updated) + with NamedTemporaryFile(delete=False, mode='w', suffix='.tsv') as temp_file: + temp_file.write(tsv_obj) + temp_file_path = temp_file.name + + ref = Annotations.loadTsv(temp_file_path) + ref = annotations.Annotation(toMask(ref), FS) + print(refTsv, "\nRef Annotation: ", ref) + + # Get the corresponding hypothesis file from the algo1 folder (we can change this to alregular expression based logic later) + hypTsv_base = "submissions/ghcr-io-esl-epfl-gotman-1982-latest/" + + hypTsv = hypTsv_base + refTsv.replace("datasets/", "", 1) + + + try: + hyp_tsv_content = s3.get_object(Bucket=AWS_BUCKET, Key=hypTsv) + print("\n ref_tsv_path:", refTsv, "\n hyp_tsv_path:", hypTsv, "\n hyp_tsv_content:", hyp_tsv_content, "datasetname: ", DATASET) + + with NamedTemporaryFile(delete=False, mode='w', suffix='.tsv') as temp_file2: + temp_file2.write(tsv_obj) + temp_file2_path = temp_file2.name + hyp = Annotations.loadTsv(temp_file2_path) + hyp = annotations.Annotation(toMask(hyp), FS) + except Exception as e: + print(f"Error loading hypothesis file: {e}") + hyp = annotations.Annotation(np.zeros_like(ref.mask), ref.fs) + + sampleScore = scoring.SampleScoring(ref, hyp) + eventScore = scoring.EventScoring(ref, hyp) + + # results["dataset"].append(DATASET) + + # dataset logic for testing + results["dataset"].append(DATASET) + + results["subject"].append(refTsv.split("/")[2]) + results["file"].append(refTsv.split("/")[-1]) + results["algorithm"].append(hypTsv.split("/")[1]) + results["duration"].append(len(ref.mask) / ref.fs) + results["tp_sample"].append(sampleScore.tp) + results["fp_sample"].append(sampleScore.fp) + results["refTrue_sample"].append(sampleScore.refTrue) + results["tp_event"].append(eventScore.tp) + results["fp_event"].append(eventScore.fp) + results["refTrue_event"].append(eventScore.refTrue) + count += 1 + + print(count) + + results = pd.DataFrame(results) + grouped_results = results.groupby('dataset')[ + ['tp_sample', 'fp_sample', 'refTrue_sample', 'duration']].sum().reset_index() + print(grouped_results.head()) + + results.to_csv("results.csv") + + result_dict = [] + + for algo in results['algorithm'].unique(): + # Sample results + temp_result = { + "algo_id": algo, + "datasets": [] + } + for dataset in results['dataset'].unique(): + temp = {} + dataset_results = results[(results['dataset'] == dataset) & (results["algorithm"] == algo)] + sensitivity_sample, precision_sample, f1_sample, fpRate_sample = computeScores( + dataset_results["tp_sample"].sum(), + dataset_results["fp_sample"].sum(), + dataset_results["refTrue_sample"].sum(), + dataset_results["duration"].sum(),) + + sensitivity_event, precision_event, f1_event, fpRate_event = computeScores( + dataset_results["tp_event"].sum(), + dataset_results["fp_event"].sum(), + dataset_results["refTrue_event"].sum(), + dataset_results["duration"].sum()) + + temp["dataset"] = dataset + + temp["sample_results"] = {} + temp["sample_results"]["sensitivity"] = sensitivity_sample + temp["sample_results"]["precision"] = precision_sample + temp["sample_results"]["f1"] = f1_sample + temp["sample_results"]["fpRate"] = fpRate_sample + + temp["event_results"] = {} + temp["event_results"]["sensitivity"] = sensitivity_event + temp["event_results"]["precision"] = precision_event + temp["event_results"]["f1"] = f1_event + temp["event_results"]["fpRate"] = fpRate_event + + temp_result['datasets'].append(temp) + + result_dict.append(temp_result) + + # Convert result_dict to JSON + json_object = json.dumps(result_dict) + + # Print JSON object + print(json_object) + + # Write JSON object to S3 + s3.put_object(Bucket=AWS_BUCKET, Key='results/results.json', Body=json_object) + + # Write results.csv to S3 + with open("results.csv", "rb") as csv_file: + s3.put_object(Bucket=AWS_BUCKET, Key='results/results.csv', Body=csv_file) + diff --git a/evaluation/requirements.txt b/evaluation/requirements.txt index 24ce15a..63257f8 100644 --- a/evaluation/requirements.txt +++ b/evaluation/requirements.txt @@ -1 +1,33 @@ -numpy +boto3==1.35.47 +botocore==1.35.47 +contourpy==1.3.0 +cycler==0.12.1 +epilepsy2bids==0.0.1 +fonttools==4.54.1 +gitdb==4.0.11 +GitPython==3.1.41 +jmespath==1.0.1 +kiwisolver==1.4.7 +llvmlite==0.43.0 +matplotlib==3.9.2 +nptyping==2.5.0 +numba==0.60.0 +numpy==1.26.4 +packaging==24.1 +pandas==2.2.3 +pillow==11.0.0 +pyarrow==17.0.0 +pyEDFlib==0.1.38 +pyparsing==3.2.0 +python-dateutil==2.9.0.post0 +python-dotenv==1.0.1 +pytz==2024.2 +resampy==0.4.3 +s3transfer==0.10.3 +setuptools==69.0.3 +six==1.16.0 +smmap==5.0.1 +termcolor==2.5.0 +timescoring==0.0.5 +tzdata==2024.2 +urllib3==1.26.20