diff --git a/DATASETS.md b/DATASETS.md index b2f9255..07de6a0 100644 --- a/DATASETS.md +++ b/DATASETS.md @@ -76,6 +76,32 @@ It should look like this after downloading: **Note**: Not all the dataset parts need to be downloaded, only the necessary directories in [the Google Cloud Bucket](https://console.cloud.google.com/storage/browser/waymo_open_dataset_motion_v_1_1_0/uncompressed/scenario) need to be downloaded (e.g., `validation` for the validation dataset). +## Yandex Shifts Motion Prediction Dataset +Nothing special needs to be done for the Yandex Shifts Motion Prediction Dataset, simply download as per [the instructions on the dataset website](https://github.com/Shifts-Project/shifts#motion-prediction-1). + +It should look like this after downloading: +``` +/path/to/ysdc/ + ├── train/ + | ├── 000 + | | ├── 000000.pb + | | └── ... + | └── ... + ├── development/ + | ├── 000 + | | ├── 000000.pb + | | └── ... + | └── ... + └── eval/ + ├── 000 + | ├── 000000.pb + | └── ... + └── ... +``` + +**Note**: Yuo may also download a complete unpartitioned dataset. The dataset also contains prerendered examples, +which are not required for `trajdata` functioning. + ## Lyft Level 5 Nothing special needs to be done for the Lyft Level 5 dataset, simply download it as per [the instructions on the dataset website](https://woven-planet.github.io/l5kit/dataset.html). diff --git a/README.md b/README.md index a715ebc..5becea3 100644 --- a/README.md +++ b/README.md @@ -24,11 +24,14 @@ pip install "trajdata[lyft]" # For Waymo pip install "trajdata[waymo]" +# For Yandex Shifts Motion Dataset +pip install "trajdata[ysdc]" + # For INTERACTION pip install "trajdata[interaction]" # All -pip install "trajdata[nusc,lyft,waymo,interaction]" +pip install "trajdata[nusc,lyft,waymo,interaction,ysdc]" ``` Then, download the raw datasets (nuScenes, Lyft Level 5, ETH/UCY, etc.) in case you do not already have them. For more information about how to structure dataset folders/files, please see [`DATASETS.md`](./DATASETS.md). @@ -99,6 +102,9 @@ Currently, the dataloader supports interfacing with the following datasets: | Waymo Open Motion Training | `waymo_train` | `train` | N/A | Waymo Open Motion Dataset `training` split | 0.1s (10Hz) | :white_check_mark: | | Waymo Open Motion Validation | `waymo_val` | `val` | N/A | Waymo Open Motion Dataset `validation` split | 0.1s (10Hz) | :white_check_mark: | | Waymo Open Motion Testing | `waymo_test` | `test` | N/A | Waymo Open Motion Dataset `testing` split | 0.1s (10Hz) | :white_check_mark: | +| Yandex Shifts Motion Dataset Training | `ysdc_train` | `train` | N/A | Yandex Shifts Motion Dataset `training` split | 0.2s (5Hz) | :white_check_mark: | +| Yandex Shifts Motion Dataset Development | `ysdc_development` | `development` | N/A | Yandex Shifts Motion Dataset `development` split | 0.2s (5Hz) | :white_check_mark: | +| Yandex Shifts Motion Dataset Evaluation | `ysdc_eval` | `eval` | N/A | Yandex Shifts Motion Dataset `eval` split | 0.2 (5Hz) | :white_check_mark: | | Lyft Level 5 Train | `lyft_train` | `train` | `palo_alto` | Lyft Level 5 training data - part 1/2 (8.4 GB) | 0.1s (10Hz) | :white_check_mark: | | Lyft Level 5 Train Full | `lyft_train_full` | `train` | `palo_alto` | Lyft Level 5 training data - part 2/2 (70 GB) | 0.1s (10Hz) | :white_check_mark: | | Lyft Level 5 Validation | `lyft_val` | `val` | `palo_alto` | Lyft Level 5 validation data (8.2 GB) | 0.1s (10Hz) | :white_check_mark: | diff --git a/pyproject.toml b/pyproject.toml index 4cf3822..e40550a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ interaction = ["lanelet2==1.2.1"] lyft = ["l5kit==1.5.0"] nusc = ["nuscenes-devkit==1.1.9"] waymo = ["tensorflow==2.11.0", "waymo-open-dataset-tf-2-11-0", "intervaltree"] +ysdc = ["ysdc-dataset-api @ git+https://github.com/yandex-research/shifts.git#subdirectory=sdc"] [project.urls] "Homepage" = "https://github.com/nvr-avg/trajdata" diff --git a/src/trajdata/augmentation/noise_histories.py b/src/trajdata/augmentation/noise_histories.py index 1aca9c6..b0fa4af 100644 --- a/src/trajdata/augmentation/noise_histories.py +++ b/src/trajdata/augmentation/noise_histories.py @@ -23,8 +23,11 @@ def apply_agent(self, agent_batch: AgentBatch) -> None: ) if agent_batch.history_pad_dir == PadDirection.BEFORE: - agent_hist_noise[..., -1, :] = 0 - neigh_hist_noise[..., -1, :] = 0 + try: + agent_hist_noise[..., -1, :] = 0 + neigh_hist_noise[..., -1, :] = 0 + except IndexError: + pass else: len_mask = ~mask_up_to( agent_batch.agent_hist_len, diff --git a/src/trajdata/data_structures/batch_element.py b/src/trajdata/data_structures/batch_element.py index f18e34d..80042e6 100644 --- a/src/trajdata/data_structures/batch_element.py +++ b/src/trajdata/data_structures/batch_element.py @@ -39,6 +39,7 @@ def __init__( self.cache: SceneCache = cache self.data_index: int = data_index self.dt: float = scene_time_agent.scene.dt + self.track_info = scene_time_agent.scene.data_access_info self.scene_ts: int = scene_time_agent.ts self.history_sec = history_sec self.future_sec = future_sec @@ -341,6 +342,7 @@ def __init__( self.data_index = data_index self.dt: float = scene_time.scene.dt self.scene_ts: int = scene_time.ts + self.track_info = scene_time.scene.data_access_info if max_agent_num is not None: scene_time.agents = scene_time.agents[:max_agent_num] @@ -506,7 +508,6 @@ def get_agents_future( future_sec: Tuple[Optional[float], Optional[float]], nearby_agents: List[AgentMetadata], ) -> Tuple[List[StateArray], List[np.ndarray], np.ndarray]: - ( agent_futures, agent_future_extents, diff --git a/src/trajdata/dataset_specific/scene_records.py b/src/trajdata/dataset_specific/scene_records.py index 68bd5b9..36a83e4 100644 --- a/src/trajdata/dataset_specific/scene_records.py +++ b/src/trajdata/dataset_specific/scene_records.py @@ -48,3 +48,14 @@ class NuPlanSceneRecord(NamedTuple): split: str # desc: str data_idx: int + + +class YandexShiftsSceneRecord(NamedTuple): + name: str + length: str + data_idx: int + day_time: str + season: str + track: str + sun_phase: str + precipitation: str diff --git a/src/trajdata/dataset_specific/yandex_shifts/__init__.py b/src/trajdata/dataset_specific/yandex_shifts/__init__.py new file mode 100644 index 0000000..dc1fde6 --- /dev/null +++ b/src/trajdata/dataset_specific/yandex_shifts/__init__.py @@ -0,0 +1 @@ +from .yandex_shifts_dataset import YandexShiftsDataset diff --git a/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_dataset.py b/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_dataset.py new file mode 100644 index 0000000..041cf43 --- /dev/null +++ b/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_dataset.py @@ -0,0 +1,213 @@ +from collections import defaultdict +from functools import partial +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Type + +import pandas as pd +import tqdm + +from ysdc_dataset_api.utils import get_file_paths, scenes_generator +from ysdc_dataset_api.proto import Scene as YSDCScene +from trajdata.caching import EnvCache, SceneCache +from trajdata.data_structures import EnvMetadata, Scene, SceneMetadata, SceneTag +from trajdata.data_structures.agent import AgentMetadata +from trajdata.dataset_specific.raw_dataset import RawDataset +from trajdata.dataset_specific.scene_records import YandexShiftsSceneRecord +from trajdata.dataset_specific.yandex_shifts import yandex_shifts_utils +from trajdata.maps import VectorMap +from trajdata.utils.parallel_utils import parallel_apply +from trajdata.dataset_specific.yandex_shifts.yandex_shifts_utils import ( + read_scene_from_original_proto, + get_scene_path, + extract_vectorized, + extract_traffic_light_status, + extract_agent_data_from_ysdc_scene, +) + + +def const_lambda(const_val: Any) -> Any: + return const_val + + +class YandexShiftsDataset(RawDataset): + def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata: + if env_name == "ysdc_train": + dataset_parts = [("train",)] + scene_split_map = defaultdict(partial(const_lambda, const_val="train")) + + elif env_name == "ysdc_development": + dataset_parts = [("development",)] + scene_split_map = defaultdict( + partial(const_lambda, const_val="development") + ) + + elif env_name == "ysdc_eval": + dataset_parts = [("eval",)] + scene_split_map = defaultdict(partial(const_lambda, const_val="eval")) + + elif env_name == "ysdc_full": + dataset_parts = [("full",)] + scene_split_map = defaultdict(partial(const_lambda, const_val="full")) + + return EnvMetadata( + name=env_name, + data_dir=data_dir, + dt=yandex_shifts_utils.YSDC_DT, + parts=dataset_parts, + scene_split_map=scene_split_map, + ) + + def load_dataset_obj(self, verbose: bool = False) -> None: + if verbose: + print(f"Loading {self.name} dataset...", flush=True) + self.dataset_obj = scenes_generator(get_file_paths(self.metadata.data_dir)) + + def _get_matching_scenes_from_obj( + self, + scene_tag: SceneTag, + scene_desc_contains: Optional[List[str]], + env_cache: EnvCache, + ) -> List[SceneMetadata]: + all_scenes_list: List[YandexShiftsSceneRecord] = list() + scenes_list: List[SceneMetadata] = list() + for idx, scene in tqdm.tqdm( + enumerate(self.dataset_obj), desc="Processing scenes from proto files" + ): + scene_name: str = scene.id + scene_split: str = self.metadata.scene_split_map[scene_name] + scene_length: int = yandex_shifts_utils.YSDC_LENGTH + # Saving all scene records for later caching. + all_scenes_list.append( + YandexShiftsSceneRecord( + scene_name, + str(scene_length), + idx, + scene.scene_tags.day_time, + scene.scene_tags.season, + scene.scene_tags.track, + scene.scene_tags.sun_phase, + scene.scene_tags.precipitation, + ) + ) + if scene_split in scene_tag and scene_desc_contains is None: + scene_metadata = SceneMetadata( + env_name=self.metadata.name, + name=scene_name, + dt=self.metadata.dt, + raw_data_idx=idx, + ) + scenes_list.append(scene_metadata) + self.cache_all_scenes_list(env_cache, all_scenes_list) + return scenes_list + + def _get_matching_scenes_from_cache( + self, + scene_tag: SceneTag, + scene_desc_contains: Optional[List[str]], + env_cache: EnvCache, + ) -> List[Scene]: + all_scenes_list: List[YandexShiftsSceneRecord] = env_cache.load_env_scenes_list( + self.name + ) + scenes_list: List[SceneMetadata] = list() + for scene_record in all_scenes_list: + scene_split: str = self.metadata.scene_split_map[scene_record.name] + if scene_split in scene_tag and scene_desc_contains is None: + scene_metadata = Scene( + self.metadata, + scene_record.name, + scene_record.data_idx, + scene_split, + scene_record.length, + scene_record.data_idx, + None, # This isn't used if everything is already cached. + ) + scenes_list.append(scene_metadata) + return scenes_list + + def get_scene(self, scene_info: SceneMetadata) -> Scene: + _, _, _, data_idx = scene_info + scene_data_from_proto: YSDCScene = read_scene_from_original_proto( + get_scene_path(self.metadata.data_dir, data_idx) + ) + num_history_timestamps = len(scene_data_from_proto.past_vehicle_tracks) + num_future_timestamps = len(scene_data_from_proto.future_vehicle_tracks) + scene_name: str = scene_data_from_proto.id + scene_split: str = self.metadata.scene_split_map[scene_name] + scene_length: int = num_history_timestamps + num_future_timestamps + return Scene( + self.metadata, + scene_data_from_proto.id, + data_idx, + scene_split, + scene_length, + data_idx, + { + "day_time": scene_data_from_proto.scene_tags.day_time, + "season": scene_data_from_proto.scene_tags.season, + "track_location": scene_data_from_proto.scene_tags.track, + "sun_phase": scene_data_from_proto.scene_tags.sun_phase, + "precipitation": scene_data_from_proto.scene_tags.precipitation, + }, + ) + + def get_agent_info( + self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache] + ) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]: + scene_data_from_proto = read_scene_from_original_proto( + get_scene_path(self.metadata.data_dir, scene.raw_data_idx) + ) + ( + scene_agents_data_df, + agent_list, + agent_presence, + ) = extract_agent_data_from_ysdc_scene(scene_data_from_proto, scene) + cache_class.save_agent_data(scene_agents_data_df, cache_path, scene) + tls_dict = extract_traffic_light_status(scene_data_from_proto) + tls_df = pd.DataFrame( + tls_dict.values(), + index=pd.MultiIndex.from_tuples( + tls_dict.keys(), names=["lane_id", "scene_ts"] + ), + columns=["status"], + ) + cache_class.save_traffic_light_data(tls_df, cache_path, scene) + return agent_list, agent_presence + + def cache_map( + self, + data_idx: int, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], + ): + scene_data_from_proto = read_scene_from_original_proto( + get_scene_path(self.metadata.data_dir, data_idx) + ) + vector_map: VectorMap = extract_vectorized( + scene_data_from_proto.path_graph, map_name=f"{self.name}:{data_idx}" + ) + map_cache_class.finalize_and_cache_map(cache_path, vector_map, map_params) + + def cache_maps( + self, + cache_path: Path, + map_cache_class: Type[SceneCache], + map_params: Dict[str, Any], + ) -> None: + num_workers: int = map_params.get("num_workers", 0) + if num_workers > 1: + parallel_apply( + partial( + self.cache_map, + cache_path=cache_path, + map_cache_class=map_cache_class, + map_params=map_params, + ), + range(len(get_file_paths(self.metadata.data_dir))), + num_workers=num_workers, + ) + + else: + for i in tqdm.trange(len(get_file_paths(self.metadata.data_dir))): + self.cache_map(i, cache_path, map_cache_class, map_params) diff --git a/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_utils.py b/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_utils.py new file mode 100644 index 0000000..ef38430 --- /dev/null +++ b/src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_utils.py @@ -0,0 +1,327 @@ +import os +from collections import defaultdict +from typing import Dict, List, Tuple, Union, Any +import numpy as np +import pandas as pd +from ysdc_dataset_api.proto import Scene as YSDCScene +from ysdc_dataset_api.proto.map_pb2 import PathGraph as YSDCPathGraph +from ysdc_dataset_api.proto.dataset_pb2 import VehicleTrack as YSDCVehicleTrack +from trajdata.maps import TrafficLightStatus, VectorMap +from trajdata.maps.vec_map_elements import PedCrosswalk, Polyline, RoadLane, RoadArea +from trajdata.data_structures.batch_element import AgentBatchElement, SceneBatchElement +from trajdata.data_structures.agent import AgentMetadata, AgentType, VariableExtent +from trajdata.data_structures import Scene as TRAJScene + + +YSDC_DT = 0.2 +YSDC_LENGTH = 50 + + +def fetch_season_info( + batch_element: Union[AgentBatchElement, SceneBatchElement] +) -> int: + return batch_element.track_info["season"] + + +def fetch_day_time_info( + batch_element: Union[AgentBatchElement, SceneBatchElement] +) -> int: + return batch_element.track_info["day_time"] + + +def fetch_track_location_info( + batch_element: Union[AgentBatchElement, SceneBatchElement] +) -> int: + return batch_element.track_info["track_location"] + + +def fetch_sun_phase_info( + batch_element: Union[AgentBatchElement, SceneBatchElement] +) -> int: + return batch_element.track_info["sun_phase"] + + +def fetch_precipitation_info( + batch_element: Union[AgentBatchElement, SceneBatchElement] +) -> int: + return batch_element.track_info["precipitation"] + + +def read_scene_from_original_proto(path: str) -> YSDCScene: + with open(path, "rb") as f: + scene = YSDCScene() + scene.ParseFromString(f.read()) + return scene + + +def get_scene_path(data_dir: str, scene_idx: int) -> str: + return ( + os.path.join(data_dir, str(scene_idx // 1000).zfill(3), str(scene_idx).zfill(6)) + + ".pb" + ) + + +def fix_headings(agents_data_df: pd.DataFrame) -> pd.DataFrame: + headings = agents_data_df["heading"].values + previous_headings = np.roll(headings, 1) + previous_headings[0] = headings[0] + normalized_angle_diff = np.abs(headings - previous_headings) % (2 * np.pi) + fixed_headings = np.where( + np.minimum(normalized_angle_diff, 2 * np.pi - normalized_angle_diff) + > np.pi / 2, + headings + np.pi, + headings, + ) + agents_data_df["heading"] = fixed_headings + return agents_data_df + + +def fill_missing_timestamps( + agents_data_df: pd.DataFrame, agent_id_to_time_range: dict +) -> pd.DataFrame: + filled_agents_data_df = [] + for agent_id, agent_df in agents_data_df.groupby("agent_id"): + state_idx = 0 + for ts in range( + agent_id_to_time_range[agent_id][0], agent_id_to_time_range[agent_id][1] + 1 + ): + if ( + state_idx < agent_df.shape[0] + and agent_df.iloc[state_idx]["scene_ts"] == ts + ): + d = agent_df.iloc[state_idx].to_dict() + d["agent_id"] = agent_id + filled_agents_data_df.append(d) + state_idx += 1 + else: + filled_agents_data_df.append( + { + "agent_id": agent_id, + "scene_ts": ts, + "x": None, + "y": None, + "z": None, + "vx": None, + "vy": None, + "ax": None, + "ay": None, + "heading": None, + "length": None, + "width": None, + "height": None, + } + ) + return pd.DataFrame(filled_agents_data_df).sort_values(by=["agent_id", "scene_ts"]) + + +def map_ysdc_to_trajdata_traffic_light_status( + ysdc_tl_status: int, +) -> TrafficLightStatus: + mapping = { + -1: TrafficLightStatus.NO_DATA, + 0: TrafficLightStatus.UNKNOWN, + 1: TrafficLightStatus.GREEN, + 2: TrafficLightStatus.GREEN, + 3: TrafficLightStatus.RED, + 4: TrafficLightStatus.RED, + 5: TrafficLightStatus.RED, + 6: TrafficLightStatus.UNKNOWN, + 7: TrafficLightStatus.UNKNOWN, + 8: TrafficLightStatus.UNKNOWN, + 9: TrafficLightStatus.UNKNOWN, + 10: TrafficLightStatus.UNKNOWN, + 11: TrafficLightStatus.RED, + } + return mapping[ysdc_tl_status] + + +def extract_traffic_light_status( + ysdc_scene: YSDCScene, +) -> Dict[Tuple[str, int], TrafficLightStatus]: + traffic_light_data = {} + n_states = len(ysdc_scene.past_vehicle_tracks) + len( + ysdc_scene.future_vehicle_tracks + ) + traffic_light_section_id_to_state = {} + for traffic_light in ysdc_scene.traffic_lights: + for traffic_light_section in traffic_light.sections: + traffic_light_section_id_to_state[ + traffic_light_section.id + ] = traffic_light_section.state + for lane_idx, lane in enumerate(ysdc_scene.path_graph.lanes): + # YSDC dataset supports also left_section_id and right_section_id + conventional_lane_id = f"lane_{lane_idx}" + lane_main_section_id = lane.traffic_light_section_ids.main_section_id + if lane_main_section_id not in traffic_light_section_id_to_state: + traffic_light_section_id_to_state[lane_main_section_id] = -1 + ysdc_traffic_light_state = traffic_light_section_id_to_state[ + lane_main_section_id + ] + conventional_traffic_light_state = map_ysdc_to_trajdata_traffic_light_status( + ysdc_traffic_light_state + ) + for ts in range(n_states): + traffic_light_data[ + (conventional_lane_id, ts) + ] = conventional_traffic_light_state + return traffic_light_data + + +def extract_vectorized(map_features: YSDCPathGraph, map_name: str) -> VectorMap: + vec_map = VectorMap(map_id=map_name) + max_pt = np.array([np.nan, np.nan]) + min_pt = np.array([np.nan, np.nan]) + + for lane_idx, lane in enumerate(map_features.lanes): + lane_centers = np.array([(node.x, node.y) for node in lane.centers]) + max_pt = np.nanmax(np.vstack([max_pt, lane_centers]), axis=0) + min_pt = np.nanmin(np.vstack([min_pt, lane_centers]), axis=0) + vec_map.add_map_element( + RoadLane( + # YSDC only has center lane + id=f"lane_{lane_idx}", + center=Polyline(lane_centers), + ) + ) + + for crosswalk_idx, crosswalk in enumerate(map_features.crosswalks): + crosswalk_points = np.array( + [(node.x, node.y) for node in crosswalk.geometry.points] + ) + max_pt = np.nanmax(np.vstack([max_pt, crosswalk_points]), axis=0) + min_pt = np.nanmin(np.vstack([min_pt, crosswalk_points]), axis=0) + vec_map.add_map_element( + PedCrosswalk( + id=f"crosswalk_{crosswalk_idx}", polygon=Polyline(crosswalk_points) + ) + ) + + for road_polygon_idx, road_polygon in enumerate(map_features.road_polygons): + road_polygon_points = np.array( + [(node.x, node.y) for node in road_polygon.geometry.points] + ) + max_pt = np.nanmax(np.vstack([max_pt, road_polygon_points]), axis=0) + min_pt = np.nanmin(np.vstack([min_pt, road_polygon_points]), axis=0) + vec_map.add_map_element( + RoadArea( + id=f"road_polygon_{road_polygon_idx}", + exterior_polygon=Polyline(road_polygon_points), + ) + ) + + vec_map.extent = np.array([*min_pt, 0, *max_pt, 0]) + return vec_map + + +def prepare_agent_info_dict_from_track( + track: YSDCVehicleTrack, scene_ts: int, entity: AgentType, is_ego: bool = False +) -> Dict[str, Any]: + assert entity in [AgentType.VEHICLE, AgentType.PEDESTRIAN] + return { + "agent_id": "ego" if is_ego else str(track.track_id), + "scene_ts": scene_ts, + "x": track.position.x, + "y": track.position.y, + "z": track.position.z, + "vx": track.linear_velocity.x, + "vy": track.linear_velocity.y, + "ax": 0 if entity == AgentType.PEDESTRIAN else track.linear_acceleration.x, + "ay": 0 if entity == AgentType.PEDESTRIAN else track.linear_acceleration.y, + "heading": np.arctan2(track.linear_velocity.y, track.linear_velocity.x) + if entity == AgentType.PEDESTRIAN + else track.yaw, + "length": track.dimensions.x, + "width": track.dimensions.y, + "height": track.dimensions.z, + } + + +def update_time_range( + agent_id: str, + timestamp: int, + agent_id_to_time_range: Dict[str, Tuple[float, float]], +) -> None: + agent_id_to_time_range[agent_id] = ( + min(agent_id_to_time_range[agent_id][0], timestamp), + max(agent_id_to_time_range[agent_id][1], timestamp), + ) + + +def extract_agent_data_from_ysdc_scene( + ysdc_scene: YSDCScene, trajdata_scene: TRAJScene +) -> Tuple[pd.DataFrame, List[AgentMetadata], List[List[AgentMetadata]]]: + agent_list: List[AgentMetadata] = [] + agent_presence: List[List[AgentMetadata]] = [ + [] for _ in range(trajdata_scene.length_timesteps) + ] + scene_agents_data = defaultdict(list) + agent_id_to_time_range = defaultdict(lambda: (np.inf, -np.inf)) + agent_id_to_type = {"ego": AgentType.VEHICLE} + agents_types_data = [ + ( + AgentType.VEHICLE, + list(ysdc_scene.past_vehicle_tracks) + + list(ysdc_scene.future_vehicle_tracks), + ), + ( + AgentType.PEDESTRIAN, + list(ysdc_scene.past_pedestrian_tracks) + + list(ysdc_scene.future_pedestrian_tracks), + ), + ] + ego_agent_data = list(ysdc_scene.past_ego_track) + list(ysdc_scene.future_ego_track) + for agent_type, scene_moment_states in agents_types_data: + for timestamp, scene_moment_state in enumerate(scene_moment_states): + for agent_moment_state in scene_moment_state.tracks: + agent_info_dict = prepare_agent_info_dict_from_track( + agent_moment_state, timestamp, agent_type + ) + scene_agents_data[agent_info_dict["agent_id"]].append(agent_info_dict) + update_time_range( + agent_info_dict["agent_id"], timestamp, agent_id_to_time_range + ) + agent_id_to_type[str(agent_moment_state.track_id)] = agent_type + for timestamp, ego_agent_moment_state in enumerate(ego_agent_data): + agent_info_dict = prepare_agent_info_dict_from_track( + ego_agent_moment_state, timestamp, AgentType.VEHICLE, True + ) + scene_agents_data[agent_info_dict["agent_id"]].append(agent_info_dict) + update_time_range( + agent_info_dict["agent_id"], timestamp, agent_id_to_time_range + ) + scene_agents_data_df = pd.DataFrame( + [item for sublist in scene_agents_data.values() for item in sublist] + ).sort_values(by=["agent_id", "scene_ts"]) + scene_agents_data_df = fix_headings(scene_agents_data_df) + scene_agents_data_df = fill_missing_timestamps( + scene_agents_data_df, agent_id_to_time_range + ) + scene_agents_data_df = ( + scene_agents_data_df.groupby("agent_id", group_keys=True) + .apply(lambda group: group.interpolate(limit_area="inside")) + .reset_index(drop=True) + .set_index(["agent_id", "scene_ts"]) + ) + for agent_id in agent_id_to_type.keys(): + agent_list.append( + AgentMetadata( + name=agent_id, + agent_type=agent_id_to_type[agent_id], + first_timestep=agent_id_to_time_range[agent_id][0], + last_timestep=agent_id_to_time_range[agent_id][1], + extent=VariableExtent(), + ) + ) + for ts in range( + agent_id_to_time_range[agent_id][0], agent_id_to_time_range[agent_id][1] + 1 + ): + agent_presence[ts].append( + AgentMetadata( + name=agent_id, + agent_type=agent_id_to_type[agent_id], + first_timestep=agent_id_to_time_range[agent_id][0], + last_timestep=agent_id_to_time_range[agent_id][1], + extent=VariableExtent(), + ) + ) + return scene_agents_data_df, agent_list, agent_presence diff --git a/src/trajdata/utils/env_utils.py b/src/trajdata/utils/env_utils.py index 4726537..bee3ef6 100644 --- a/src/trajdata/utils/env_utils.py +++ b/src/trajdata/utils/env_utils.py @@ -41,6 +41,13 @@ # with the "trajdata[waymo]" option. pass +try: + from trajdata.dataset_specific.yandex_shifts import YandexShiftsDataset +except ModuleNotFoundError: + # This can happen if the user did not install trajdata + # with the "trajdata[ysdc]" option. + pass + def get_raw_dataset(dataset_name: str, data_dir: str) -> RawDataset: if "nusc" in dataset_name: @@ -65,6 +72,11 @@ def get_raw_dataset(dataset_name: str, data_dir: str) -> RawDataset: if "waymo" in dataset_name: return WaymoDataset(dataset_name, data_dir, parallelizable=True, has_maps=True) + if "ysdc" in dataset_name: + return YandexShiftsDataset( + dataset_name, data_dir, parallelizable=True, has_maps=True + ) + if "interaction" in dataset_name: return InteractionDataset( dataset_name, data_dir, parallelizable=True, has_maps=True diff --git a/src/trajdata/utils/raster_utils.py b/src/trajdata/utils/raster_utils.py index 120981f..cf47669 100644 --- a/src/trajdata/utils/raster_utils.py +++ b/src/trajdata/utils/raster_utils.py @@ -199,22 +199,22 @@ def rasterize_map( line_color=(0, 255, 0), ) - # # This code helps visualize centerlines to check if the inferred headings are correct. - # center_pts = cv2_subpixel( - # transform_points( - # proto_to_np(map_elem.road_lane.center, incl_heading=False), - # raster_from_world, - # ) - # )[..., :2] - - # # Drawing lane centerlines. - # cv2.polylines( - # img=lane_line_img, - # pts=center_pts[None, :, :], - # isClosed=False, - # color=(255, 0, 0), - # **CV2_SUB_VALUES, - # ) + # This code helps visualize centerlines to check if the inferred headings are correct. + center_pts = cv2_subpixel( + map_utils.transform_points( + map_elem.center.xyz, + raster_from_world, + ) + )[..., :2] + + # Drawing lane centerlines. + cv2.polylines( + img=lane_line_img, + pts=center_pts[None, :, :], + isClosed=False, + color=(255, 0, 0), + **CV2_SUB_VALUES, + ) # headings = np.asarray(map_elem.road_lane.center.h_rad) # delta = cv2_subpixel(30*np.array([np.cos(headings[0]), np.sin(headings[0])]))