Skip to content

Commit

Permalink
fixes in tensor generator (#327)
Browse files Browse the repository at this point in the history
* #260 #324 #326

* get stats q

* variable names
  • Loading branch information
StevenSong authored Jun 18, 2020
1 parent 9754136 commit f747f7c
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions ml4cvd/tensor_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ def __init__(
self._started = False
self.workers = []
self.worker_instances = []
if num_workers == 0:
num_workers = 1 # The one worker is the main thread
self.batch_size, self.input_maps, self.output_maps, self.num_workers, self.cache_size, self.weights, self.name, self.keep_paths = \
batch_size, input_maps, output_maps, num_workers, cache_size, weights, name, keep_paths
self.true_epochs = 0
self.stats_string = ""
if num_workers == 0:
num_workers = 1 # The one worker is the main thread
if weights is None:
worker_paths = np.array_split(paths, num_workers)
self.true_epoch_lens = list(map(len, worker_paths))
Expand Down Expand Up @@ -148,7 +148,7 @@ def _init_workers(self):
)
process.start()
self.workers.append(process)
logging.info(f"Started {i} {self.name.replace('_', ' ')}s with cache size {self.cache_size/1e9}GB.")
logging.info(f"Started {i + 1} {self.name.replace('_', ' ')}s with cache size {self.cache_size/1e9}GB.")

def set_worker_paths(self, paths: List[Path]):
"""In the single worker case, set the worker's paths."""
Expand All @@ -161,11 +161,11 @@ def set_worker_paths(self, paths: List[Path]):
def __next__(self) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Optional[List[str]]]:
if not self._started:
self._init_workers()
if self.stats_q.qsize() == self.num_workers:
self.aggregate_and_print_stats()
if self.run_on_main_thread:
return next(self.worker_instances[0])
else:
if self.stats_q.qsize() == self.num_workers:
self.aggregate_and_print_stats()
return self.q.get(TENSOR_GENERATOR_TIMEOUT)

def aggregate_and_print_stats(self):
Expand Down Expand Up @@ -227,7 +227,7 @@ def aggregate_and_print_stats(self):
f"{stats['Tensors presented']:0.0f} tensors were presented.",
f"{stats['skipped_paths']} paths were skipped because they previously failed.",
f"{error_info}",
f"{self.stats_string}"
f"{self.stats_string}",
])
logging.info(f"\n!!!!>~~~~~~~~~~~~ {self.name} completed true epoch {self.true_epochs} ~~~~~~~~~~~~<!!!!\nAggregated information string:\n\t{info_string}")

Expand Down Expand Up @@ -725,6 +725,8 @@ def test_train_valid_tensor_generators(
tensors: str,
batch_size: int,
num_workers: int,
training_steps: int,
validation_steps: int,
cache_size: float,
balance_csvs: List[str],
keep_paths: bool = False,
Expand Down Expand Up @@ -786,8 +788,11 @@ def test_train_valid_tensor_generators(
test_csv=test_csv,
)
weights = None
generate_train = TensorGenerator(batch_size, tensor_maps_in, tensor_maps_out, train_paths, num_workers, cache_size, weights, keep_paths, mixup_alpha, name='train_worker', siamese=siamese, augment=True, sample_weight=sample_weight)
generate_valid = TensorGenerator(batch_size, tensor_maps_in, tensor_maps_out, valid_paths, num_workers // 2, cache_size, weights, keep_paths, name='validation_worker', siamese=siamese, augment=False)

num_train_workers = int(training_steps / (training_steps + validation_steps) * num_workers) or (1 if num_workers else 0)
num_valid_workers = int(validation_steps / (training_steps + validation_steps) * num_workers) or (1 if num_workers else 0)
generate_train = TensorGenerator(batch_size, tensor_maps_in, tensor_maps_out, train_paths, num_train_workers, cache_size, weights, keep_paths, mixup_alpha, name='train_worker', siamese=siamese, augment=True, sample_weight=sample_weight)
generate_valid = TensorGenerator(batch_size, tensor_maps_in, tensor_maps_out, valid_paths, num_valid_workers, cache_size, weights, keep_paths, name='validation_worker', siamese=siamese, augment=False)
generate_test = TensorGenerator(batch_size, tensor_maps_in, tensor_maps_out, test_paths, num_workers, 0, weights, keep_paths or keep_paths_test, name='test_worker', siamese=siamese, augment=False)
return generate_train, generate_valid, generate_test

Expand Down

0 comments on commit f747f7c

Please sign in to comment.