Skip to content

Commit

Permalink
Merge pull request #9 from rekcurd/feature/EvaluateResult_proto
Browse files Browse the repository at this point in the history
Meged
  • Loading branch information
keigohtr authored Jan 15, 2019
2 parents 0afb52c + f2cbb08 commit 7a5859e
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 15 deletions.
47 changes: 33 additions & 14 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import csv
import os

from typing import Tuple, List
from typing import Tuple, List, Generator

from drucker.logger import JsonSystemLogger
from drucker import Drucker
from drucker.utils import PredictLabel, PredictResult, EvaluateResult, EvaluateDetail
from drucker.utils import PredictLabel, PredictResult, EvaluateResult, EvaluateDetail, EvaluateResultDetail

import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
Expand Down Expand Up @@ -62,7 +62,13 @@ def predict(self, input: PredictLabel, option: dict = None) -> PredictResult:
self.logger.error(traceback.format_exc())
raise e

def evaluate(self, file_path: str) -> Tuple[EvaluateResult, List[EvaluateDetail]]:
def __generate_eval_data(self, file_path: str) -> Generator[EvaluateDetail, None, None]:
with open(file_path, 'r') as f:
reader = csv.reader(f, delimiter=",")
for row in reader:
yield int(row[0]), row[1:]

def evaluate(self, file_path: str) -> Tuple[EvaluateResult, List[EvaluateResultDetail]]:
""" override
Evaluate
Expand All @@ -73,7 +79,7 @@ def evaluate(self, file_path: str) -> Tuple[EvaluateResult, List[EvaluateDetail]
precision: Precision. arr[float]
recall: Recall. arr[float]
fvalue: F1 value. arr[float]
option: optional metrics. dict[str, float]
option: Optional metrics. dict[str, float]
details: detail result of each prediction
"""
Expand All @@ -82,16 +88,13 @@ def evaluate(self, file_path: str) -> Tuple[EvaluateResult, List[EvaluateDetail]
label_gold = []
label_predict = []
details = []
with open(file_path, 'r') as f:
reader = csv.reader(f, delimiter=",")
for row in reader:
num += 1
correct_label = int(row[0])
label_gold.append(correct_label)
result = self.predict(row[1:], option={})
is_correct = correct_label == int(result.label[0])
details.append(EvaluateDetail(result, is_correct))
label_predict.append(result.label)
for correct_label, data in self.__generate_eval_data(file_path):
num += 1
label_gold.append(correct_label)
result = self.predict(data, option={})
is_correct = correct_label == int(result.label[0])
details.append(EvaluateResultDetail(result, is_correct))
label_predict.append(result.label)

accuracy = accuracy_score(label_gold, label_predict)
p_r_f = precision_recall_fscore_support(label_gold, label_predict)
Expand All @@ -101,3 +104,19 @@ def evaluate(self, file_path: str) -> Tuple[EvaluateResult, List[EvaluateDetail]
self.logger.error(str(e))
self.logger.error(traceback.format_exc())
return EvaluateResult(), []

def get_evaluate_detail(self, file_path: str, results: List[EvaluateResultDetail]) -> Generator[EvaluateDetail, None, None]:
""" override
Create EvaluateDetail by merging evaluation data from file_path and EvaluateResultDetail
:param file_path: Evaluation data file path. str
:param results: Detail result of each prediction
:return:
detail: Evaluation data & result of each prediction
"""
try:
for i, (correct_label, data) in enumerate(self.__generate_eval_data(file_path)):
yield EvaluateDetail(input=data, label=correct_label, result=results[i])
except Exception as e:
self.logger.error(str(e))
self.logger.error(traceback.format_exc())
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
drucker==0.4.3
drucker==0.4.4
scikit-learn==0.19.1
numpy==1.14.3
scipy==1.1.0

0 comments on commit 7a5859e

Please sign in to comment.