diff --git a/src/coastseg/zoo_model.py b/src/coastseg/zoo_model.py index cb48f0a..6a23e3c 100644 --- a/src/coastseg/zoo_model.py +++ b/src/coastseg/zoo_model.py @@ -608,6 +608,8 @@ def set_settings(self, **kwargs): "use_GPU": "0", "implementation": "BEST", "model_type": "global_segformer_RGB_4class_14036903", + "local_model_path": "", # local path to the directory containing the model + "use_local_model": True, # Use local model (not one from zeneodo) "otsu": False, "tta": False, "cloud_thresh": 0.5, # threshold on maximum cloud cover @@ -995,6 +997,39 @@ def postprocess_data( file_utilities.move_files(outputs_path, session_path, delete_src=True) session.save(session.path) + def get_weights_directory(self,model_implementation:str, model_id: str) -> str: + """ + Retrieves the directory path where the model weights are stored. + This method determines whether to use a local model path or to download the model + from a remote source based on the settings provided. If the local model path is + specified and exists, it will use that path. Otherwise, it will create a directory + for the model and download the weights. + Args: + model_implementation (str): The implementation type of the model either 'BEST' or 'ENSEMBLE' + model_id (str): The identifier for the model. This is the zenodo ID located at the end of the URL + Returns: + str: The directory path where the model weights are stored. + Raises: + FileNotFoundError: If the local model path is specified but does not exist. + """ + + USE_LOCAL_MODEL = self.settings.get("use_local_model", False) + LOCAL_MODEL_PATH = self.settings.get("local_model_path", "") + + if USE_LOCAL_MODEL and not os.path.exists(LOCAL_MODEL_PATH): + raise FileNotFoundError(f"The local model path does not exist at {LOCAL_MODEL_PATH}") + + # check if a local model should be loaded or not + if USE_LOCAL_MODEL == False or LOCAL_MODEL_PATH == "": + # create the model directory & download the model + weights_directory = self.get_model_directory(model_id) + self.download_model(model_implementation, model_id, weights_directory) + else: + # load the model from the local model path + weights_directory = LOCAL_MODEL_PATH + + return weights_directory + def prepare_model(self, model_implementation: str, model_id: str): """ Prepares the model for use by downloading the required files and loading the model. @@ -1003,12 +1038,10 @@ def prepare_model(self, model_implementation: str, model_id: str): model_implementation (str): The model implementation either 'BEST' or 'ENSEMBLE' model_id (str): The ID of the model. """ - self.clear_zoo_model() - # create the model directory - self.weights_directory = self.get_model_directory(model_id) + # weights_directory is the directory that contains the model weights, the model card json files and the BEST_MODEL.txt file + self.weights_directory = self.get_weights_directory(model_implementation, model_id) logger.info(f"self.weights_directory:{self.weights_directory}") - self.download_model(model_implementation, model_id, self.weights_directory) weights_list = self.get_weights_list(model_implementation) # Load the model from the config files @@ -1096,6 +1129,7 @@ def run_model( logger.info(f"use_tta: {use_tta}") print(f"Running model {model_name}") + # print(f"self.settings: {self.settings}") self.prepare_model(model_implementation, model_name) # create a session