Skip to content

Commit 5e05677

Browse files
committed
update models
1 parent 94cfc92 commit 5e05677

File tree

1 file changed

+187
-0
lines changed

1 file changed

+187
-0
lines changed

tensordiffeq/models.py

+187
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
import tensorflow as tf
2+
import numpy as np
3+
import time
4+
from .utils import *
5+
from .networks import *
6+
from .plotting import *
7+
from .fit import *
8+
9+
10+
class CollocationSolverND:
11+
def __init__(self, assimilate=False):
12+
self.assimilate = assimilate
13+
14+
def compile(self, layer_sizes, f_model, domain, bcs, isAdaptive=False,
15+
col_weights=None, u_weights=None, g=None, dist=False):
16+
self.layer_sizes = layer_sizes
17+
self.sizes_w, self.sizes_b = get_sizes(layer_sizes)
18+
self.bcs = bcs
19+
self.f_model = get_tf_model(f_model)
20+
self.isAdaptive = False
21+
self.g = g
22+
self.domain = domain
23+
self.dist = dist
24+
self.col_weights = col_weights
25+
self.u_weights = u_weights
26+
self.X_f_dims = tf.shape(self.domain.X_f)
27+
self.X_f_len = tf.slice(self.X_f_dims, [0], [1]).numpy()
28+
tmp = []
29+
for i, vec in enumerate(self.domain.X_f.T):
30+
tmp.append(np.reshape(vec, (-1,1)))
31+
self.X_f_in = np.asarray(tmp)
32+
33+
34+
35+
if isAdaptive:
36+
self.isAdaptive = True
37+
if self.col_weights is None and self.u_weights is None:
38+
raise Exception("Adaptive weights selected but no inputs were specified!")
39+
if not isAdaptive:
40+
if self.col_weights is not None and self.u_weights is not None:
41+
raise Exception(
42+
"Adaptive weights are turned off but weight vectors were provided. Set the weight vectors to "
43+
"\"none\" to continue")
44+
45+
def compile_data(self, x, t, y):
46+
if not self.assimilate:
47+
raise Exception(
48+
"Assimilate needs to be set to 'true' for data assimilation. Re-initialize CollocationSolver1D with "
49+
"assimilate=True.")
50+
self.data_x = x
51+
self.data_t = t
52+
self.data_s = y
53+
54+
def update_loss(self):
55+
loss_tmp = 0.0
56+
# Periodic BC iteration for all components of deriv_model
57+
for bc in self.bcs:
58+
if bc.isPeriodic:
59+
for i, dim in enumerate(bc.var):
60+
for j, lst in enumerate(dim):
61+
for k, tup in enumerate(lst):
62+
upper = bc.u_x_model(self.u_model, bc.upper[i])[j][k]
63+
lower = bc.u_x_model(self.u_model, bc.lower[i])[j][k]
64+
msq = MSE(upper, lower)
65+
loss_tmp = tf.math.add(loss_tmp, msq)
66+
# initial BCs, including adaptive model
67+
if bc.isInit:
68+
if self.isAdaptive:
69+
loss_tmp = tf.math.add(loss_tmp, MSE(self.u_model(bc.input), bc.val, self.u_weights))
70+
else:
71+
loss_tmp = tf.math.add(loss_tmp, MSE(self.u_model(bc.input), bc.val))
72+
# Dirichlect BC, will need to add more cases for Neumann BC, etc as more
73+
# BC types are added
74+
# This is true unless the BC loss can be evaluated using the MSE function explicitly
75+
else:
76+
loss_tmp = tf.math.add(loss_tmp, MSE(self.u_model(bc.input), bc.val))
77+
78+
f_u_pred = self.f_model(self.u_model, *self.X_f_in)
79+
80+
if self.isAdaptive:
81+
mse_f_u = MSE(f_u_pred, constant(0.0), self.col_weights)
82+
else:
83+
mse_f_u = MSE(f_u_pred, constant(0.0))
84+
85+
loss_tmp = tf.math.add(loss_tmp, mse_f_u)
86+
return loss_tmp
87+
88+
def grad(self):
89+
with tf.GradientTape() as tape:
90+
loss_value = self.update_loss()
91+
grads = tape.gradient(loss_value, self.variables)
92+
return loss_value, grads
93+
94+
def fit(self, tf_iter, newton_iter, batch_sz=None, newton_eager=True):
95+
if self.isAdaptive and (batch_sz is not None):
96+
raise Exception("Currently we dont support minibatching for adaptive PINNs")
97+
if self.dist:
98+
fit_dist(self, tf_iter=tf_iter, newton_iter=newton_iter, batch_sz=batch_sz, newton_eager=newton_eager)
99+
else:
100+
fit(self, tf_iter=tf_iter, newton_iter=newton_iter, batch_sz=batch_sz, newton_eager=newton_eager)
101+
102+
# L-BFGS implementation from https://github.com/pierremtb/PINNs-TF2.0
103+
def get_loss_and_flat_grad(self):
104+
def loss_and_flat_grad(w):
105+
with tf.GradientTape() as tape:
106+
set_weights(self.u_model, w, self.sizes_w, self.sizes_b)
107+
loss_value = self.update_loss()
108+
grad = tape.gradient(loss_value, self.u_model.trainable_variables)
109+
grad_flat = []
110+
for g in grad:
111+
grad_flat.append(tf.reshape(g, [-1]))
112+
grad_flat = tf.concat(grad_flat, 0)
113+
return loss_value, grad_flat
114+
115+
return loss_and_flat_grad
116+
117+
def predict(self, X_star):
118+
X_star = convertTensor(X_star)
119+
u_star = self.u_model(X_star)
120+
121+
f_u_star = self.f_model(self.u_model, X_star[:, 0:1],
122+
X_star[:, 1:2])
123+
124+
return u_star.numpy(), f_u_star.numpy()
125+
126+
127+
# WIP
128+
# TODO DiscoveryModel
129+
class DiscoveryModel():
130+
def compile(self, layer_sizes, f_model, X, u, vars, col_weights=None):
131+
self.layer_sizes = layer_sizes
132+
self.f_model = f_model
133+
self.X = X
134+
self.x_f = X[:, 0:1]
135+
self.t_f = X[:, 1:2]
136+
self.u = u
137+
self.vars = vars
138+
self.u_model = neural_net(self.layer_sizes)
139+
self.tf_optimizer = tf.keras.optimizers.Adam(lr=0.005, beta_1=.99)
140+
self.tf_optimizer_vars = tf.keras.optimizers.Adam(lr=0.0005, beta_1=.99)
141+
self.tf_optimizer_weights = tf.keras.optimizers.Adam(lr=0.005, beta_1=.99)
142+
self.col_weights = col_weights
143+
144+
def loss(self):
145+
u_pred = self.u_model(self.X)
146+
f_u_pred, self.vars = self.f_model(self.u_model, self.x_f, self.t_f, self.vars)
147+
148+
if self.col_weights is not None:
149+
return MSE(u_pred, self.u) + g_MSE(f_u_pred, constant(0.0), self.col_weights ** 2)
150+
else:
151+
return MSE(u_pred, self.u) + MSE(f_u_pred, constant(0.0))
152+
153+
def grad(self):
154+
with tf.GradientTape() as tape:
155+
loss_value = self.loss()
156+
grads = tape.gradient(loss_value, self.variables)
157+
return loss_value, grads
158+
159+
@tf.function
160+
def train_op(self):
161+
if self.col_weights is not None:
162+
len_ = len(self.vars)
163+
self.variables = self.u_model.trainable_variables
164+
self.variables.extend([self.col_weights])
165+
self.variables.extend(self.vars)
166+
loss_value, grads = self.grad()
167+
self.tf_optimizer.apply_gradients(zip(grads[:-(len_ + 2)], self.u_model.trainable_variables))
168+
self.tf_optimizer_weights.apply_gradients(zip([-grads[-(len_ + 1)]], [self.col_weights]))
169+
self.tf_optimizer_vars.apply_gradients(zip(grads[-len_:], self.vars))
170+
else:
171+
self.variables = self.u_model.trainable_variables
172+
loss_value, mse_0, mse_b, mse_f, grads = self.grad()
173+
self.tf_optimizer.apply_gradients(zip(grads, self.u_model.trainable_variables))
174+
175+
return loss_value
176+
177+
def train_loop(self, tf_iter):
178+
start_time = time.time()
179+
for i in range(tf_iter):
180+
loss_value = self.train_op()
181+
if i % 100 == 0:
182+
elapsed = time.time() - start_time
183+
print('It: %d, Time: %.2f' % (i, elapsed))
184+
tf.print(f"total loss: {loss_value}")
185+
var = [var.numpy() for var in self.vars]
186+
print("vars estimate(s):", var)
187+
start_time = time.time()

0 commit comments

Comments
 (0)