-
-
Notifications
You must be signed in to change notification settings - Fork 580
/
docker_prepare.py
55 lines (44 loc) · 1.43 KB
/
docker_prepare.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import asyncio
from argparse import ArgumentParser
from manga_translator.utils import ModelWrapper
from manga_translator.detection import DETECTORS
from manga_translator.ocr import OCRS
from manga_translator.inpainting import INPAINTERS
arg_parser = ArgumentParser()
arg_parser.add_argument("--models", default="")
arg_parser.add_argument("--continue-on-error", action="store_true")
cli_args = arg_parser.parse_args()
async def download(dict):
""" """
for key, value in dict.items():
if issubclass(value, ModelWrapper):
print(" -- Downloading", key)
try:
inst = value()
await inst.download()
except Exception as e:
print("Failed to download", key, value)
print(e)
if not cli_args.continue_on_error:
raise
async def main():
models: set[str] = set(filter(None, cli_args.models.split(",")))
await download(
{
k: v
for k, v in DETECTORS.items()
if (not models) or (f"detector.{k}" in models)
}
)
await download(
{k: v for k, v in OCRS.items() if (not models) or (f"ocr.{k}" in models)}
)
await download(
{
k: v
for k, v in INPAINTERS.items()
if (not models) or (f"inpaint.{k}" in models) and (k not in ["sd"])
}
)
if __name__ == "__main__":
asyncio.run(main())