Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Directional sparsity cost functional #259

Merged
merged 5 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 56 additions & 8 deletions neurolib/control/optimal_control/cost_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def accuracy_cost(x, target_timeseries, weights, cost_matrix, dt, interval=(0, N
:param interval: (t_start, t_end). Indices of start and end point of the slice (both inclusive) in time
dimension. Only 'int' positive index-notation allowed (i.e. no negative indices or 'None').
:type interval: tuple, optional

:return: Accuracy cost.
:rtype: float
"""
Expand Down Expand Up @@ -56,7 +56,7 @@ def derivative_accuracy_cost(x, target_timeseries, weights, cost_matrix, interva
:param interval: (t_start, t_end). Indices of start and end point of the slice (both inclusive) in time
dimension. Only 'int' positive index-notation allowed (i.e. no negative indices or 'None').
:type interval: tuple, optional

:return: Accuracy cost derivative.
:rtype: ndarray
"""
Expand Down Expand Up @@ -84,7 +84,7 @@ def precision_cost(x_sim, x_target, cost_matrix, interval=(0, None)):
:param interval: (t_start, t_end). Indices of start and end point of the slice (both inclusive) in time
dimension. Only 'int' positive index-notation allowed (i.e. no negative indices or 'None').
:type interval: tuple

:return: Precision cost for time interval.
:rtype: float
"""
Expand Down Expand Up @@ -114,7 +114,7 @@ def derivative_precision_cost(x_sim, x_target, cost_matrix, interval):
:param interval: (t_start, t_end). Indices of start and end point of the slice (both inclusive) in time
dimension. Only 'int' positive index-notation allowed (i.e. no negative indices or 'None').
:type interval: tuple

:return: Control-dimensions x T array of precision cost gradients.
:rtype: np.ndarray
"""
Expand All @@ -140,7 +140,7 @@ def control_strength_cost(u, weights, dt):
:type weights: dictionary
:param dt: Time step.
:type dt: float

:return: control strength cost of the control.
:rtype: float
"""
Expand All @@ -159,17 +159,22 @@ def control_strength_cost(u, weights, dt):
for t in range(u.shape[2]):
cost += cost_timeseries[n, v, t] * dt

if weights["w_1D"] != 0.0:
cost += weights["w_1D"] * L1D_cost_integral(u, dt)

return cost


@numba.njit
def derivative_control_strength_cost(u, weights):
def derivative_control_strength_cost(u, weights, dt):
"""Derivative of the 'control_strength_cost' wrt. the control 'u'.

:param u: Control-dimensions x T array. Control signals.
:type u: np.ndarray
:param weights: Dictionary of weights.
:type weights: dictionary
:param dt: Time step.
:type dt: float

:return: Control-dimensions x T array of L2-cost gradients.
:rtype: np.ndarray
Expand All @@ -179,6 +184,8 @@ def derivative_control_strength_cost(u, weights):

if weights["w_2"] != 0.0:
der += weights["w_2"] * derivative_L2_cost(u)
if weights["w_1D"] != 0.0:
der += weights["w_1D"] * derivative_L1D_cost(u, dt)

return der

Expand All @@ -189,7 +196,7 @@ def L2_cost(u):

:param u: Control-dimensions x T array. Control signals.
:type u: np.ndarray

:return: L2 cost of the control.
:rtype: float
"""
Expand All @@ -203,8 +210,49 @@ def derivative_L2_cost(u):

:param u: Control-dimensions x T array. Control signals.
:type u: np.ndarray

:return: Control-dimensions x T array of L2-cost gradients.
:rtype: np.ndarray
"""
return u


@numba.njit
def L1D_cost_integral(
u,
dt,
):
"""'Directional sparsity' or 'L1D' cost integrated over time. Penalizes for control strength.
:param u: Control-dimensions x T array. Control signals.
:type u: np.ndarray
:param dt: Time step.
:type dt: float
:return: L1D cost of the control.
:rtype: float
"""

return np.sum(np.sum(np.sqrt(np.sum(u**2, axis=2) * dt), axis=1), axis=0)


@numba.njit
def derivative_L1D_cost(
u,
dt,
):
"""
:param u: Control-dimensions x T array. Control signals.
:type u: np.ndarray
:param dt: Time step.
:type dt: float
:return : Control-dimensions x T array of L1D-cost gradients.
:rtype: np.ndarray
"""

denominator = np.sqrt(np.sum(u**2, axis=2) * dt)
der = np.zeros((u.shape))
for n in range(der.shape[0]):
for v in range(der.shape[1]):
if denominator[n, v] != 0.0:
der[n, v, :] = u[n, v, :] / denominator[n, v]

return der
7 changes: 4 additions & 3 deletions neurolib/control/optimal_control/oc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def getdefaultweights():
)
weights["w_p"] = 1.0
weights["w_2"] = 0.0
weights["w_1D"] = 0.0

return weights

Expand Down Expand Up @@ -471,14 +472,14 @@ def __init__(
for v, iv in enumerate(self.model.input_vars):
control[:, v, :] = self.model.params[iv]

self.control = control.copy()
self.control = control.copy()
self.check_params()

self.control = update_control_with_limit(
self.N, self.dim_in, self.T, control, 0.0, np.zeros(control.shape), self.maximum_control_strength
)

self.model_params = self.get_model_params()
self.model_params = self.get_model_params()

def check_params(self):
"""Checks a subset of parameters and throws an error if a wrong dimension is found."""
Expand Down Expand Up @@ -624,7 +625,7 @@ def compute_gradient(self):
:rtype: np.ndarray of shape N x V x T
"""
self.solve_adjoint()
df_du = cost_functions.derivative_control_strength_cost(self.control, self.weights)
df_du = cost_functions.derivative_control_strength_cost(self.control, self.weights, self.dt)
d_du = self.Duh()

return compute_gradient(
Expand Down
24 changes: 24 additions & 0 deletions tests/control/optimal_control/test_oc_cost_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,30 @@ def test_derivative_L2_cost(self):
desired_output = u
self.assertTrue(np.all(cost_functions.derivative_L2_cost(u) == desired_output))

def test_L1D_cost(self):
print(" Test L1D cost")
dt = 0.1
reference_result = 2.0 * np.sum(np.sqrt(np.sum(p.TEST_INPUT_1N_6**2 * dt, axis=1)))
weights = getdefaultweights()
weights["w_1D"] = 1.0
u = np.concatenate([p.TEST_INPUT_1N_6[:, np.newaxis, :], p.TEST_INPUT_1N_6[:, np.newaxis, :]], axis=1)
L1D_cost = cost_functions.control_strength_cost(u, weights, dt)

self.assertAlmostEqual(L1D_cost, reference_result, places=8)

def test_derivative_L1D_cost(self):
print(" Test L1D cost derivative")
dt = 0.1
denominator = np.sqrt(np.sum(p.TEST_INPUT_1N_6**2 * dt, axis=1))

u = np.concatenate([p.TEST_INPUT_1N_6[:, np.newaxis, :], p.TEST_INPUT_1N_6[:, np.newaxis, :]], axis=1)
reference_result = np.zeros((u.shape))
for n in range(u.shape[0]):
for v in range(u.shape[1]):
reference_result[n, v, :] = u[n, v, :] / denominator[n]

self.assertTrue(np.all(cost_functions.derivative_L1D_cost(u, dt) == reference_result))

def test_weights_dictionary(self):
print("Test dictionary of cost weights")
model = FHNModel()
Expand Down
Loading