From 68b69ddbd5876151d9d25ed845c22737ba609d1b Mon Sep 17 00:00:00 2001 From: Dongsun Yoo Date: Thu, 11 Oct 2018 13:27:37 +0900 Subject: [PATCH] Add new option `echeck`, `fcheck`, and `stddev` --- simple_nn/_version.py | 2 +- simple_nn/models/neural_network.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/simple_nn/_version.py b/simple_nn/_version.py index 3dd3d2d..3d18726 100644 --- a/simple_nn/_version.py +++ b/simple_nn/_version.py @@ -1 +1 @@ -__version__ = "0.4.6" +__version__ = "0.5.0" diff --git a/simple_nn/models/neural_network.py b/simple_nn/models/neural_network.py index 4518bab..d94becf 100644 --- a/simple_nn/models/neural_network.py +++ b/simple_nn/models/neural_network.py @@ -53,6 +53,9 @@ def __init__(self): 'intra_op_parallelism_threads': 0, 'print_structure_rmse': False, 'cache': False, + 'stddev': 0.3, + 'echeck': True, + 'fcheck': True, } } self.inputs = dict() @@ -93,10 +96,11 @@ def _make_model(self): else: dtype = tf.float32 + # TODO: input validation for stddev. dense_basic_setting = { 'dtype': dtype, - 'kernel_initializer': tf.initializers.truncated_normal(stddev=0.3, dtype=dtype), - 'bias_initializer': tf.initializers.truncated_normal(stddev=0.3, dtype=dtype) + 'kernel_initializer': tf.initializers.truncated_normal(stddev=self.inputs['stddev'], dtype=dtype), + 'bias_initializer': tf.initializers.truncated_normal(stddev=self.inputs['stddev'], dtype=dtype) } dense_last_setting = copy.deepcopy(dense_basic_setting) @@ -637,8 +641,9 @@ def train(self, user_optimizer=None, user_atomic_weights_function=None): # Temp saving #if (epoch+1) % self.inputs['save_interval'] == 0: - if save_stack > self.inputs['save_interval'] and prev_eloss > eloss and \ - ((prev_floss > floss) or floss == 0.): + if save_stack > self.inputs['save_interval'] and \ + (prev_eloss > eloss or not self.inputs['echeck']) and \ + (prev_floss > floss or not self.inputs['fcheck'] or floss == 0.): temp_time = timeit.default_timer() self._save(sess, saver) prev_eloss = eloss