Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mppi dynamic reconfigure #55

Merged
merged 2 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
type: "cached_lanelet"
map_topic: "/map/vector_map"
costmap_topic: "~/debug/cached_costmap"
inflation_radius: 1.4 # [m]
inflation_radius: 1.8 # [m]
cached_costmap:
min_x: 89607.0 # [m]
max_x: 89687.0 # [m]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@
ros__parameters:
# mppi
horizon : 25
num_samples : 4000
u_min : [-4.0, -0.35] # accel(m/s2), steer angle(rad)
u_max : [3.0, 0.35]
sigmas : [2.0, 0.35] # sample range
num_samples : 5000
u_min : [-4.0, -0.25] # accel(m/s2), steer angle(rad)
u_max : [2.0, 0.25]
sigmas : [2.0, 0.25] # sample range
lambda : 1.0
auto_lambda : false
auto_lambda : true
# reference path
DL : 0.1
lookahead_distance : 0.1
reference_path_interval : 0.8
reference_path_interval : 0.83
# cost weights
Qc : 10.0
Ql : 1.0
Qv : 4.0
Qc : 20.0
Ql : 5.0
Qv : 0.5
Qo : 1000.0
Qin : 0.01
Qdin : 0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@

<!-- <node pkg="simple_pure_pursuit" exec="simple_pure_pursuit" name="simple_pure_pursuit_node"
output="screen" unless="$(var use_stanley)">
<param name="use_external_target_vel" value="true" />
<param name="external_target_vel" value="8.0" />
<param name="use_external_target_vel" value="false" />
<param name="external_target_vel" value="8.3" />
<param name="lookahead_gain" value="0.24" />
<param name="lookahead_min_distance" value="2.0" />
<param name="speed_proportional_gain" value="2.0" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,23 @@ def __init__(self, config, debug=False, device=torch.device("cuda"), dtype=torch
self.reference_path: torch.Tensor = None
self.cost_map: CostMapTensor = None

def update_params(self, config):
self.config = config

# model parameter
self.delta_t = torch.tensor(self.config["delta_t"], device=self._device, dtype=self._dtype)
self.vehicle_L = torch.tensor(self.config["vehicle_L"], device=self._device, dtype=self._dtype)
self.V_MAX = torch.tensor(self.config["V_MAX"], device=self._device, dtype=self._dtype)

# cost weights
self.Qc = self.config["Qc"]
self.Ql = self.config["Ql"]
self.Qv = self.config["Qv"]
self.Qo = self.config["Qo"]
self.Qin = self.config["Qin"]
self.Qdin = self.config["Qdin"]


def update(self, state: torch.Tensor, racing_center_path: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Update the controller with the current state and reference path.
Expand Down Expand Up @@ -170,6 +187,12 @@ def cost_function(self, state: torch.Tensor, action: torch.Tensor, info: dict) -
input_cost = self.Qin * action.pow(2).sum(dim=1)
input_cost += self.Qdin * (action - prev_action).pow(2).sum(dim=1)

# dtheta cost
# steer = action[:, 1]
# dtheta = v * torch.tan(steer) / self.vehicle_L
# MAX_DTHETA = 1
# dtheta_cost = 400*torch.relu(torch.abs(dtheta) - MAX_DTHETA)

cost = path_cost + velocity_cost + obstacle_cost + input_cost

return cost
Expand Down Expand Up @@ -218,6 +241,45 @@ def resample_path(path, DL):
start_idx += segment_length

return new_path

def compute_curvature(path, interval=5):
"""
Compute curvature for each segment of the path at specified intervals and linearly interpolate the skipped points.
Args:
path (torch.Tensor): path points with shape (N, 4) [x, y, yaw, target_v]
interval (int): number of points to skip when computing curvature (the larger the value, the fewer points are used)

Returns:
torch.Tensor: curvature values for each path point, shape (N,)
"""
# Initialize curvature tensor
curvature = torch.zeros(path.shape[0], dtype=path.dtype, device=path.device)

# Compute curvature at every 'interval' points
for i in range(interval, len(path) - interval, interval):
# Compute central difference for first derivative (dx, dy)
dx = (path[i + interval, 0] - path[i - interval, 0]) / (2.0 * interval)
dy = (path[i + interval, 1] - path[i - interval, 1]) / (2.0 * interval)

# Compute central difference for second derivative (ddx, ddy)
ddx = path[i + interval, 0] - 2 * path[i, 0] + path[i - interval, 0]
ddy = path[i + interval, 1] - 2 * path[i, 1] + path[i - interval, 1]

# Curvature formula: k = (x'y'' - y'x'') / (x'^2 + y'^2)^(3/2)
curvature[i] = torch.abs(dx * ddy - dy * ddx) / (dx.pow(2) + dy.pow(2)).pow(3 / 2)

# Linearly interpolate skipped points
for i in range(interval, len(path) - interval, interval):
if i + interval < len(path):
for j in range(1, interval):
curvature[i + j] = curvature[i] + (curvature[i + interval] - curvature[i]) * (j / interval)

# Handle the boundary cases (start and end)
curvature[:interval] = curvature[interval]
curvature[-interval:] = curvature[-interval - 1]

return curvature


# Resample the path with the specified DL
path = resample_path(path, DL)
Expand All @@ -229,13 +291,30 @@ def resample_path(path, DL):
# Ensure the index is not less than the current index
ind = max(cind, ind)

# Compute curvature along the path
# curvature = compute_curvature(path)

# Generate the rest of the reference trajectory
travel = lookahead_distance

for i in range(horizon + 1):
travel += reference_path_interval
dind = int(round(travel / DL))

# if (ind + dind) < ncourse:
# xref[i] = path[ind + dind]

# # Adjust target velocity based on curvature
# if (ind + dind) < len(curvature):
# current_curvature = curvature[ind + dind]
# # Lower the target velocity when curvature is high
# print("curvature:", current_curvature)
# if current_curvature > 2.0: # Set threshold for high curvature
# xref[i, 3] = 1.0

# else:
# xref[i] = path[-1]

if (ind + dind) < ncourse:
xref[i] = path[ind + dind]
else:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import rclpy
from rclpy.node import Node
from rcl_interfaces.msg import ParameterEvent
from rclpy.parameter import Parameter
from rcl_interfaces.msg import SetParametersResult
from nav_msgs.msg import Odometry
from geometry_msgs.msg import Pose, Point, Quaternion
from autoware_auto_planning_msgs.msg import Trajectory, TrajectoryPoint
Expand Down Expand Up @@ -40,7 +43,7 @@ def __init__(self):
self.declare_parameter('vehicle_L', 1.0)
self.declare_parameter('V_MAX', 8.0)
# get
config = {
self.config = {
"horizon": self.get_parameter('horizon').get_parameter_value().integer_value,
"num_samples": self.get_parameter('num_samples').get_parameter_value().integer_value,
"u_min": self.get_parameter('u_min').get_parameter_value().double_array_value,
Expand All @@ -61,8 +64,10 @@ def __init__(self):
"vehicle_L": self.get_parameter('vehicle_L').get_parameter_value().double_value,
"V_MAX": self.get_parameter('V_MAX').get_parameter_value().double_value,
}
self.get_logger().info(f'config: {self.config}')

self.get_logger().info(f'config: {config}')
# Add parameter change callback
self.add_on_set_parameters_callback(self.parameter_callback)

# publisher
# control command
Expand Down Expand Up @@ -100,7 +105,7 @@ def __init__(self):
self.dtype = torch.float32

# mppi controller
self.controller = mppi_controller(config=config, debug=True, device=self.device, dtype=self.dtype)
self.controller = mppi_controller(config=self.config, debug=True, device=self.device, dtype=self.dtype)

self.odometry: Odometry = None
self.trajectory: Trajectory = None
Expand All @@ -117,6 +122,25 @@ def trajectory_callback(self, msg : Trajectory):
def costmap_callback(self, msg : OccupancyGrid):
self.costmap = msg

def parameter_callback(self, params):
for param in params:
if param.name in self.config:
if param.type_ == Parameter.Type.DOUBLE:
self.config[param.name] = param.value
elif param.type_ == Parameter.Type.INTEGER:
self.config[param.name] = param.value
elif param.type_ == Parameter.Type.BOOL:
self.config[param.name] = param.value
elif param.type_ == Parameter.Type.DOUBLE_ARRAY:
self.config[param.name] = param.value
self.get_logger().info(f"Parameter {param.name} changed to {param.value}")

# update controller
self.controller.update_params(self.config)

# Return a success result
return SetParametersResult(successful=True)

def zero_ackermann_control_command(self):
cmd = AckermannControlCommand()
now = self.get_clock().now().to_msg()
Expand Down
Loading