-
Notifications
You must be signed in to change notification settings - Fork 151
/
Copy pathModelDownloader.py
52 lines (39 loc) · 1.31 KB
/
ModelDownloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import logging
import time
from pathlib import Path
MODEL_DIR = Path("models")
class ModelDownloader:
def __init__(self):
self.model_name = "<no-name>"
self.dir_path = MODEL_DIR
self.logger = logging.getLogger(__name__)
def from_modelscope(self):
raise NotImplementedError()
def from_huggingface(self):
raise NotImplementedError()
def check_exist(self) -> bool:
return NotImplementedError()
def gc(self):
raise NotImplementedError()
def __call__(self, source: str):
self.execate(downloader=self, source=source)
@staticmethod
def execate(*, downloader: "ModelDownloader", source: str):
if downloader.check_exist():
print(f"Model {downloader.model_name} already exists.")
return
if source == "modelscope" or source == "ms":
downloader.from_modelscope()
elif source == "huggingface" or source == "hf":
downloader.from_huggingface()
else:
raise ValueError("Invalid source")
# after check
times = 5
for i in range(times):
if downloader.check_exist():
break
time.sleep(5)
if i == times - 1:
raise TimeoutError("Download timeout")
downloader.gc()