From 1faae5934deb65a701c769474e3c857c9d2c750e Mon Sep 17 00:00:00 2001 From: Levi McClenny Date: Mon, 1 Mar 2021 22:25:42 -0600 Subject: [PATCH] make u_weights and col_weights not interdependent --- tensordiffeq/fit.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tensordiffeq/fit.py b/tensordiffeq/fit.py index 4e5ce68..3e3c98c 100644 --- a/tensordiffeq/fit.py +++ b/tensordiffeq/fit.py @@ -92,12 +92,22 @@ def apply_grads(n_batches, obj=obj): for _ in range(n_batches): # unstack = tf.unstack(obj.u_model.trainable_variables, axis = 2) obj.variables = obj.u_model.trainable_variables - if obj.isAdaptive: + if obj.isAdaptive and obj.u_weights is not None and obj.col_weights is not None: obj.variables.extend([obj.u_weights, obj.col_weights]) loss_value, grads = obj.grad() obj.tf_optimizer.apply_gradients(zip(grads[:-2], obj.u_model.trainable_variables)) obj.tf_optimizer_weights.apply_gradients( zip([-grads[-2], -grads[-1]], [obj.u_weights, obj.col_weights])) + elif obj.isAdaptive and obj.u_weights is None and obj.col_weights is not None: + obj.variables.extend([obj.col_weights]) + loss_value, grads = obj.grad() + obj.tf_optimizer.apply_gradients(zip(grads[:-1], obj.u_model.trainable_variables)) + obj.tf_optimizer_weights.apply_gradients(zip([-grads[-1]], [obj.col_weights])) + elif obj.isAdaptive and obj.u_weights is not None and obj.col_weights is None: + obj.variables.extend([obj.u_weights]) + loss_value, grads = obj.grad() + obj.tf_optimizer.apply_gradients(zip(grads[:-1], obj.u_model.trainable_variables)) + obj.tf_optimizer_weights.apply_gradients(zip([-grads[-1]], [obj.u_weights])) else: loss_value, grads = obj.grad() obj.tf_optimizer.apply_gradients(zip(grads, obj.u_model.trainable_variables))