From 5765fcb1c7577459bd2524c2377a27849d8dca2e Mon Sep 17 00:00:00 2001 From: Andreas Berg Date: Thu, 27 Jun 2024 10:04:36 +0200 Subject: [PATCH] Changed predict image by post function --- src/customvision/classifier.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/customvision/classifier.py b/src/customvision/classifier.py index 3ca892b..ca64869 100644 --- a/src/customvision/classifier.py +++ b/src/customvision/classifier.py @@ -48,6 +48,7 @@ def __init__(self) -> None: None """ self.ENDPOINT = Keys.get("CV_ENDPOINT") + self.PREDICTION_ENDPOINT = Keys.get("CV_PREDICTION_ENDPOINT") self.project_id = Keys.get("CV_PROJECT_ID") self.prediction_key = Keys.get("CV_PREDICTION_KEY") self.training_key = Keys.get("CV_TRAINING_KEY") @@ -58,7 +59,7 @@ def __init__(self) -> None: in_headers={"Prediction-key": self.prediction_key} ) self.predictor = CustomVisionPredictionClient( - self.ENDPOINT, self.prediction_credentials + self.PREDICTION_ENDPOINT, self.prediction_credentials ) self.training_credentials = ApiKeyCredentials( in_headers={"Training-key": self.training_key} @@ -155,9 +156,11 @@ def predict_image_by_post(self, img) -> Dict[str, float]: """ headers = {'content-type': 'application/octet-stream', "prediction-key": self.prediction_key} - res = requests.post(Keys.get("CV_PREDICTION_ENDPOINT"), img.read(), headers=headers).json() + res = self.predictor.classify_image( + self.project_id, self.iteration_name, img.read(), custom_headers=headers + ) img.seek(0) - pred_kv = dict([(i["tagName"], i["probability"]) for i in res["predictions"]]) + pred_kv = dict([(i.tag_name, i.probability) for i in res.predictions]) best_guess = max(pred_kv, key=pred_kv.get) return pred_kv, best_guess