Skip to content

Commit

Permalink
style(pre-commit): autofix
Browse files Browse the repository at this point in the history
  • Loading branch information
pre-commit-ci[bot] committed Feb 26, 2024
1 parent 3f1ed5e commit 16fb567
Showing 1 changed file with 105 additions and 93 deletions.
198 changes: 105 additions & 93 deletions localization/ndt_evaluation/plot_box.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import numpy as np
import matplotlib.pyplot as plt
import tf_transformations
import rosbag2_py
import argparse
import pathlib
from rosidl_runtime_py.utilities import get_message
from rclpy.serialization import deserialize_message
import pandas as pd
from scipy.spatial.transform import Rotation, Slerp

import geometry_msgs.msg
import matplotlib.pyplot as plt
import nav_msgs.msg
import numpy as np
import pandas as pd
from rclpy.serialization import deserialize_message
import rosbag2_py
from rosidl_runtime_py.utilities import get_message
from scipy.spatial.transform import Rotation
from scipy.spatial.transform import Slerp
import tf_transformations


def euler_from_quaternion(quaternion):
return tf_transformations.euler_from_quaternion([quaternion.qx, quaternion.qy, quaternion.qz, quaternion.qw])
return tf_transformations.euler_from_quaternion(
[quaternion.qx, quaternion.qy, quaternion.qz, quaternion.qw]
)


def extract_pose_data(msg, msg_type):
Expand All @@ -30,23 +34,23 @@ def extract_pose_data(msg, msg_type):


def interpolate_pose(df_pose: pd.DataFrame, target_timestamp: pd.Series) -> pd.DataFrame:
POSITIONS_KEY = ['x', 'y', 'z']
ORIENTATIONS_KEY = ['qw', 'qx', 'qy', 'qz']
POSITIONS_KEY = ["x", "y", "z"]
ORIENTATIONS_KEY = ["qw", "qx", "qy", "qz"]
target_index = 0
df_index = 0
data_dict = {
'x': [],
'y': [],
'z': [],
'qx': [],
'qy': [],
'qz': [],
'qw': [],
'timestamp': [],
"x": [],
"y": [],
"z": [],
"qx": [],
"qy": [],
"qz": [],
"qw": [],
"timestamp": [],
}
while df_index < len(df_pose) - 1 and target_index < len(target_timestamp):
curr_time = df_pose.iloc[df_index]['timestamp']
next_time = df_pose.iloc[df_index + 1]['timestamp']
curr_time = df_pose.iloc[df_index]["timestamp"]
next_time = df_pose.iloc[df_index + 1]["timestamp"]
target_time = target_timestamp[target_index]

# Find a df_index that includes target_time
Expand All @@ -65,42 +69,42 @@ def interpolate_pose(df_pose: pd.DataFrame, target_timestamp: pd.Series) -> pd.D
next_orientation = df_pose.iloc[df_index + 1][ORIENTATIONS_KEY]
curr_r = Rotation.from_quat(curr_orientation)
next_r = Rotation.from_quat(next_orientation)
slerp = Slerp([curr_time, next_time],
Rotation.concatenate([curr_r, next_r]))
slerp = Slerp([curr_time, next_time], Rotation.concatenate([curr_r, next_r]))
target_orientation = slerp([target_time]).as_quat()[0]

data_dict['timestamp'].append(target_timestamp[target_index])
data_dict['x'].append(target_position[0])
data_dict['y'].append(target_position[1])
data_dict['z'].append(target_position[2])
data_dict['qw'].append(target_orientation[0])
data_dict['qx'].append(target_orientation[1])
data_dict['qy'].append(target_orientation[2])
data_dict['qz'].append(target_orientation[3])
data_dict["timestamp"].append(target_timestamp[target_index])
data_dict["x"].append(target_position[0])
data_dict["y"].append(target_position[1])
data_dict["z"].append(target_position[2])
data_dict["qw"].append(target_orientation[0])
data_dict["qx"].append(target_orientation[1])
data_dict["qy"].append(target_orientation[2])
data_dict["qz"].append(target_orientation[3])
target_index += 1
result_df = pd.DataFrame(data_dict)
return result_df


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('rosbag_path', type=pathlib.Path)
parser.add_argument('--pose_topic1', type=str, default="/localization/pose_estimator/pose")
parser.add_argument('--pose_topic2', type=str, default="/sensing/gnss/pose")
parser.add_argument("rosbag_path", type=pathlib.Path)
parser.add_argument("--pose_topic1", type=str, default="/localization/pose_estimator/pose")
parser.add_argument("--pose_topic2", type=str, default="/sensing/gnss/pose")
return parser.parse_args()


