diff --git a/nemo_curator/utils/distributed_utils.py b/nemo_curator/utils/distributed_utils.py index a9d792b4..c8c33789 100644 --- a/nemo_curator/utils/distributed_utils.py +++ b/nemo_curator/utils/distributed_utils.py @@ -16,6 +16,7 @@ import ast import os import shutil +import subprocess import dask @@ -566,6 +567,9 @@ def read_data( """ if isinstance(input_files, str): input_files = [input_files] + + check_dask_cwd(input_files) + if file_type == "pickle": df = read_pandas_pickle( input_files[0], add_filename=add_filename, columns=columns, **kwargs @@ -1014,6 +1018,17 @@ def get_current_client(): return None +def check_dask_cwd(file_list): + if any(not os.path.isabs(file_path) for file_path in file_list): + dask_cwd = list(get_current_client().run(os.getcwd).values())[0] + os_pwd = subprocess.check_output("pwd", shell=True, text=True).strip() + if dask_cwd != os_pwd: + raise RuntimeError( + "Mismatch between Dask client and worker working directories. " + "Use absolute file paths to ensure the correct files are read as intended." + ) + + def performance_report_if( path: Optional[str] = None, report_name: str = "dask-profile.html" ):