From a447f8373fd4e1ae72289218570bbbb14ef898b5 Mon Sep 17 00:00:00 2001 From: Samuel Felton Date: Wed, 26 Jun 2024 18:46:52 +0200 Subject: [PATCH] Added fallback to ViSP server when downloading megapose models --- script/megapose_server/install.py | 33 ++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/script/megapose_server/install.py b/script/megapose_server/install.py index 6ecf305984..bac589a7fb 100644 --- a/script/megapose_server/install.py +++ b/script/megapose_server/install.py @@ -107,13 +107,32 @@ def download_models(megapose_env: str, megapose_path: Path, megapose_data_path: Download the megapose deep learning models ''' models_path = megapose_data_path / 'megapose-models' - conf_path = megapose_path / 'rclone.conf' - rclone = str(get_rclone_for_conda_env(megapose_env).absolute()) - arguments = [rclone, 'copyto', 'inria_data:megapose-models/', - str(models_path), '--exclude', '*epoch*', - '--config', str(conf_path), '--progress'] - print(' '.join(arguments)) - subprocess.run(arguments, check=True) + models_path.mkdir(exist_ok=True) + try: + conf_path = megapose_path / 'rclone.conf' + rclone = str(get_rclone_for_conda_env(megapose_env).absolute()) + arguments = [rclone, 'copyto', 'inria_data:megapose-models/', + str(models_path), '--exclude', '*epoch*', + '--config', str(conf_path), '--progress'] + print(' '.join(arguments)) + subprocess.run(arguments, check=True) + except: + print('Could not download MegaPose data from the original repo, trying to fetch from the ViSP website') + from urllib.request import urlretrieve + + base_url = 'https://visp-doc.inria.fr/download/model-zoo/megapose-models/' + dirs = ['coarse-rgb-906902141/', 'refiner-rgb-653307694/', 'refiner-rgbd-288182519/'] + files_in_each_dir = ['checkpoint.pth.tar', 'config.yaml', 'log.txt'] + for folder_name in dirs: + dir_url = base_url + folder_name + save_dir = models_path / folder_name + save_dir.mkdir(exist_ok=True) + for file_name in files_in_each_dir: + full_url = dir_url + file_name + print(full_url) + _, headers = urlretrieve(full_url, str(save_dir / file_name)) + + def install_server(megapose_env: str):