if __name__ == '__main__':
if __name__ == "__main__":
args = parse_args()
rosbag_path = args.rosbag_path
pose_topic1 = args.pose_topic1
pose_topic2 = args.pose_topic2

serialization_format = 'cdr'
storage_options = rosbag2_py.StorageOptions(
uri=str(rosbag_path), storage_id='sqlite3')
serialization_format = "cdr"
storage_options = rosbag2_py.StorageOptions(uri=str(rosbag_path), storage_id="sqlite3")
converter_options = rosbag2_py.ConverterOptions(
input_serialization_format=serialization_format, output_serialization_format=serialization_format)
input_serialization_format=serialization_format,
output_serialization_format=serialization_format,
)

reader = rosbag2_py.SequentialReader()
reader.open(storage_options, converter_options)
Expand All @@ -119,32 +123,35 @@ def parse_args():
(topic_name, data, timestamp_rosbag) = reader.read_next()
msg_type = get_message(type_map[topic_name])
msg = deserialize_message(data, msg_type)
timestamp_header = int(msg.header.stamp.sec) + \
int(msg.header.stamp.nanosec) * 1e-9
timestamp_header = int(msg.header.stamp.sec) + int(msg.header.stamp.nanosec) * 1e-9
if topic_name == pose_topic1:
pose = extract_pose_data(msg, msg_type)
pose1_data.append({
'timestamp': timestamp_header,
'x': pose.position.x,
'y': pose.position.y,
'z': pose.position.z,
'qw': pose.orientation.w,
'qx': pose.orientation.x,
'qy': pose.orientation.y,
'qz': pose.orientation.z,
})
pose1_data.append(
{
"timestamp": timestamp_header,
"x": pose.position.x,
"y": pose.position.y,
"z": pose.position.z,
"qw": pose.orientation.w,
"qx": pose.orientation.x,
"qy": pose.orientation.y,
"qz": pose.orientation.z,
}
)
elif topic_name == pose_topic2:
pose = extract_pose_data(msg, msg_type)
pose2_data.append({
'timestamp': timestamp_header,
'x': pose.position.x,
'y': pose.position.y,
'z': pose.position.z,
'qw': pose.orientation.w,
'qx': pose.orientation.x,
'qy': pose.orientation.y,
'qz': pose.orientation.z,
})
pose2_data.append(
{
"timestamp": timestamp_header,
"x": pose.position.x,
"y": pose.position.y,
"z": pose.position.z,
"qw": pose.orientation.w,
"qx": pose.orientation.x,
"qy": pose.orientation.y,
"qz": pose.orientation.z,
}
)
else:
assert False, f"Unknown topic: {topic_name}"

Expand All @@ -153,26 +160,27 @@ def parse_args():
df_pose2 = pd.DataFrame(pose2_data)

# Synchronize timestamps
df_pose2 = interpolate_pose(df_pose2, df_pose1['timestamp'])
df_pose2 = interpolate_pose(df_pose2, df_pose1["timestamp"])

assert len(df_pose1) == len(df_pose2), \
f"Lengths of pose1({len(df_pose1)}) and pose2({len(df_pose1)}) are different"
assert len(df_pose1) == len(
df_pose2
), f"Lengths of pose1({len(df_pose1)}) and pose2({len(df_pose1)}) are different"

df_length = len(pose1_data)
translation_error = []
yaw_error = []
distance_traveled = [0]

for i in range(1, df_length):
pose1_pos = df_pose1.iloc[i][['x', 'y', 'z']]
pose2_pos = df_pose2.iloc[i][['x', 'y', 'z']]
pose1_pos = df_pose1.iloc[i][["x", "y", "z"]]
pose2_pos = df_pose2.iloc[i][["x", "y", "z"]]
translation_error.append(np.linalg.norm(pose1_pos - pose2_pos))

pose1_yaw = euler_from_quaternion(df_pose1.iloc[i])[2]
pose2_yaw = euler_from_quaternion(df_pose2.iloc[i])[2]
yaw_error.append(abs(pose1_yaw - pose2_yaw))

prev_pose1_pos = df_pose1.iloc[i-1][['x', 'y', 'z']]
prev_pose1_pos = df_pose1.iloc[i - 1][["x", "y", "z"]]
distance_traveled.append(distance_traveled[-1] + np.linalg.norm(pose1_pos - prev_pose1_pos))

num_subdivisions = 5
Expand All @@ -181,7 +189,8 @@ def parse_args():

