Skip to content

Commit

Permalink
Merge pull request #76 from computas/bug/fix-prediction-endpoint
Browse files Browse the repository at this point in the history
Changed predict image by post function
  • Loading branch information
ThomasBakkenMoe-Computas authored Jun 27, 2024
2 parents 7965073 + 5765fcb commit d85ff89
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/customvision/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self) -> None:
None
"""
self.ENDPOINT = Keys.get("CV_ENDPOINT")
self.PREDICTION_ENDPOINT = Keys.get("CV_PREDICTION_ENDPOINT")
self.project_id = Keys.get("CV_PROJECT_ID")
self.prediction_key = Keys.get("CV_PREDICTION_KEY")
self.training_key = Keys.get("CV_TRAINING_KEY")
Expand All @@ -58,7 +59,7 @@ def __init__(self) -> None:
in_headers={"Prediction-key": self.prediction_key}
)
self.predictor = CustomVisionPredictionClient(
self.ENDPOINT, self.prediction_credentials
self.PREDICTION_ENDPOINT, self.prediction_credentials
)
self.training_credentials = ApiKeyCredentials(
in_headers={"Training-key": self.training_key}
Expand Down Expand Up @@ -155,9 +156,11 @@ def predict_image_by_post(self, img) -> Dict[str, float]:
"""

headers = {'content-type': 'application/octet-stream', "prediction-key": self.prediction_key}
res = requests.post(Keys.get("CV_PREDICTION_ENDPOINT"), img.read(), headers=headers).json()
res = self.predictor.classify_image(
self.project_id, self.iteration_name, img.read(), custom_headers=headers
)
img.seek(0)
pred_kv = dict([(i["tagName"], i["probability"]) for i in res["predictions"]])
pred_kv = dict([(i.tag_name, i.probability) for i in res.predictions])
best_guess = max(pred_kv, key=pred_kv.get)
return pred_kv, best_guess

Expand Down

0 comments on commit d85ff89

Please sign in to comment.