diff --git a/base/base_model.py b/base/base_model.py index 94fcab8..2e8623c 100644 --- a/base/base_model.py +++ b/base/base_model.py @@ -22,6 +22,8 @@ def load(self, sess): print("Loading model checkpoint {} ...\n".format(latest_checkpoint)) self.saver.restore(sess, latest_checkpoint) print("Model loaded") + else: + print("NO model loaded. Training from beginning") # just initialize a tensorflow variable to use it as epoch counter def init_cur_epoch(self): diff --git a/base/base_train.py b/base/base_train.py index 8627af1..bde135d 100644 --- a/base/base_train.py +++ b/base/base_train.py @@ -2,7 +2,7 @@ class BaseTrain: - def __init__(self, sess, model, data, config, logger): + def __init__(self, sess, model, data, config, logger, load=False): self.model = model self.logger = logger self.config = config @@ -11,6 +11,10 @@ def __init__(self, sess, model, data, config, logger): self.init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) self.sess.run(self.init) + # Load the model after initialization to ensure that the loaded values are kept + if load: + self.model.load(self.sess) + def train(self): for cur_epoch in range(self.model.cur_epoch_tensor.eval(self.sess), self.config.num_epochs + 1, 1): self.train_epoch() diff --git a/mains/example.py b/mains/example.py index 7b3ef92..009244f 100644 --- a/mains/example.py +++ b/mains/example.py @@ -26,14 +26,12 @@ def main(): sess = tf.Session() # create an instance of the model you want model = ExampleModel(config) - #load model if exists - model.load(sess) # create your data generator data = DataGenerator(config) # create tensorboard logger logger = Logger(sess, config) # create trainer and pass all the previous components to it - trainer = ExampleTrainer(sess, model, data, config, logger) + trainer = ExampleTrainer(sess, model, data, config, logger, load=True) # here you train your model trainer.train() diff --git a/trainers/example_trainer.py b/trainers/example_trainer.py index eac343c..b54ef42 100644 --- a/trainers/example_trainer.py +++ b/trainers/example_trainer.py @@ -4,8 +4,8 @@ class ExampleTrainer(BaseTrain): - def __init__(self, sess, model, data, config,logger): - super(ExampleTrainer, self).__init__(sess, model, data, config,logger) + def __init__(self, sess, model, data, config,logger, **kwargs): + super(ExampleTrainer, self).__init__(sess, model, data, config,logger, **kwargs) def train_epoch(self): loop = tqdm(range(self.config.num_iter_per_epoch))