Skip to content

Commit

Permalink
resolved #308 (comment)
Browse files Browse the repository at this point in the history
  • Loading branch information
hvgazula committed Mar 23, 2024
1 parent ad24b26 commit 093daff
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions nobrainer/processing/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from .base import BaseEstimator
from .. import losses
from ..dataset import get_dataset
from ..dataset import Dataset


class ProgressiveGeneration(BaseEstimator):
Expand Down Expand Up @@ -136,8 +136,8 @@ def _compile():
d_loss_fn=d_loss,
)

print(self.model_.generator.summary())
print(self.model_.discriminator.summary())
self.model_.generator.summary()
self.model_.discriminator.summary()

for resolution, info in dataset_train.items():
if resolution < self.current_resolution_:
Expand All @@ -147,26 +147,24 @@ def _compile():
if batch_size % self.strategy.num_replicas_in_sync:
raise ValueError("batch size must be a multiple of the number of GPUs")

dataset = get_dataset(
dataset = Dataset.from_tfrecords(
file_pattern=info.get("file_pattern"),
batch_size=batch_size,
num_parallel_calls=num_parallel_calls,
volume_shape=(resolution, resolution, resolution),
n_classes=1,
scalar_label=True,
normalizer=info.get("normalizer") or normalizer,
scalar_labels=True,
)

if info.get("normalizer") or normalizer:
dataset = dataset.normalize(normalizer)

with self.strategy.scope():
# grow the networks by one (2^x) resolution
if resolution > self.current_resolution_:
self.model_.generator.add_resolution()
self.model_.discriminator.add_resolution()
_compile()

steps_per_epoch = (info.get("epochs") or epochs) // info.get(
"batch_size"
)
steps_per_epoch = dataset.get_steps_per_epoch()

# save_best_only is set to False as it is an adversarial loss
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
Expand Down

0 comments on commit 093daff

Please sign in to comment.