Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
solve codescan#53-56 in neural-chat (#1627)
Browse files Browse the repository at this point in the history
Signed-off-by: Liangyx2 <[email protected]>
  • Loading branch information
Liangyx2 authored Jun 21, 2024
1 parent 25e3741 commit 1418339
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions intel_extension_for_transformers/neural_chat/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,16 @@ def construct_parameters(query, model_name, device, assistant_model, config):
params["device"] = device
return params

def safe_path(*paths):
# Prevent path traversal by ensuring the final path is within the base path or assets_path
current_working_directory = os.getcwd()
path_parts = current_working_directory.split('/')
base_path = '/' + path_parts[1]
assets_path = '/intel-extension-for-transformers/intel_extension_for_transformers/neural_chat/assets/'
final_path = os.path.abspath(*paths)
if final_path.startswith(base_path) or final_path.startswith(assets_path):
return final_path

class BaseModel(ABC):
"""A base class for LLM."""

Expand Down Expand Up @@ -158,7 +168,7 @@ def predict_stream(self, query, origin_query="", config=None):
my_origin_query = origin_query

if is_audio_file(query):
if not os.path.exists(query):
if not os.path.exists(safe_path(query)):
raise ValueError(f"The audio file path {query} is invalid.")

query_include_prompt = False
Expand All @@ -181,7 +191,7 @@ def predict_stream(self, query, origin_query="", config=None):
if response:
logging.info("Get response: %s from cache", response)
return response['choices'][0]['text'], link
if plugin_name == "asr" and not os.path.exists(query):
if plugin_name == "asr" and not os.path.exists(safe_path(query)):
continue
if plugin_name == "retrieval":
try:
Expand Down Expand Up @@ -281,7 +291,7 @@ def predict(self, query, origin_query="", config=None):
config.ipex_int8 = self.ipex_int8

if is_audio_file(query):
if not os.path.exists(query):
if not os.path.exists(safe_path(query)):
raise ValueError(f"The audio file path {query} is invalid.")

query_include_prompt = False
Expand All @@ -302,7 +312,7 @@ def predict(self, query, origin_query="", config=None):
if response:
logging.info("Get response: %s from cache", response)
return response['choices'][0]['text']
if plugin_name == "asr" and not os.path.exists(query):
if plugin_name == "asr" and not os.path.exists(safe_path(query)):
continue
if plugin_name == "retrieval":
try:
Expand Down

0 comments on commit 1418339

Please sign in to comment.