Skip to content

Commit

Permalink
Merge pull request #5 from drucker/feature/evaluate-model
Browse files Browse the repository at this point in the history
Merged
  • Loading branch information
keigohtr authored Nov 20, 2018
2 parents d8e5425 + e367b07 commit 2be1396
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import io

from enum import Enum
from typing import Tuple, List

from drucker.logger import JsonSystemLogger
from drucker import Drucker
from drucker.utils import PredictLabel, PredictResult, EvaluateResult
from drucker.utils import PredictLabel, PredictResult, EvaluateResult, EvaluateDetail

import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
Expand Down Expand Up @@ -72,9 +73,9 @@ def predict(self, input: PredictLabel, option: dict = None) -> PredictResult:
self.logger.error(traceback.format_exc())
raise e

def evaluate(self, file: bytes) -> EvaluateResult:
def evaluate(self, file: bytes) -> Tuple[EvaluateResult, List[EvaluateDetail]]:
""" override
[WIP] Evaluate
Evaluate
:param file: Evaluation data file. bytes
:return:
Expand All @@ -83,23 +84,31 @@ def evaluate(self, file: bytes) -> EvaluateResult:
precision: Precision. arr[float]
recall: Recall. arr[float]
fvalue: F1 value. arr[float]
option: optional metrics. dict[str, float]
details: detail result of each prediction
"""
try:
f = io.StringIO(file.decode("utf-8"))
reader = csv.reader(f, delimiter=",")
num = 0
label_gold = []
label_predict = []
details = []
for row in reader:
num += 1
label_gold.append(int(row[0]))
correct_label = int(row[0])
label_gold.append(correct_label)
result = self.predict(row[1:], option={})
is_correct = correct_label == int(result.label[0])
details.append(EvaluateDetail(input, correct_label, result, is_correct))
label_predict.append(result.label)

accuracy = accuracy_score(label_gold, label_predict)
p_r_f = precision_recall_fscore_support(label_gold, label_predict)
return EvaluateResult(num, accuracy, p_r_f[0].tolist(), p_r_f[1].tolist(), p_r_f[2].tolist())
res = EvaluateResult(num, accuracy, p_r_f[0].tolist(), p_r_f[1].tolist(), p_r_f[2].tolist(), {})
return res, details
except Exception as e:
self.logger.error(str(e))
self.logger.error(traceback.format_exc())
return EvaluateResult()
return EvaluateResult(), []

0 comments on commit 2be1396

Please sign in to comment.