Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Yandex Shifts Motion Dataset #33

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions DATASETS.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,32 @@ It should look like this after downloading:

**Note**: Not all the dataset parts need to be downloaded, only the necessary directories in [the Google Cloud Bucket](https://console.cloud.google.com/storage/browser/waymo_open_dataset_motion_v_1_1_0/uncompressed/scenario) need to be downloaded (e.g., `validation` for the validation dataset).

## Yandex Shifts Motion Prediction Dataset
Nothing special needs to be done for the Yandex Shifts Motion Prediction Dataset, simply download as per [the instructions on the dataset website](https://github.com/Shifts-Project/shifts#motion-prediction-1).

It should look like this after downloading:
```
/path/to/ysdc/
├── train/
| ├── 000
| | ├── 000000.pb
| | └── ...
| └── ...
├── development/
| ├── 000
| | ├── 000000.pb
| | └── ...
| └── ...
└── eval/
├── 000
| ├── 000000.pb
| └── ...
└── ...
```

**Note**: Yuo may also download a complete unpartitioned dataset. The dataset also contains prerendered examples,
which are not required for `trajdata` functioning.

## Lyft Level 5
Nothing special needs to be done for the Lyft Level 5 dataset, simply download it as per [the instructions on the dataset website](https://woven-planet.github.io/l5kit/dataset.html).

Expand Down
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@ pip install "trajdata[lyft]"
# For Waymo
pip install "trajdata[waymo]"

# For Yandex Shifts Motion Dataset
pip install "trajdata[ysdc]"

# For INTERACTION
pip install "trajdata[interaction]"

# All
pip install "trajdata[nusc,lyft,waymo,interaction]"
pip install "trajdata[nusc,lyft,waymo,interaction,ysdc]"
```
Then, download the raw datasets (nuScenes, Lyft Level 5, ETH/UCY, etc.) in case you do not already have them. For more information about how to structure dataset folders/files, please see [`DATASETS.md`](./DATASETS.md).

Expand Down Expand Up @@ -99,6 +102,9 @@ Currently, the dataloader supports interfacing with the following datasets:
| Waymo Open Motion Training | `waymo_train` | `train` | N/A | Waymo Open Motion Dataset `training` split | 0.1s (10Hz) | :white_check_mark: |
| Waymo Open Motion Validation | `waymo_val` | `val` | N/A | Waymo Open Motion Dataset `validation` split | 0.1s (10Hz) | :white_check_mark: |
| Waymo Open Motion Testing | `waymo_test` | `test` | N/A | Waymo Open Motion Dataset `testing` split | 0.1s (10Hz) | :white_check_mark: |
| Yandex Shifts Motion Dataset Training | `ysdc_train` | `train` | N/A | Yandex Shifts Motion Dataset `training` split | 0.2s (5Hz) | :white_check_mark: |
| Yandex Shifts Motion Dataset Development | `ysdc_development` | `development` | N/A | Yandex Shifts Motion Dataset `development` split | 0.2s (5Hz) | :white_check_mark: |
| Yandex Shifts Motion Dataset Evaluation | `ysdc_eval` | `eval` | N/A | Yandex Shifts Motion Dataset `eval` split | 0.2 (5Hz) | :white_check_mark: |
| Lyft Level 5 Train | `lyft_train` | `train` | `palo_alto` | Lyft Level 5 training data - part 1/2 (8.4 GB) | 0.1s (10Hz) | :white_check_mark: |
| Lyft Level 5 Train Full | `lyft_train_full` | `train` | `palo_alto` | Lyft Level 5 training data - part 2/2 (70 GB) | 0.1s (10Hz) | :white_check_mark: |
| Lyft Level 5 Validation | `lyft_val` | `val` | `palo_alto` | Lyft Level 5 validation data (8.2 GB) | 0.1s (10Hz) | :white_check_mark: |
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ interaction = ["lanelet2==1.2.1"]
lyft = ["l5kit==1.5.0"]
nusc = ["nuscenes-devkit==1.1.9"]
waymo = ["tensorflow==2.11.0", "waymo-open-dataset-tf-2-11-0", "intervaltree"]
ysdc = ["ysdc-dataset-api @ git+https://github.com/yandex-research/shifts.git#subdirectory=sdc"]

[project.urls]
"Homepage" = "https://github.com/nvr-avg/trajdata"
Expand Down
7 changes: 5 additions & 2 deletions src/trajdata/augmentation/noise_histories.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ def apply_agent(self, agent_batch: AgentBatch) -> None:
)

if agent_batch.history_pad_dir == PadDirection.BEFORE:
agent_hist_noise[..., -1, :] = 0
neigh_hist_noise[..., -1, :] = 0
try:
agent_hist_noise[..., -1, :] = 0
neigh_hist_noise[..., -1, :] = 0
except IndexError:
pass
else:
len_mask = ~mask_up_to(
agent_batch.agent_hist_len,
Expand Down
3 changes: 2 additions & 1 deletion src/trajdata/data_structures/batch_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(
self.cache: SceneCache = cache
self.data_index: int = data_index
self.dt: float = scene_time_agent.scene.dt
self.track_info = scene_time_agent.scene.data_access_info
self.scene_ts: int = scene_time_agent.ts
self.history_sec = history_sec
self.future_sec = future_sec
Expand Down Expand Up @@ -341,6 +342,7 @@ def __init__(
self.data_index = data_index
self.dt: float = scene_time.scene.dt
self.scene_ts: int = scene_time.ts
self.track_info = scene_time.scene.data_access_info

if max_agent_num is not None:
scene_time.agents = scene_time.agents[:max_agent_num]
Expand Down Expand Up @@ -506,7 +508,6 @@ def get_agents_future(
future_sec: Tuple[Optional[float], Optional[float]],
nearby_agents: List[AgentMetadata],
) -> Tuple[List[StateArray], List[np.ndarray], np.ndarray]:

(
agent_futures,
agent_future_extents,
Expand Down
11 changes: 11 additions & 0 deletions src/trajdata/dataset_specific/scene_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,14 @@ class NuPlanSceneRecord(NamedTuple):
split: str
# desc: str
data_idx: int


class YandexShiftsSceneRecord(NamedTuple):
name: str
length: str
data_idx: int
day_time: str
season: str
track: str
sun_phase: str
precipitation: str
1 change: 1 addition & 0 deletions src/trajdata/dataset_specific/yandex_shifts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .yandex_shifts_dataset import YandexShiftsDataset
213 changes: 213 additions & 0 deletions src/trajdata/dataset_specific/yandex_shifts/yandex_shifts_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
from collections import defaultdict
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Type

import pandas as pd
import tqdm

from ysdc_dataset_api.utils import get_file_paths, scenes_generator
from ysdc_dataset_api.proto import Scene as YSDCScene
from trajdata.caching import EnvCache, SceneCache
from trajdata.data_structures import EnvMetadata, Scene, SceneMetadata, SceneTag
from trajdata.data_structures.agent import AgentMetadata
from trajdata.dataset_specific.raw_dataset import RawDataset
from trajdata.dataset_specific.scene_records import YandexShiftsSceneRecord
from trajdata.dataset_specific.yandex_shifts import yandex_shifts_utils
from trajdata.maps import VectorMap
from trajdata.utils.parallel_utils import parallel_apply
from trajdata.dataset_specific.yandex_shifts.yandex_shifts_utils import (
read_scene_from_original_proto,
get_scene_path,
extract_vectorized,
extract_traffic_light_status,
extract_agent_data_from_ysdc_scene,
)


def const_lambda(const_val: Any) -> Any:
return const_val


class YandexShiftsDataset(RawDataset):
def compute_metadata(self, env_name: str, data_dir: str) -> EnvMetadata:
if env_name == "ysdc_train":
dataset_parts = [("train",)]
scene_split_map = defaultdict(partial(const_lambda, const_val="train"))

elif env_name == "ysdc_development":
dataset_parts = [("development",)]
scene_split_map = defaultdict(
partial(const_lambda, const_val="development")
)

elif env_name == "ysdc_eval":
dataset_parts = [("eval",)]
scene_split_map = defaultdict(partial(const_lambda, const_val="eval"))

elif env_name == "ysdc_full":
dataset_parts = [("full",)]
scene_split_map = defaultdict(partial(const_lambda, const_val="full"))

return EnvMetadata(
name=env_name,
data_dir=data_dir,
dt=yandex_shifts_utils.YSDC_DT,
parts=dataset_parts,
scene_split_map=scene_split_map,
)

def load_dataset_obj(self, verbose: bool = False) -> None:
if verbose:
print(f"Loading {self.name} dataset...", flush=True)
self.dataset_obj = scenes_generator(get_file_paths(self.metadata.data_dir))

def _get_matching_scenes_from_obj(
self,
scene_tag: SceneTag,
scene_desc_contains: Optional[List[str]],
env_cache: EnvCache,
) -> List[SceneMetadata]:
all_scenes_list: List[YandexShiftsSceneRecord] = list()
scenes_list: List[SceneMetadata] = list()
for idx, scene in tqdm.tqdm(
enumerate(self.dataset_obj), desc="Processing scenes from proto files"
):
scene_name: str = scene.id
scene_split: str = self.metadata.scene_split_map[scene_name]
scene_length: int = yandex_shifts_utils.YSDC_LENGTH
# Saving all scene records for later caching.
all_scenes_list.append(
YandexShiftsSceneRecord(
scene_name,
str(scene_length),
idx,
scene.scene_tags.day_time,
scene.scene_tags.season,
scene.scene_tags.track,
scene.scene_tags.sun_phase,
scene.scene_tags.precipitation,
)
)
if scene_split in scene_tag and scene_desc_contains is None:
scene_metadata = SceneMetadata(
env_name=self.metadata.name,
name=scene_name,
dt=self.metadata.dt,
raw_data_idx=idx,
)
scenes_list.append(scene_metadata)
self.cache_all_scenes_list(env_cache, all_scenes_list)
return scenes_list

def _get_matching_scenes_from_cache(
self,
scene_tag: SceneTag,
scene_desc_contains: Optional[List[str]],
env_cache: EnvCache,
) -> List[Scene]:
all_scenes_list: List[YandexShiftsSceneRecord] = env_cache.load_env_scenes_list(
self.name
)
scenes_list: List[SceneMetadata] = list()
for scene_record in all_scenes_list:
scene_split: str = self.metadata.scene_split_map[scene_record.name]
if scene_split in scene_tag and scene_desc_contains is None:
scene_metadata = Scene(
self.metadata,
scene_record.name,
scene_record.data_idx,
scene_split,
scene_record.length,
scene_record.data_idx,
None, # This isn't used if everything is already cached.
)
scenes_list.append(scene_metadata)
return scenes_list

def get_scene(self, scene_info: SceneMetadata) -> Scene:
_, _, _, data_idx = scene_info
scene_data_from_proto: YSDCScene = read_scene_from_original_proto(
get_scene_path(self.metadata.data_dir, data_idx)
)
num_history_timestamps = len(scene_data_from_proto.past_vehicle_tracks)
num_future_timestamps = len(scene_data_from_proto.future_vehicle_tracks)
scene_name: str = scene_data_from_proto.id
scene_split: str = self.metadata.scene_split_map[scene_name]
scene_length: int = num_history_timestamps + num_future_timestamps
return Scene(
self.metadata,
scene_data_from_proto.id,
data_idx,
scene_split,
scene_length,
data_idx,
{
"day_time": scene_data_from_proto.scene_tags.day_time,
"season": scene_data_from_proto.scene_tags.season,
"track_location": scene_data_from_proto.scene_tags.track,
"sun_phase": scene_data_from_proto.scene_tags.sun_phase,
"precipitation": scene_data_from_proto.scene_tags.precipitation,
},
)

def get_agent_info(
self, scene: Scene, cache_path: Path, cache_class: Type[SceneCache]
) -> Tuple[List[AgentMetadata], List[List[AgentMetadata]]]:
scene_data_from_proto = read_scene_from_original_proto(
get_scene_path(self.metadata.data_dir, scene.raw_data_idx)
)
(
scene_agents_data_df,
agent_list,
agent_presence,
) = extract_agent_data_from_ysdc_scene(scene_data_from_proto, scene)
cache_class.save_agent_data(scene_agents_data_df, cache_path, scene)
tls_dict = extract_traffic_light_status(scene_data_from_proto)
tls_df = pd.DataFrame(
tls_dict.values(),
index=pd.MultiIndex.from_tuples(
tls_dict.keys(), names=["lane_id", "scene_ts"]
),
columns=["status"],
)
cache_class.save_traffic_light_data(tls_df, cache_path, scene)
return agent_list, agent_presence

def cache_map(
self,
data_idx: int,
cache_path: Path,
map_cache_class: Type[SceneCache],
map_params: Dict[str, Any],
):
scene_data_from_proto = read_scene_from_original_proto(
get_scene_path(self.metadata.data_dir, data_idx)
)
vector_map: VectorMap = extract_vectorized(
scene_data_from_proto.path_graph, map_name=f"{self.name}:{data_idx}"
)
map_cache_class.finalize_and_cache_map(cache_path, vector_map, map_params)

def cache_maps(
self,
cache_path: Path,
map_cache_class: Type[SceneCache],
map_params: Dict[str, Any],
) -> None:
num_workers: int = map_params.get("num_workers", 0)
if num_workers > 1:
parallel_apply(
partial(
self.cache_map,
cache_path=cache_path,
map_cache_class=map_cache_class,
map_params=map_params,
),
range(len(get_file_paths(self.metadata.data_dir))),
num_workers=num_workers,
)

else:
for i in tqdm.trange(len(get_file_paths(self.metadata.data_dir))):
self.cache_map(i, cache_path, map_cache_class, map_params)
Loading