Skip to content

Commit

Permalink
Create check_dask_cwd function
Browse files Browse the repository at this point in the history
Signed-off-by: Sarah Yurick <[email protected]>
  • Loading branch information
sarahyurick committed Jan 17, 2025
1 parent 7cfda44 commit 2dc9cf2
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions nemo_curator/utils/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import ast
import os
import shutil
import subprocess

import dask

Expand Down Expand Up @@ -566,6 +567,9 @@ def read_data(
"""
if isinstance(input_files, str):
input_files = [input_files]

check_dask_cwd(input_files)

if file_type == "pickle":
df = read_pandas_pickle(
input_files[0], add_filename=add_filename, columns=columns, **kwargs
Expand Down Expand Up @@ -1014,6 +1018,17 @@ def get_current_client():
return None


def check_dask_cwd(file_list):
if any(not os.path.isabs(file_path) for file_path in file_list):
dask_cwd = list(get_current_client().run(os.getcwd).values())[0]
os_pwd = subprocess.check_output("pwd", shell=True, text=True).strip()
if dask_cwd != os_pwd:
raise RuntimeError(
"Mismatch between Dask client and worker working directories. "
"Use absolute file paths to ensure the correct files are read as intended."
)


def performance_report_if(
path: Optional[str] = None, report_name: str = "dask-profile.html"
):
Expand Down

0 comments on commit 2dc9cf2

Please sign in to comment.