Skip to content

Commit

Permalink
Quality of life improvements when using s3 (ML4GW#214)
Browse files Browse the repository at this point in the history
* double num workers

* suppress botocore logging

* make file download logging statements debug

* add max num workers

* add progress bar flag
  • Loading branch information
EthanMarx authored Feb 20, 2025
1 parent 369f85b commit 48aef20
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
1 change: 1 addition & 0 deletions amplfi/train/configs/flow/cbc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ trainer:
strategy: auto
devices: auto
profiler: "simple"
enable_progress_bar: true
logger:
- class_path: lightning.pytorch.loggers.WandbLogger
init_args:
Expand Down
12 changes: 11 additions & 1 deletion amplfi/train/data/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,12 @@ class AmplfiDataset(pl.LightningDataModule):
Defaults to `kernel_length`
min_valid_duration:
Minimum number of seconds of validation background data
num_files_per_batch:
Number of strain hdf5 files to use to construct
each batch. Can lead to dataloading performance increases.
max_num_workers:
Maximum number of workers to assign to each
training dataloader.
"""

Expand All @@ -81,12 +87,14 @@ def __init__(
fftlength: Optional[int] = None,
min_valid_duration: float = 10000,
num_files_per_batch: Optional[int] = None,
max_num_workers: int = 6,
verbose: bool = False,
):
super().__init__()
self.save_hyperparameters(ignore=["waveform_sampler"])
self.init_logging(verbose)
self.waveform_sampler = waveform_sampler
self.max_num_workers = max_num_workers

# generate our local node data directory
# if our specified data source is remote
Expand Down Expand Up @@ -194,7 +202,9 @@ def num_params(self):
@property
def num_workers(self):
local_world_size = len(self.trainer.device_ids)
return min(6, int(os.cpu_count() / local_world_size))
return min(
self.max_num_workers, int(os.cpu_count() / local_world_size)
)

@property
def val_batch_size(self):
Expand Down
9 changes: 6 additions & 3 deletions amplfi/train/data/utils/fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
# s3 retry configuration
retry_config = {"retries": {"total_max_attempts": 10, "mode": "adaptive"}}

# suppress botocore logs
logging.getLogger("botocore").setLevel(logging.WARNING)


def split_data_dir(data_dir: Union[str, Path]) -> Tuple[Optional[str], str]:
"""
Expand Down Expand Up @@ -63,19 +66,19 @@ def _download(
"""

lockfile = target + ".lock"
logging.info(f"Downloading {source} to {target}")
logging.debug(f"Downloading {source} to {target}")
for i in range(num_retries):
with FileLock(lockfile):
if os.path.exists(target):
logging.info(
logging.debug(
f"Object {source} already downloaded by another process"
)
return
try:
s3.get(source, target)
break
except (ResponseStreamingError, FSTimeoutError, ClientError):
logging.info(
logging.debug(
"Download attempt {} for object {} "
"was interrupted, retrying".format(i + 1, source)
)
Expand Down

0 comments on commit 48aef20

Please sign in to comment.