Skip to content

Commit

Permalink
Fix issue with calling cached_file for local directories, closes #825
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Dec 1, 2024
1 parent cb22635 commit 4b5164b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 22 deletions.
39 changes: 19 additions & 20 deletions src/python/txtai/models/pooling/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,12 @@ def method(path):
# Default method
method = "meanpooling"

try:
# Load 1_Pooling/config.json file
config = PoolingFactory.load(path, "1_Pooling/config.json")

# Set to CLS pooling if it's enabled and mean pooling is disabled
if config["pooling_mode_cls_token"] and not config["pooling_mode_mean_tokens"]:
method = "clspooling"
# Load 1_Pooling/config.json file
config = PoolingFactory.load(path, "1_Pooling/config.json")

# Ignore this error
except OSError:
pass
# Set to CLS pooling if it's enabled and mean pooling is disabled
if config and config["pooling_mode_cls_token"] and not config["pooling_mode_mean_tokens"]:
method = "clspooling"

return method

Expand All @@ -97,13 +92,8 @@ def maxlength(path):
maxlength = None

# Read max_seq_length from sentence_bert_config.json
try:
config = PoolingFactory.load(path, "sentence_bert_config.json")
maxlength = config.get("max_seq_length")

# Ignore this error
except OSError:
pass
config = PoolingFactory.load(path, "sentence_bert_config.json")
maxlength = config.get("max_seq_length") if config else maxlength

return maxlength

Expand All @@ -121,6 +111,15 @@ def load(path, name):
"""

# Download file and parse JSON
path = cached_file(path_or_repo_id=path, filename=name)
with open(path, encoding="utf-8") as f:
return json.load(f)
config = None
try:
path = cached_file(path_or_repo_id=path, filename=name)
if path:
with open(path, encoding="utf-8") as f:
config = json.load(f)

# Ignore this error - invalid repo or directory
except OSError:
pass

return config
5 changes: 3 additions & 2 deletions src/python/txtai/pipeline/audio/texttospeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,14 @@ def hasfile(self, path, name):
True if name exists in path, False otherwise
"""

exists = False
try:
# Check if file exists
cached_file(path_or_repo_id=path, filename=name)
exists = cached_file(path_or_repo_id=path, filename=name) is not None
except OSError:
return False

return True
return exists

def stream(self, texts, speaker):
"""
Expand Down

0 comments on commit 4b5164b

Please sign in to comment.