Skip to content

Commit

Permalink
mppi dynamic reconfigure (#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
tamago117 authored Oct 9, 2024
1 parent b4ed725 commit de5e51b
Show file tree
Hide file tree
Showing 6 changed files with 259 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
type: "cached_lanelet"
map_topic: "/map/vector_map"
costmap_topic: "~/debug/cached_costmap"
inflation_radius: 1.4 # [m]
inflation_radius: 1.8 # [m]
cached_costmap:
min_x: 89607.0 # [m]
max_x: 89687.0 # [m]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@
ros__parameters:
# mppi
horizon : 25
num_samples : 4000
u_min : [-4.0, -0.35] # accel(m/s2), steer angle(rad)
u_max : [3.0, 0.35]
sigmas : [2.0, 0.35] # sample range
num_samples : 5000
u_min : [-4.0, -0.25] # accel(m/s2), steer angle(rad)
u_max : [2.0, 0.25]
sigmas : [2.0, 0.25] # sample range
lambda : 1.0
auto_lambda : false
auto_lambda : true
# reference path
DL : 0.1
lookahead_distance : 0.1
reference_path_interval : 0.8
reference_path_interval : 0.83
# cost weights
Qc : 10.0
Ql : 1.0
Qv : 4.0
Qc : 20.0
Ql : 5.0
Qv : 0.5
Qo : 1000.0
Qin : 0.01
Qdin : 0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

<!-- <node pkg="simple_pure_pursuit" exec="simple_pure_pursuit" name="simple_pure_pursuit_node"
output="screen" unless="$(var use_stanley)">
<param name="use_external_target_vel" value="true" />
<param name="external_target_vel" value="8.0" />
<param name="use_external_target_vel" value="false" />
<param name="external_target_vel" value="8.3" />
<param name="lookahead_gain" value="0.24" />
<param name="lookahead_min_distance" value="2.0" />
<param name="speed_proportional_gain" value="2.0" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,23 @@ def __init__(self, config, debug=False, device=torch.device("cuda"), dtype=torch
self.reference_path: torch.Tensor = None
self.cost_map: CostMapTensor = None

def update_params(self, config):
self.config = config

# model parameter
self.delta_t = torch.tensor(self.config["delta_t"], device=self._device, dtype=self._dtype)
self.vehicle_L = torch.tensor(self.config["vehicle_L"], device=self._device, dtype=self._dtype)
self.V_MAX = torch.tensor(self.config["V_MAX"], device=self._device, dtype=self._dtype)

# cost weights
self.Qc = self.config["Qc"]
self.Ql = self.config["Ql"]
self.Qv = self.config["Qv"]
self.Qo = self.config["Qo"]
self.Qin = self.config["Qin"]
self.Qdin = self.config["Qdin"]


def update(self, state: torch.Tensor, racing_center_path: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Update the controller with the current state and reference path.
Expand Down Expand Up @@ -170,6 +187,12 @@ def cost_function(self, state: torch.Tensor, action: torch.Tensor, info: dict) -
input_cost = self.Qin * action.pow(2).sum(dim=1)
input_cost += self.Qdin * (action - prev_action).pow(2).sum(dim=1)

# dtheta cost
# steer = action[:, 1]
# dtheta = v * torch.tan(steer) / self.vehicle_L
# MAX_DTHETA = 1
# dtheta_cost = 400*torch.relu(torch.abs(dtheta) - MAX_DTHETA)

cost = path_cost + velocity_cost + obstacle_cost + input_cost

return cost
Expand Down Expand Up @@ -218,6 +241,45 @@ def resample_path(path, DL):
start_idx += segment_length

return new_path

def compute_curvature(path, interval=5):
"""
Compute curvature for each segment of the path at specified intervals and linearly interpolate the skipped points.
Args:
path (torch.Tensor): path points with shape (N, 4) [x, y, yaw, target_v]
interval (int): number of points to skip when computing curvature (the larger the value, the fewer points are used)
Returns:
torch.Tensor: curvature values for each path point, shape (N,)
"""
# Initialize curvature tensor
curvature = torch.zeros(path.shape[0], dtype=path.dtype, device=path.device)

# Compute curvature at every 'interval' points
for i in range(interval, len(path) - interval, interval):
# Compute central difference for first derivative (dx, dy)
dx = (path[i + interval, 0] - path[i - interval, 0]) / (2.0 * interval)
dy = (path[i + interval, 1] - path[i - interval, 1]) / (2.0 * interval)

# Compute central difference for second derivative (ddx, ddy)
ddx = path[i + interval, 0] - 2 * path[i, 0] + path[i - interval, 0]
ddy = path[i + interval, 1] - 2 * path[i, 1] + path[i - interval, 1]

# Curvature formula: k = (x'y'' - y'x'') / (x'^2 + y'^2)^(3/2)
curvature[i] = torch.abs(dx * ddy - dy * ddx) / (dx.pow(2) + dy.pow(2)).pow(3 / 2)

# Linearly interpolate skipped points
for i in range(interval, len(path) - interval, interval):
if i + interval < len(path):
for j in range(1, interval):
curvature[i + j] = curvature[i] + (curvature[i + interval] - curvature[i]) * (j / interval)

# Handle the boundary cases (start and end)
curvature[:interval] = curvature[interval]
curvature[-interval:] = curvature[-interval - 1]

return curvature


# Resample the path with the specified DL
path = resample_path(path, DL)
Expand All @@ -229,13 +291,30 @@ def resample_path(path, DL):
# Ensure the index is not less than the current index
ind = max(cind, ind)

# Compute curvature along the path
# curvature = compute_curvature(path)

# Generate the rest of the reference trajectory
travel = lookahead_distance

for i in range(horizon + 1):
travel += reference_path_interval
dind = int(round(travel / DL))

# if (ind + dind) < ncourse:
# xref[i] = path[ind + dind]

# # Adjust target velocity based on curvature
# if (ind + dind) < len(curvature):
# current_curvature = curvature[ind + dind]
# # Lower the target velocity when curvature is high
# print("curvature:", current_curvature)
# if current_curvature > 2.0: # Set threshold for high curvature
# xref[i, 3] = 1.0

# else:
# xref[i] = path[-1]

if (ind + dind) < ncourse:
xref[i] = path[ind + dind]
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import rclpy
from rclpy.node import Node
from rcl_interfaces.msg import ParameterEvent
from rclpy.parameter import Parameter
from rcl_interfaces.msg import SetParametersResult
from nav_msgs.msg import Odometry
from geometry_msgs.msg import Pose, Point, Quaternion
from autoware_auto_planning_msgs.msg import Trajectory, TrajectoryPoint
Expand Down Expand Up @@ -40,7 +43,7 @@ def __init__(self):
self.declare_parameter('vehicle_L', 1.0)
self.declare_parameter('V_MAX', 8.0)
# get
config = {
self.config = {
"horizon": self.get_parameter('horizon').get_parameter_value().integer_value,
"num_samples": self.get_parameter('num_samples').get_parameter_value().integer_value,
"u_min": self.get_parameter('u_min').get_parameter_value().double_array_value,
Expand All @@ -61,8 +64,10 @@ def __init__(self):
"vehicle_L": self.get_parameter('vehicle_L').get_parameter_value().double_value,
"V_MAX": self.get_parameter('V_MAX').get_parameter_value().double_value,
}
self.get_logger().info(f'config: {self.config}')

self.get_logger().info(f'config: {config}')
# Add parameter change callback
self.add_on_set_parameters_callback(self.parameter_callback)

# publisher
# control command
Expand Down Expand Up @@ -100,7 +105,7 @@ def __init__(self):
self.dtype = torch.float32

# mppi controller
self.controller = mppi_controller(config=config, debug=True, device=self.device, dtype=self.dtype)
self.controller = mppi_controller(config=self.config, debug=True, device=self.device, dtype=self.dtype)

self.odometry: Odometry = None
self.trajectory: Trajectory = None
Expand All @@ -117,6 +122,25 @@ def trajectory_callback(self, msg : Trajectory):
def costmap_callback(self, msg : OccupancyGrid):
self.costmap = msg

def parameter_callback(self, params):
for param in params:
if param.name in self.config:
if param.type_ == Parameter.Type.DOUBLE:
self.config[param.name] = param.value
elif param.type_ == Parameter.Type.INTEGER:
self.config[param.name] = param.value
elif param.type_ == Parameter.Type.BOOL:
self.config[param.name] = param.value
elif param.type_ == Parameter.Type.DOUBLE_ARRAY:
self.config[param.name] = param.value
self.get_logger().info(f"Parameter {param.name} changed to {param.value}")

# update controller
self.controller.update_params(self.config)

# Return a success result
return SetParametersResult(successful=True)

def zero_ackermann_control_command(self):
cmd = AckermannControlCommand()
now = self.get_clock().now().to_msg()
Expand Down
Loading

0 comments on commit de5e51b

Please sign in to comment.