forked from cortexlabs/cortex
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredictor.py
50 lines (39 loc) · 1.49 KB
/
predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# WARNING: you are on the master branch, please refer to the examples on the branch that matches your `cortex version`
import re
import torch
import os
import boto3
from botocore import UNSIGNED
from botocore.client import Config
from model import IrisNet
labels = ["setosa", "versicolor", "virginica"]
class PythonPredictor:
def __init__(self, config):
# download the model
bucket, key = re.match("s3://(.+?)/(.+)", config["model"]).groups()
if os.environ.get("AWS_ACCESS_KEY_ID"):
s3 = boto3.client("s3") # client will use your credentials if available
else:
s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED)) # anonymous client
s3.download_file(bucket, key, "/tmp/model.pth")
# initialize the model
model = IrisNet()
model.load_state_dict(torch.load("/tmp/model.pth"))
model.eval()
self.model = model
def predict(self, payload):
# Convert the request to a tensor and pass it into the model
input_tensor = torch.FloatTensor(
[
[
payload["sepal_length"],
payload["sepal_width"],
payload["petal_length"],
payload["petal_width"],
]
]
)
# Run the prediction
output = self.model(input_tensor)
# Translate the model output to the corresponding label string
return labels[torch.argmax(output[0])]