diff --git a/intel_extension_for_transformers/neural_chat/models/base_model.py b/intel_extension_for_transformers/neural_chat/models/base_model.py index 1aa31ef1782..5f3616f89ff 100644 --- a/intel_extension_for_transformers/neural_chat/models/base_model.py +++ b/intel_extension_for_transformers/neural_chat/models/base_model.py @@ -58,6 +58,16 @@ def construct_parameters(query, model_name, device, assistant_model, config): params["device"] = device return params +def safe_path(*paths): + # Prevent path traversal by ensuring the final path is within the base path or assets_path + current_working_directory = os.getcwd() + path_parts = current_working_directory.split('/') + base_path = '/' + path_parts[1] + assets_path = '/intel-extension-for-transformers/intel_extension_for_transformers/neural_chat/assets/' + final_path = os.path.abspath(*paths) + if final_path.startswith(base_path) or final_path.startswith(assets_path): + return final_path + class BaseModel(ABC): """A base class for LLM.""" @@ -158,7 +168,7 @@ def predict_stream(self, query, origin_query="", config=None): my_origin_query = origin_query if is_audio_file(query): - if not os.path.exists(query): + if not os.path.exists(safe_path(query)): raise ValueError(f"The audio file path {query} is invalid.") query_include_prompt = False @@ -181,7 +191,7 @@ def predict_stream(self, query, origin_query="", config=None): if response: logging.info("Get response: %s from cache", response) return response['choices'][0]['text'], link - if plugin_name == "asr" and not os.path.exists(query): + if plugin_name == "asr" and not os.path.exists(safe_path(query)): continue if plugin_name == "retrieval": try: @@ -281,7 +291,7 @@ def predict(self, query, origin_query="", config=None): config.ipex_int8 = self.ipex_int8 if is_audio_file(query): - if not os.path.exists(query): + if not os.path.exists(safe_path(query)): raise ValueError(f"The audio file path {query} is invalid.") query_include_prompt = False @@ -302,7 +312,7 @@ def predict(self, query, origin_query="", config=None): if response: logging.info("Get response: %s from cache", response) return response['choices'][0]['text'] - if plugin_name == "asr" and not os.path.exists(query): + if plugin_name == "asr" and not os.path.exists(safe_path(query)): continue if plugin_name == "retrieval": try: