diff --git a/README.md b/README.md index 9f9c27c..5a13675 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,7 @@ Also have a look at the other installing method, if you want to use the commands #### Download pretrained weights ```bash -./weights/download_weights.sh +poetry run yoeo-download-weights ``` ## Test diff --git a/pyproject.toml b/pyproject.toml index 7af144e..b39b9f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,3 +34,4 @@ build-backend = "poetry.core.masonry.api" yoeo-detect = "yoeo.detect:run" yoeo-train = "yoeo.train:run" yoeo-test = "yoeo.test:run" +yoeo-download-weights = "scripts.download_weights:run" diff --git a/scripts/download_weights.py b/scripts/download_weights.py new file mode 100755 index 0000000..70f8c9a --- /dev/null +++ b/scripts/download_weights.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python3 +import os +import pathlib +import argparse +import urllib.request + + +def run(): + parser = argparse.ArgumentParser(description='Download YOEO pretrained weights') + parser.add_argument( + '--output', + '-o', + type=pathlib.Path, + default='weights/yoeo.pth', + help='The pretrained weights file (.pth) will be written to this path. Defaults to "weights/yoeo.pth"', + ) + args = parser.parse_args() + + url = "https://data.bit-bots.de/models/2021_12_06_flo_torso21_yoeo_7/yoeo.pth" + output_path = args.output + + with urllib.request.urlopen(url) as input_file: + try: + with open(output_path, 'xb') as output_file: + print(f"Saving pretrained weights to: {output_path}") + output_file.write(input_file.read()) + except FileExistsError as e: + print(f"ERROR: The output file {output_path} does already exist. Will abort and not overwrite.") + + +if __name__ == '__main__': + run() diff --git a/weights/download_weights.sh b/weights/download_weights.sh deleted file mode 100755 index 8e0c664..0000000 --- a/weights/download_weights.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -# Download weights for yoeo-rev-7 -wget "http://data.bit-bots.de/models/2021_12_06_flo_torso21_yoeo_7/yoeo.pth" -# Download weights for vanilla YOLOv3 -wget -c "https://pjreddie.com/media/files/yolov3.weights" --header "Referer: pjreddie.com" -# # Download weights for tiny YOLOv3 -wget -c "https://pjreddie.com/media/files/yolov3-tiny.weights" --header "Referer: pjreddie.com" -# Download weights for backbone network -wget -c "https://pjreddie.com/media/files/darknet53.conv.74" --header "Referer: pjreddie.com"