-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_package.py
60 lines (47 loc) · 1.68 KB
/
model_package.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class ModelPackage:
"""Several models that can be requested dynamically."""
def __init__(self):
self.models = []
@property
def ready(self):
return self.models and all(model.is_ready() for model in self.models)
def add(self, model):
self.models.append(model)
def info(self) -> list:
return [
{"key": str(ix), "name": model.model_name, "path": model.model_path}
for ix, model in enumerate(self.models)
]
def model_index(self, key=None):
"""Find a model by a key which could be an index, name or path.
The returned index is guaranteed to exist in the model list.
If the model is not found, returns None.
"""
if not self.models:
return None
if key is None:
return 0
try:
ix = int(key)
return ix % len(self.models)
except ValueError:
pass
try:
return next(
ix
for ix, model in enumerate(self.models)
if key in (model.model_name, model.model_path)
)
except StopIteration:
return None
def classify(self, data, model_key: str):
"""Forward the classification task to the appropriate model.
This checks that the requested model exists and is ready.
"""
ix = self.model_index(model_key)
if ix is None:
raise ValueError(f"No model found for {model_key}")
model = self.models[ix]
if not model.is_ready():
raise ValueError(f"The specified model {model_key} was not ready")
return self.models[ix].classify(data)