-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathKANtf.py
137 lines (111 loc) · 5.08 KB
/
KANtf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.keras.regularizers import l2
from tensorflow.keras.utils import plot_model
import numpy as np
class KANLinear(Layer):
def __init__(self, in_features, out_features, grid_size=5, spline_order=3,
activation='silu', regularization_factor=0.01, grid_range=(-1, 1), **kwargs):
super(KANLinear, self).__init__(**kwargs)
self.in_features = in_features
self.out_features = out_features
self.grid_size = grid_size
self.spline_order = spline_order
self.activation_func = getattr(tf.nn, activation)
self.regularizer = l2(regularization_factor)
self.grid_range = grid_range
# Initialize weights
self.base_weight = self.add_weight(
"base_weight",
shape=(in_features, out_features),
initializer='glorot_uniform',
regularizer=self.regularizer,
trainable=True)
self.spline_weight = self.add_weight(
"spline_weight",
shape=(in_features, out_features, grid_size + spline_order - 1),
initializer='glorot_uniform',
regularizer=self.regularizer,
trainable=True)
self.build_grid()
def build_grid(self):
# Direct initialization of grid points here
initial_grid = np.random.randn(self.grid_size, self.in_features) # Modify as necessary
self.grid = self.add_weight(
name="grid",
shape=(self.grid_size, self.in_features),
initializer=tf.constant_initializer(initial_grid),
trainable=True
)
def call(self, inputs):
base_output = tf.matmul(inputs, self.base_weight)
spline_output = self.compute_spline_output(inputs)
return self.activation_func(base_output + spline_output)
def compute_spline_output(self, inputs):
# Placeholder for B-spline calculation logic
inputs_expanded = tf.expand_dims(inputs, -1)
# Assume a function B_batch is defined similarly as in PyTorch to compute B-spline basis functions
b_spline_bases = B_batch_tf(inputs_expanded, self.grid, k=self.spline_order)
spline_output = tf.einsum('bik,ijk->bij', b_spline_bases, self.spline_weight)
return spline_output
def get_config(self):
config = super().get_config()
config.update({
'in_features': self.in_features,
'out_features': self.out_features,
'grid_size': self.grid_size,
'spline_order': self.spline_order,
'activation': self.activation_func.__name__,
'regularization_factor': self.regularizer.l2.numpy,
'grid_range': self.grid_range
})
return config
def extend_grid_tf(grid, k):
if tf.rank(grid) != 1:
raise ValueError("Grid tensor must be one-dimensional")
left = tf.fill([k], 2 * grid[0] - grid[k]) # symmetric extension at the start
right = tf.fill([k], 2 * grid[-1] - grid[-k-1]) # symmetric extension at the end
return tf.concat([left, grid, right], axis=0)
def B_batch_tf(x, grid, k=3, extend=True):
"""
Compute B-spline basis values for given inputs using TensorFlow.
Args:
x : Tensor
Input values, shape (num_samples, 1).
grid : Tensor
Grid points, shape (num_grid_points).
k : int
Order of the B-spline (degree is k-1).
extend : bool
If True, extends the grid by k points on both ends to handle boundary conditions.
Returns:
Tensor
B-spline basis values, shape (num_samples, num_grid_points + 2 * k).
"""
print("x shape before extension:", x.shape)
print("grid shape before extension:", grid.shape)
if extend:
grid = extend_grid_tf(grid, k)
print("grid shape after extension:", grid.shape)
num_grid_points = tf.shape(grid)[0]
num_samples = tf.shape(x)[0]
# Broadcasting x to compare against each grid interval
x_broadcasted = tf.broadcast_to(x, [num_samples, num_grid_points - 1])
# Initialize B_0
B = tf.cast(tf.logical_and(x_broadcasted >= grid[:-1], x_broadcasted < grid[1:]), dtype=tf.float32)
# Recursive calculation of B_k
for d in range(1, k):
left_term = (x_broadcasted - grid[:-d-1]) / (grid[d:-1] - grid[:-d-1])
right_term = (grid[d+1:] - x_broadcasted) / (grid[d+1:] - grid[1:-d])
B = left_term * B[:, :-1] + right_term * B[:, 1:]
return B # shape (num_samples, num_grid_points - 1)
class KAN(tf.keras.models.Sequential):
def __init__(self, layers_configurations, **kwargs):
super(KAN, self).__init__()
for layer_config in layers_configurations:
self.add(KANLinear(**layer_config, **kwargs))
def get_activations(model, model_inputs, layer_name=None):
layer_outputs = [layer.output for layer in model.layers if layer.name == layer_name or layer_name is None]
activation_model = Model(inputs=model.input, outputs=layer_outputs)
activations = activation_model.predict(model_inputs)
return activations