Skip to content

Commit

Permalink
make u_weights and col_weights not interdependent
Browse files Browse the repository at this point in the history
  • Loading branch information
levimcclenny committed Mar 2, 2021
1 parent 06a82a5 commit 1faae59
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion tensordiffeq/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,22 @@ def apply_grads(n_batches, obj=obj):
for _ in range(n_batches):
# unstack = tf.unstack(obj.u_model.trainable_variables, axis = 2)
obj.variables = obj.u_model.trainable_variables
if obj.isAdaptive:
if obj.isAdaptive and obj.u_weights is not None and obj.col_weights is not None:
obj.variables.extend([obj.u_weights, obj.col_weights])
loss_value, grads = obj.grad()
obj.tf_optimizer.apply_gradients(zip(grads[:-2], obj.u_model.trainable_variables))
obj.tf_optimizer_weights.apply_gradients(
zip([-grads[-2], -grads[-1]], [obj.u_weights, obj.col_weights]))
elif obj.isAdaptive and obj.u_weights is None and obj.col_weights is not None:
obj.variables.extend([obj.col_weights])
loss_value, grads = obj.grad()
obj.tf_optimizer.apply_gradients(zip(grads[:-1], obj.u_model.trainable_variables))
obj.tf_optimizer_weights.apply_gradients(zip([-grads[-1]], [obj.col_weights]))
elif obj.isAdaptive and obj.u_weights is not None and obj.col_weights is None:
obj.variables.extend([obj.u_weights])
loss_value, grads = obj.grad()
obj.tf_optimizer.apply_gradients(zip(grads[:-1], obj.u_model.trainable_variables))
obj.tf_optimizer_weights.apply_gradients(zip([-grads[-1]], [obj.u_weights]))
else:
loss_value, grads = obj.grad()
obj.tf_optimizer.apply_gradients(zip(grads, obj.u_model.trainable_variables))
Expand Down

0 comments on commit 1faae59

Please sign in to comment.