diff --git a/src/python/txtai/models/pooling/factory.py b/src/python/txtai/models/pooling/factory.py index 131367d0e..7e399f0da 100644 --- a/src/python/txtai/models/pooling/factory.py +++ b/src/python/txtai/models/pooling/factory.py @@ -67,17 +67,12 @@ def method(path): # Default method method = "meanpooling" - try: - # Load 1_Pooling/config.json file - config = PoolingFactory.load(path, "1_Pooling/config.json") - - # Set to CLS pooling if it's enabled and mean pooling is disabled - if config["pooling_mode_cls_token"] and not config["pooling_mode_mean_tokens"]: - method = "clspooling" + # Load 1_Pooling/config.json file + config = PoolingFactory.load(path, "1_Pooling/config.json") - # Ignore this error - except OSError: - pass + # Set to CLS pooling if it's enabled and mean pooling is disabled + if config and config["pooling_mode_cls_token"] and not config["pooling_mode_mean_tokens"]: + method = "clspooling" return method @@ -97,13 +92,8 @@ def maxlength(path): maxlength = None # Read max_seq_length from sentence_bert_config.json - try: - config = PoolingFactory.load(path, "sentence_bert_config.json") - maxlength = config.get("max_seq_length") - - # Ignore this error - except OSError: - pass + config = PoolingFactory.load(path, "sentence_bert_config.json") + maxlength = config.get("max_seq_length") if config else maxlength return maxlength @@ -121,6 +111,15 @@ def load(path, name): """ # Download file and parse JSON - path = cached_file(path_or_repo_id=path, filename=name) - with open(path, encoding="utf-8") as f: - return json.load(f) + config = None + try: + path = cached_file(path_or_repo_id=path, filename=name) + if path: + with open(path, encoding="utf-8") as f: + config = json.load(f) + + # Ignore this error - invalid repo or directory + except OSError: + pass + + return config diff --git a/src/python/txtai/pipeline/audio/texttospeech.py b/src/python/txtai/pipeline/audio/texttospeech.py index 0315180eb..ed35cc545 100644 --- a/src/python/txtai/pipeline/audio/texttospeech.py +++ b/src/python/txtai/pipeline/audio/texttospeech.py @@ -115,13 +115,14 @@ def hasfile(self, path, name): True if name exists in path, False otherwise """ + exists = False try: # Check if file exists - cached_file(path_or_repo_id=path, filename=name) + exists = cached_file(path_or_repo_id=path, filename=name) is not None except OSError: return False - return True + return exists def stream(self, texts, speaker): """