diff --git a/merlin/core/compat/__init__.py b/merlin/core/compat/__init__.py index dc5ec84ed..4d81d60b9 100644 --- a/merlin/core/compat/__init__.py +++ b/merlin/core/compat/__init__.py @@ -21,8 +21,7 @@ from merlin.core.has_gpu import HAS_GPU # noqa pylint: disable=unused-import -if not cuda.is_available(): - cuda = None +cuda = None if not HAS_GPU else cuda try: import psutil @@ -98,17 +97,8 @@ def device_mem_size(kind="total", cpu=False): if kind not in ["free", "total"]: raise ValueError(f"{kind} not a supported option for device_mem_size.") - try: - if kind == "free": - return int(cuda.current_context().get_memory_info()[0]) - else: - return int(cuda.current_context().get_memory_info()[1]) - except NotImplementedError: - if kind == "free": - # Not using NVML "free" memory, because it will not include RMM-managed memory - warnings.warn("get_memory_info is not supported. Using total device memory from NVML.") - size = pynvml_mem_size(kind="total", index=0) - return size + + return pynvml_mem_size(kind=kind) try: diff --git a/merlin/io/writer.py b/merlin/io/writer.py index 66c24170d..c4a5aa788 100644 --- a/merlin/io/writer.py +++ b/merlin/io/writer.py @@ -196,7 +196,11 @@ def _add_data_slice(self, df): if self.shuffle: df = shuffle_df(df) int_slice_size = df.shape[0] // self.num_out_files - slice_size = int_slice_size if df.shape[0] % int_slice_size == 0 else int_slice_size + 1 + slice_size = ( + int_slice_size + if int_slice_size > 0 and df.shape[0] % int_slice_size == 0 + else int_slice_size + 1 + ) for x in range(self.num_out_files): start = x * slice_size end = start + slice_size