# Categorize the distance traveled into subdivisions
distance_categories = np.digitize(
distance_traveled, np.arange(0, max_distance, distance_interval))
distance_traveled, np.arange(0, max_distance, distance_interval)
)

# Ensure that the arrays are of the same length
length = min(len(distance_categories), len(translation_error), len(yaw_error))
Expand All @@ -190,41 +199,44 @@ def parse_args():
yaw_error = yaw_error[:length]

# Now create the DataFrames
df_translation_error = pd.DataFrame({
'DistanceCategory': distance_categories,
'DistanceTraveled': distance_categories * distance_interval,
'TranslationError': np.array(translation_error)
})
df_translation_error = pd.DataFrame(
{
"DistanceCategory": distance_categories,
"DistanceTraveled": distance_categories * distance_interval,
"TranslationError": np.array(translation_error),
}
)

print(df_translation_error)

df_yaw_error = pd.DataFrame({
'DistanceCategory': distance_categories,
'DistanceTraveled': distance_categories * distance_interval,
'YawError': np.array(yaw_error) * 180 / np.pi
})
df_yaw_error = pd.DataFrame(
{
"DistanceCategory": distance_categories,
"DistanceTraveled": distance_categories * distance_interval,
"YawError": np.array(yaw_error) * 180 / np.pi,
}
)

# Create a single figure for both subplots
fig, axs = plt.subplots(1, 2, figsize=(15, 6)) # 1 row, 2 columns

# Translation Error Box Plot
df_translation_error.boxplot(column='TranslationError',
by='DistanceCategory', grid=False, ax=axs[0])
axs[0].set_title('Translation Error vs Distance Traveled')
axs[0].set_xlabel('Distance traveled [m]')
axs[0].set_ylabel('Translation error [m]')
axs[0].set_xticklabels(
[f"{(i + 1) * distance_interval:.2f}" for i in range(num_subdivisions)])
df_translation_error.boxplot(

Check warning on line 224 in localization/ndt_evaluation/plot_box.py

View workflow job for this annotation

GitHub Actions / spell-check-partial

Unknown word (boxplot)
column="TranslationError", by="DistanceCategory", grid=False, ax=axs[0]
)
axs[0].set_title("Translation Error vs Distance Traveled")
axs[0].set_xlabel("Distance traveled [m]")
axs[0].set_ylabel("Translation error [m]")
axs[0].set_xticklabels([f"{(i + 1) * distance_interval:.2f}" for i in range(num_subdivisions)])

Check warning on line 230 in localization/ndt_evaluation/plot_box.py

View workflow job for this annotation

GitHub Actions / spell-check-partial

Unknown word (xticklabels)

# Yaw Error Box Plot
df_yaw_error.boxplot(column='YawError', by='DistanceCategory', grid=False, ax=axs[1])
axs[1].set_title('Yaw Error vs Distance Traveled')
axs[1].set_xlabel('Distance traveled [m]')
axs[1].set_ylabel('Yaw error [deg]')
axs[1].set_xticklabels(
[f"{(i + 1) * distance_interval:.2f}" for i in range(num_subdivisions)])
df_yaw_error.boxplot(column="YawError", by="DistanceCategory", grid=False, ax=axs[1])

Check warning on line 233 in localization/ndt_evaluation/plot_box.py

View workflow job for this annotation

GitHub Actions / spell-check-partial

Unknown word (boxplot)
axs[1].set_title("Yaw Error vs Distance Traveled")
axs[1].set_xlabel("Distance traveled [m]")
axs[1].set_ylabel("Yaw error [deg]")
axs[1].set_xticklabels([f"{(i + 1) * distance_interval:.2f}" for i in range(num_subdivisions)])

Check warning on line 237 in localization/ndt_evaluation/plot_box.py

View workflow job for this annotation

GitHub Actions / spell-check-partial

Unknown word (xticklabels)

plt.suptitle('')
plt.suptitle("")

Check warning on line 239 in localization/ndt_evaluation/plot_box.py

View workflow job for this annotation

GitHub Actions / spell-check-partial

Unknown word (suptitle)

# Adjust layout
plt.tight_layout()
Expand All @@ -233,6 +245,6 @@ def parse_args():
save_dir = rosbag_path.parent if rosbag_path.is_dir() else rosbag_path.parent.parent

save_path = save_dir / "performance_box.png"
plt.savefig(save_path, bbox_inches='tight', pad_inches=0.05)
plt.savefig(save_path, bbox_inches="tight", pad_inches=0.05)
print(f"Saved to {save_path}")
plt.close()

0 comments on commit 16fb567

Please sign in to comment.