Skip to content

Commit

Permalink
aborted state reset implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Joceran committed Apr 8, 2022
1 parent 14c06e5 commit 1c3190e
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 46 deletions.
4 changes: 2 additions & 2 deletions experiments/thetagpu/jobs/minimalistic-frnn-test.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash
#COBALT -n 1
#COBALT -t 20
#COBALT -q single-gpu
#COBALT -t 140
#COBALT -q full-node
#COBALT -A datascience

source ../../../build/activate-dhenv.sh
Expand Down
122 changes: 78 additions & 44 deletions src/scalbo/scalbo/benchmark/minimalistic_frnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,24 @@ def on_batch_end(self, batch, logs=None):
self.model.reset_states()


class LayerReset(tf.keras.layers.Layer):
def __init__(self, layer):
super(LayerReset, self).__init__()
self._layer = layer

def call(self, batches_to_reset):
def reset_states(layer, batches_to_reset):
for j, reset in enumerate(tf.unstack(batches_to_reset)):
if reset != 0:
for state in layer.states:
assert len(batches_to_reset) == state.shape[0]
batch_states = tf.keras.backend.get_value(state)
batch_states[j] = 0
tf.keras.backend.set_value(state, batch_states)
# reset_states(self._layer, batches_to_reset)
return batches_to_reset


class ModelBuilder(object):
def __init__(self, conf):
self.conf = conf
Expand Down Expand Up @@ -217,53 +235,34 @@ def build_model(self):
pre_rnn_model = tf.keras.Model(inputs=pre_rnn_input, outputs=pre_rnn)
pre_rnn_model.summary()

#

x_input = Input(batch_shape=batch_input_shape)
if num_1D > 0 or model_conf.get('extra_dense_input', False):
x_in = TimeDistributed(pre_rnn_model)(x_input)
else:
x_in = x_input

# ==========
# TCN MODEL
# RNN MODEL
# ==========
if model_conf.get('keras_tcn', False):
tcn_layers = model_conf['tcn_layers']
tcn_dropout = model_conf['tcn_dropout']
nb_filters = model_conf['tcn_hidden']
kernel_size = model_conf['kernel_size_temporal']
nb_stacks = model_conf['tcn_nbstacks']
use_skip_connections = model_conf['tcn_skip_connect']
activation = model_conf['tcn_activation']
use_batch_norm = model_conf['tcn_batch_norm']
# for _ in range(model_conf['tcn_pack_layers']):
# x_in = TCN(
# use_batch_norm=use_batch_norm, activation=activation,
# use_skip_connections=use_skip_connections,
# nb_stacks=nb_stacks, kernel_size=kernel_size,
# nb_filters=nb_filters, num_layers=tcn_layers,
# dropout_rate=tcn_dropout)(x_in)
# x_in = Dropout(dropout_prob)(x_in)
else:
# ==========
# RNN MODEL
# ==========
model_kwargs = dict(return_sequences=return_sequences,
# batch_input_shape=batch_input_shape,
stateful=stateful,
kernel_regularizer=l2(regularization),
recurrent_regularizer=l2(regularization),
bias_regularizer=l2(regularization),
)
if rnn_type != 'CuDNNLSTM':
# recurrent_dropout is unsupported in cuDNN library
model_kwargs['recurrent_dropout'] = dropout_prob # recurrent states
model_kwargs['dropout'] = dropout_prob # input states
for _ in range(model_conf['rnn_layers']):
x_in = rnn_model(rnn_size, **model_kwargs)(x_in)
x_in = Dropout(dropout_prob)(x_in)
if return_sequences:
x_out = TimeDistributed(Dense(1, activation=output_activation))(x_in)
model_kwargs = dict(return_sequences=return_sequences,
# batch_input_shape=batch_input_shape,
stateful=stateful,
kernel_regularizer=l2(regularization),
recurrent_regularizer=l2(regularization),
bias_regularizer=l2(regularization),
)
if rnn_type != 'CuDNNLSTM':
# recurrent_dropout is unsupported in cuDNN library
model_kwargs['recurrent_dropout'] = dropout_prob # recurrent states
model_kwargs['dropout'] = dropout_prob # input states
for _ in range(model_conf['rnn_layers']):
rnn_layer = rnn_model(rnn_size, **model_kwargs)
x_in = rnn_layer(x_in)
x_in = Dropout(dropout_prob)(x_in)
if return_sequences:
x_out = TimeDistributed(Dense(1, activation=output_activation))(x_in)

model = tf.keras.Model(inputs=x_input, outputs=x_out)
# bug with tensorflow/Keras
# TODO(KGF): what is this bug? this is the only direct "tensorflow"
Expand All @@ -278,6 +277,38 @@ def build_model(self):
model.reset_states()
return model

def build_model_zeros(self):
conf = self.conf
model_conf = conf['model']
length = model_conf['length']
use_signals = conf['paths']['use_signals']
num_signals = sum([sig.num_channels for sig in use_signals])
batch_size = conf['training']['batch_size']
batch_input_shape = (batch_size, length, num_signals)
batch_output_shape = (batch_size, length, 1)
x_input = Input(batch_shape=batch_input_shape)

output = output = Lambda(lambda _: tf.zeros(batch_output_shape))(x_input)

model = tf.keras.Model(inputs=x_input, outputs=output)
return model

def build_model_ones(self):
conf = self.conf
model_conf = conf['model']
length = model_conf['length']
use_signals = conf['paths']['use_signals']
num_signals = sum([sig.num_channels for sig in use_signals])
batch_size = conf['training']['batch_size']
batch_input_shape = (batch_size, length, num_signals)
batch_output_shape = (batch_size, length, 1)
x_input = Input(batch_shape=batch_input_shape)

output = output = Lambda(lambda _: tf.ones(batch_output_shape))(x_input)

model = tf.keras.Model(inputs=x_input, outputs=output)
return model

def build_optimizer(self):
conf = self.conf
lr = conf['model']["lr"]
Expand Down Expand Up @@ -355,6 +386,10 @@ def load_dataset(self, shot_list, loader):
batch_generator,
output_signature=(
tf.TensorSpec(shape=(batch_size, length, num_signals), dtype=tf.float32),
# (
# tf.TensorSpec(shape=(batch_size, length, num_signals), dtype=tf.float32),
# tf.TensorSpec(shape=(batch_size, 1), dtype=tf.int32),
# ),
tf.TensorSpec(shape=(batch_size, length, 1), dtype=tf.float32),
)
)
Expand Down Expand Up @@ -528,11 +563,10 @@ def run(config: None):
steps_per_epoch=steps_per_epoch,
validation_data=valid_dataset,
validation_steps=validation_steps,
# callbacks=ResetStatesCallback(),
)

# with open('/lus/grand/projects/datascience/jgouneau/deephyper/frnn/exp/outputs/history.json', 'w') as file:
# json.dump(history.history, file)
with open('/lus/grand/projects/datascience/jgouneau/deephyper/frnn/exp/outputs/stateless_model.json', 'w') as file:
json.dump(history.history, file)

# evaluate it
evaluator = ModelEvaluator(model, loader, conf)
Expand All @@ -549,7 +583,7 @@ def run(config: None):
print('Test Loss: {:.3e}'.format(test_loss))
print('Test ROC: {:.4f}'.format(test_roc))

return valid_roc
# return valid_roc

if __name__ == '__main__':
conf = {
Expand Down

0 comments on commit 1c3190e

Please sign in to comment.