Skip to content

Commit

Permalink
Improve rosbag2 dataloader performace (#75)
Browse files Browse the repository at this point in the history
* Copy read_points from common_interfaces

* Provide common implementation of read_point_cloud

* Switch rosbag and rosbag2 to new reader

* Improce error when rosbags is not installed

* Remove class check

- Prevents false positive when both rosbags and rosbag is installed
  • Loading branch information
markuspi authored Feb 23, 2023
1 parent 46ae1ab commit 00924ff
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 153 deletions.
19 changes: 3 additions & 16 deletions python/kiss_icp/datasets/rosbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import importlib
import os
from pathlib import Path
import sys
Expand All @@ -36,7 +35,8 @@ def __init__(self, data_dir: Path, topic: str, *_, **__):
print('python rosbag is not installed, run "sudo apt install python3-rosbag"')
sys.exit(1)

self.pc2 = importlib.import_module("sensor_msgs.point_cloud2")
from kiss_icp.tools.point_cloud2 import read_point_cloud
self.read_point_cloud = read_point_cloud
self.sequence_id = os.path.basename(data_dir).split(".")[0]

# bagfile
Expand All @@ -55,22 +55,9 @@ def __len__(self):
return self.n_scans

def __getitem__(self, idx):
return self.read_point_cloud(self.bagfile, self.topic, idx)

def read_point_cloud(self, bagfile: Path, topic: str, idx: int):
# TODO: implemnt [idx], expose field_names
_, msg, _ = next(self.msgs)
points = np.array(list(self.pc2.read_points(msg, field_names=["x", "y", "z"])))

t_field = None
for field in msg.fields:
if field.name in ["t", "timestamp", "time"]:
t_field = field.name
timestamps = np.ones(points.shape[0])
if t_field:
timestamps = np.array(list(self.pc2.read_points(msg, field_names=t_field)))
timestamps = timestamps / np.max(timestamps) if t_field != "time" else timestamps
return points.astype(np.float64), timestamps
return self.read_point_cloud(msg)

def check_topic(self, topic: str) -> str:
# when user specified the topic don't check
Expand Down
145 changes: 8 additions & 137 deletions python/kiss_icp/datasets/rosbag2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,11 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
NOTE: Original implementation of the read_points method taken from sensor_msgs.point_cloud2.py, all
rights reserver to the original author: Tim Field
"""
from dataclasses import dataclass

import importlib
import math
import os
from pathlib import Path
import struct
import sys
from typing import ClassVar

import numpy as np

Expand All @@ -40,9 +33,12 @@ class RosbagDataset:
def __init__(self, data_dir: Path, topic: str, *_, **__):
try:
from rosbags import rosbag2
except ModuleNotFoundError:
except ImportError:
print('rosbag2 reader is not installed, run "pip install rosbags"')
sys.exit(1)

from kiss_icp.tools.point_cloud2 import read_point_cloud
self.read_point_cloud = read_point_cloud

self.deserialize_cdr = importlib.import_module("rosbags.serde").deserialize_cdr

Expand All @@ -65,25 +61,16 @@ def __init__(self, data_dir: Path, topic: str, *_, **__):
self.use_global_visualizer = True

def __del__(self):
self.bag.close()
if hasattr(self, 'bag'):
self.bag.close()

def __len__(self):
return self.n_scans

def __getitem__(self, idx):
connection, _, rawdata = next(self.msgs)
msg = self.deserialize_cdr(rawdata, connection.msgtype)
points = np.array(list(read_points(msg, field_names=["x", "y", "z"])))

t_field = None
for field in msg.fields:
if field.name == "t" or field.name == "timestamp":
t_field = field.name
timestamps = np.ones(points.shape[0])
if t_field:
timestamps = np.array(list(read_points(msg, field_names=t_field)))
timestamps = timestamps / np.max(timestamps)
return points.astype(np.float64), timestamps
return self.read_point_cloud(msg)

def check_for_topics(self):
if self.topic:
Expand All @@ -94,119 +81,3 @@ def check_for_topics(self):
if topic_info.msgtype == "sensor_msgs/msg/PointCloud2":
print(topic)
sys.exit(1)


def _get_struct_fmt(is_bigendian, fields, field_names=None):
@dataclass
class PointField:
"""Class for sensor_msgs/msg/PointField."""

name: str
offset: int
datatype: int
count: int
INT8: ClassVar[int] = 1
UINT8: ClassVar[int] = 2
INT16: ClassVar[int] = 3
UINT16: ClassVar[int] = 4
INT32: ClassVar[int] = 5
UINT32: ClassVar[int] = 6
FLOAT32: ClassVar[int] = 7
FLOAT64: ClassVar[int] = 8
__msgtype__: ClassVar[str] = "sensor_msgs/msg/PointField"

_datatypes = {
PointField.INT8: ("b", 1),
PointField.UINT8: ("B", 1),
PointField.INT16: ("h", 2),
PointField.UINT16: ("H", 2),
PointField.INT32: ("i", 4),
PointField.UINT32: ("I", 4),
PointField.FLOAT32: ("f", 4),
PointField.FLOAT64: ("d", 8),
}

fmt = ">" if is_bigendian else "<"

offset = 0
for field in (
f
for f in sorted(fields, key=lambda f: f.offset)
if field_names is None or f.name in field_names
):
if offset < field.offset:
fmt += "x" * (field.offset - offset)
offset = field.offset
if field.datatype not in _datatypes:
print(
"Skipping unknown PointField datatype [%d]" % field.datatype,
file=sys.stderr,
)
else:
datatype_fmt, datatype_length = _datatypes[field.datatype]
fmt += field.count * datatype_fmt
offset += field.count * datatype_length

return fmt


def read_points(cloud, field_names=None, skip_nans=False, uvs=[]):
"""
Read points from a L{sensor_msgs.PointCloud2} message.
@param cloud: The point cloud to read from.
@type cloud: L{sensor_msgs.PointCloud2}
@param field_names: The names of fields to read. If None, read all fields. [default: None]
@type field_names: iterable
@param skip_nans: If True, then don't return any point with a NaN value.
@type skip_nans: bool [default: False]
@param uvs: If specified, then only return the points at the given coordinates. [default: empty list]
@type uvs: iterable
@return: Generator which yields a list of values for each point.
@rtype: generator
"""
fmt = _get_struct_fmt(cloud.is_bigendian, cloud.fields, field_names)
width, height, point_step, row_step, data, isnan = (
cloud.width,
cloud.height,
cloud.point_step,
cloud.row_step,
cloud.data,
math.isnan,
)
unpack_from = struct.Struct(fmt).unpack_from

if skip_nans:
if uvs:
for u, v in uvs:
p = unpack_from(data, (row_step * v) + (point_step * u))
has_nan = False
for pv in p:
if isnan(pv):
has_nan = True
break
if not has_nan:
yield p
else:
for v in range(height):
offset = row_step * v
for u in range(width):
p = unpack_from(data, offset)
has_nan = False
for pv in p:
if isnan(pv):
has_nan = True
break
if not has_nan:
yield p
offset += point_step
else:
if uvs:
for u, v in uvs:
yield unpack_from(data, (row_step * v) + (point_step * u))
else:
for v in range(height):
offset = row_step * v
for u in range(width):
yield unpack_from(data, offset)
offset += point_step
Loading

0 comments on commit 00924ff

Please sign in to comment.