diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..e6f2ec5
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,18 @@
+__pycache__/
+**/__pycache__
+*.py[cod]
+*$py.class
+
+logs/
+outputs/
+wandb/
+*.pyc
+
+.vscode/settings.json
+carla_gym/envs/__pycache__/__init__.cpython-38.pyc
+
+.idea/
+.ipynb_checkpoints/
+*.ipynb
+tensorboard_logs/
+env/
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..7c97348
--- /dev/null
+++ b/README.md
@@ -0,0 +1,110 @@
+# MUVO
+This is the PyTorch implementation for the paper
+> MUVO: A Multimodal World Model with Spatial Representations for Autonomous Driving.
+
+## Requirements
+The simplest way to install all required dependencies is to create
+a [conda](https://docs.conda.io/projects/miniconda/en/latest/) environment by running
+```
+conda env create -f conda_env.yml
+```
+Then activate conda environment by
+```
+conda activate muvo
+```
+or create your own venv and install the requirement by running
+```
+pip install -r requirements.txt
+```
+
+
+## Dataset
+Use [CARLA](http://carla.org/) to collection data.
+First install carla refer to its [documentation](https://carla.readthedocs.io/en/latest/).
+
+### Dataset Collection
+Change settings in config/,
+then run `bash run/data_collect.sh ${PORT}`
+with `${PORT}` the port to run CARLA (usually `2000`)
+The data collection code is modified from
+[CARLA-Roach](https://github.com/zhejz/carla-roach) and [MILE](https://github.com/wayveai/mile),
+some config settings can be referred there.
+
+### Voxelization
+After collecting the data by CARLA, create voxels data by running `data/generate_voxels.py`,
+voxel settings can be changed in `data_preprocess.yaml`.
+
+### Folder Structure
+After completing the above steps, or otherwise obtaining the dataset,
+please change the file structure of the dataset.
+
+The main branch includes most of the results presented in the paper. In the 2D branch, you can find 2D latent states, perceptual losses, and a new transformer backbone. The data is organized in the following format
+```
+/carla_dataset/trainval/
+ ├── train/
+ │ ├── Town01/
+ │ │ ├── 0000/
+ │ │ │ ├── birdview/
+ │ │ │ │ ├ birdview_000000000.png
+ │ │ │ │ .
+ │ │ │ ├── depth_semantic/
+ │ │ │ │ ├ depth_semantic_000000000.png
+ │ │ │ │ .
+ │ │ │ ├── image/
+ │ │ │ │ ├ image_000000000.png
+ │ │ │ │ .
+ │ │ │ ├── points/
+ │ │ │ │ ├ points_000000000.png
+ │ │ │ │ .
+ │ │ │ ├── points_semantic/
+ │ │ │ │ ├ points_semantic_000000000.png
+ │ │ │ │ .
+ │ │ │ ├── routemap/
+ │ │ │ │ ├ routemap_000000000.png
+ │ │ │ │ .
+ │ │ │ ├── voxel/
+ │ │ │ │ ├ voxel_000000000.png
+ │ │ │ │ .
+ │ │ │ └── pd_dataframe.pkl
+ │ │ ├── 0001/
+ │ │ ├── 0002/
+ │ | .
+ │ | └── 0024/
+ │ ├── Town03/
+ │ ├── Town04/
+ │ .
+ │ └── Town06/
+ ├── val0/
+ .
+ └── val1/
+```
+
+## training
+Run
+```angular2html
+python train.py --conifg-file muvo/configs/your_config.yml
+```
+You can use default config file `muvo/configs/muvo.yml`, or create your own config file in `muvo/configs/`.
+In `config file(*.yml)`, you can set all the configs listed in `muvo/config.py`.
+Before training, make sure that the required input/output data as well as the model structure/dimensions are correctly set in `muvo/configs/your_config.yml`.
+
+## test
+
+### weights
+
+We provide weights for pre-trained models, and each was trained with around 100,000 steps. [weights](https://github.com/daniel-bogdoll/MUVO/releases/tag/1.0) is for a 1D latent space. [weights_2D](https://github.com/daniel-bogdoll/MUVO/releases/tag/2.0) for a 2D latent space. We provide config files for each:
+'basic_voxel' in [weights_2D](https://github.com/daniel-bogdoll/MUVO/releases/tag/2.0) is for the basic 2D latent space model, which uses resnet18 as the backbone, without bev mapping for image features, uses range view for point cloud and uses the transformer to fuse features, the corresponding config file is '[test_base_2d.yml](https://github.com/daniel-bogdoll/MUVO/blob/main/muvo/configs/test_base_2d.yml)';
+'mobilevit' weights just change the backbone compared to the 'basic_voxel' weights, the corresponding config file is '[test_mobilevit_2d.yml](https://github.com/daniel-bogdoll/MUVO/blob/main/muvo/configs/test_mobilevit_2d.yml)';
+'RV_WOB_TR_1d_Voxel' and 'RV_WOB_TR_1d_no_Voxel' in [weights](https://github.com/daniel-bogdoll/MUVO/releases/tag/1.0) all use basic setting but use 1d latent space, '[test_base_1d.yml](https://github.com/daniel-bogdoll/MUVO/blob/main/muvo/configs/test_base_1d.yml)' and '[test_base_1d_without_voxel.yml](https://github.com/daniel-bogdoll/MUVO/blob/main/muvo/configs/test_base_1d_without_voxel.yml)' are corresponding config files.
+
+### execute
+Run
+```angular2html
+python prediction.py --config-file muvo/configs/test.yml
+```
+The config file is the same as in training.\
+In `file 'muvo/data/dataset.py', class 'DataModule', function 'setup'`, you can change the test dataset/sampler type.
+
+## Related Projects
+Our code is based on [MILE](https://github.com/wayveai/mile).
+And thanks to [CARLA-Roach](https://github.com/zhejz/carla-roach) for making a gym wrapper around CARLA.
diff --git a/carla_env.yml b/carla_env.yml
new file mode 100644
index 0000000..859d469
--- /dev/null
+++ b/carla_env.yml
@@ -0,0 +1,301 @@
+name: muvo
+channels:
+ - pytorch
+ - nvidia
+ - defaults
+dependencies:
+ - _libgcc_mutex=0.1=main
+ - _openmp_mutex=5.1=1_gnu
+ - blas=1.0=mkl
+ - blosc=1.21.3=h6a678d5_0
+ - brotli=1.0.9=h5eee18b_7
+ - brotli-bin=1.0.9=h5eee18b_7
+ - brotlipy=0.7.0=py38h27cfd23_1003
+ - brunsli=0.1=h2531618_0
+ - bzip2=1.0.8=h7b6447c_0
+ - c-ares=1.19.0=h5eee18b_0
+ - ca-certificates=2023.01.10=h06a4308_0
+ - certifi=2022.12.7=py38h06a4308_0
+ - cffi=1.15.1=py38h5eee18b_3
+ - cfitsio=3.470=h5893167_7
+ - charls=2.2.0=h2531618_0
+ - charset-normalizer=2.0.4=pyhd3eb1b0_0
+ - cryptography=39.0.1=py38h9ce1e76_0
+ - cuda-cudart=11.8.89=0
+ - cuda-cupti=11.8.87=0
+ - cuda-libraries=11.8.0=0
+ - cuda-nvrtc=11.8.89=0
+ - cuda-nvtx=11.8.86=0
+ - cuda-runtime=11.8.0=0
+ - cytoolz=0.12.0=py38h5eee18b_0
+ - dask-core=2023.3.2=py38h06a4308_0
+ - ffmpeg=4.3=hf484d3e_0
+ - filelock=3.9.0=py38h06a4308_0
+ - flit-core=3.8.0=py38h06a4308_0
+ - freetype=2.12.1=h4a9f257_0
+ - giflib=5.2.1=h5eee18b_3
+ - gmp=6.2.1=h295c915_3
+ - gmpy2=2.1.2=py38heeb90bb_0
+ - gnutls=3.6.15=he1e5248_0
+ - idna=3.4=py38h06a4308_0
+ - imagecodecs=2021.8.26=py38hfcb8610_2
+ - intel-openmp=2021.4.0=h06a4308_3561
+ - jinja2=3.1.2=py38h06a4308_0
+ - jpeg=9e=h5eee18b_1
+ - jxrlib=1.1=h7b6447c_2
+ - krb5=1.19.4=h568e23c_0
+ - lame=3.100=h7b6447c_0
+ - lcms2=2.12=h3be6417_0
+ - ld_impl_linux-64=2.38=h1181459_1
+ - lerc=3.0=h295c915_0
+ - libaec=1.0.4=he6710b0_1
+ - libbrotlicommon=1.0.9=h5eee18b_7
+ - libbrotlidec=1.0.9=h5eee18b_7
+ - libbrotlienc=1.0.9=h5eee18b_7
+ - libcublas=11.11.3.6=0
+ - libcufft=10.9.0.58=0
+ - libcufile=1.6.0.25=0
+ - libcurand=10.3.2.56=0
+ - libcurl=7.88.1=h91b91d3_0
+ - libcusolver=11.4.1.48=0
+ - libcusparse=11.7.5.86=0
+ - libdeflate=1.17=h5eee18b_0
+ - libedit=3.1.20221030=h5eee18b_0
+ - libev=4.33=h7f8727e_1
+ - libffi=3.4.2=h6a678d5_6
+ - libgcc-ng=11.2.0=h1234567_1
+ - libgfortran-ng=11.2.0=h00389a5_1
+ - libgfortran5=11.2.0=h1234567_1
+ - libgomp=11.2.0=h1234567_1
+ - libiconv=1.16=h7f8727e_2
+ - libidn2=2.3.2=h7f8727e_0
+ - libnghttp2=1.46.0=hce63b2e_0
+ - libnpp=11.8.0.86=0
+ - libnvjpeg=11.9.0.86=0
+ - libpng=1.6.39=h5eee18b_0
+ - libssh2=1.10.0=h8f2d780_0
+ - libstdcxx-ng=11.2.0=h1234567_1
+ - libtasn1=4.19.0=h5eee18b_0
+ - libtiff=4.5.0=h6a678d5_2
+ - libunistring=0.9.10=h27cfd23_0
+ - libwebp=1.2.4=h11a3e52_1
+ - libwebp-base=1.2.4=h5eee18b_1
+ - libzopfli=1.0.3=he6710b0_0
+ - locket=1.0.0=py38h06a4308_0
+ - lz4-c=1.9.4=h6a678d5_0
+ - markupsafe=2.1.1=py38h7f8727e_0
+ - mkl=2021.4.0=h06a4308_640
+ - mkl-service=2.4.0=py38h7f8727e_0
+ - mkl_fft=1.3.1=py38hd3c417c_0
+ - mkl_random=1.2.2=py38h51133e4_0
+ - mpc=1.1.0=h10f8cd9_1
+ - mpfr=4.0.2=hb69a4c5_1
+ - mpmath=1.2.1=py38h06a4308_0
+ - ncurses=6.4=h6a678d5_0
+ - nettle=3.7.3=hbbd107a_1
+ - networkx=2.8.4=py38h06a4308_1
+ - numpy=1.23.5=py38h14f4228_0
+ - numpy-base=1.23.5=py38h31eccc5_0
+ - openh264=2.1.1=h4ff587b_0
+ - openjpeg=2.4.0=h3ad879b_0
+ - openssl=1.1.1t=h7f8727e_0
+ - packaging=23.0=py38h06a4308_0
+ - partd=1.2.0=pyhd3eb1b0_1
+ - pillow=9.4.0=py38h6a678d5_0
+ - pip=23.0.1=py38h06a4308_0
+ - pooch=1.4.0=pyhd3eb1b0_0
+ - pycparser=2.21=pyhd3eb1b0_0
+ - pyopenssl=23.0.0=py38h06a4308_0
+ - pysocks=1.7.1=py38h06a4308_0
+ - python=3.8.16=h7a1cb2a_3
+ - pytorch=2.0.0=py3.8_cuda11.8_cudnn8.7.0_0
+ - pytorch-cuda=11.8=h7e8668a_3
+ - pytorch-mutex=1.0=cuda
+ - pywavelets=1.4.1=py38h5eee18b_0
+ - pyyaml=6.0=py38h5eee18b_1
+ - readline=8.2=h5eee18b_0
+ - requests=2.28.1=py38h06a4308_1
+ - scikit-image=0.19.3=py38h6a678d5_1
+ - setuptools=65.6.3=py38h06a4308_0
+ - six=1.16.0=pyhd3eb1b0_1
+ - snappy=1.1.9=h295c915_0
+ - sqlite=3.41.1=h5eee18b_0
+ - sympy=1.11.1=py38h06a4308_0
+ - tifffile=2021.7.2=pyhd3eb1b0_2
+ - tk=8.6.12=h1ccaba5_0
+ - toolz=0.12.0=py38h06a4308_0
+ - torchaudio=2.0.0=py38_cu118
+ - torchtriton=2.0.0=py38
+ - torchvision=0.15.0=py38_cu118
+ - typing_extensions=4.4.0=py38h06a4308_0
+ - urllib3=1.26.15=py38h06a4308_0
+ - wheel=0.38.4=py38h06a4308_0
+ - xz=5.2.10=h5eee18b_1
+ - yaml=0.2.5=h7b6447c_0
+ - zfp=0.5.5=h295c915_6
+ - zlib=1.2.13=h5eee18b_0
+ - zstd=1.5.4=hc292b87_0
+ - pip:
+ - absl-py==1.4.0
+ - addict==2.4.0
+ - aiohttp==3.8.4
+ - aiosignal==1.3.1
+ - antlr4-python3-runtime==4.9.3
+ - anyio==3.6.2
+ - appdirs==1.4.4
+ - arrow==1.2.3
+ - asttokens==2.2.1
+ - async-timeout==4.0.2
+ - attrs==22.2.0
+ - backcall==0.2.0
+ - beautifulsoup4==4.12.2
+ - blessed==1.20.0
+ - boto3==1.26.124
+ - botocore==1.29.124
+ - cachetools==5.3.0
+ - carla==0.9.13
+ - clearml==1.10.3
+ - click==8.1.3
+ - cloudpickle==2.2.1
+ - comm==0.1.3
+ - configargparse==1.5.3
+ - contourpy==1.0.7
+ - croniter==1.3.14
+ - cycler==0.11.0
+ - dash==2.9.2
+ - dash-core-components==2.0.0
+ - dash-html-components==2.0.0
+ - dash-table==5.0.0
+ - dateutils==0.6.12
+ - debugpy==1.6.7
+ - decorator==4.4.2
+ - deepdiff==6.3.0
+ - docker-pycreds==0.4.0
+ - executing==1.2.0
+ - fastapi==0.88.0
+ - fastjsonschema==2.16.3
+ - flask==2.2.3
+ - fonttools==4.39.3
+ - frozenlist==1.3.3
+ - fsspec==2023.4.0
+ - furl==2.1.3
+ - fvcore==0.1.5.post20221221
+ - geos==0.2.3
+ - gitdb==4.0.10
+ - gitpython==3.1.31
+ - google-auth==2.17.3
+ - google-auth-oauthlib==1.0.0
+ - grpcio==1.54.0
+ - gym==0.21.0
+ - gym-notices==0.0.8
+ - h11==0.14.0
+ - h5py==3.8.0
+ - huggingface-hub==0.13.4
+ - hydra-core==1.3.2
+ - imageio==2.27.0
+ - imageio-ffmpeg==0.4.8
+ - importlib-metadata==4.13.0
+ - importlib-resources==5.12.0
+ - inquirer==3.1.3
+ - iopath==0.1.10
+ - ipykernel==6.22.0
+ - ipython==8.12.0
+ - ipywidgets==8.0.6
+ - itsdangerous==2.1.2
+ - jedi==0.18.2
+ - jmespath==1.0.1
+ - joblib==1.2.0
+ - jsonschema==4.17.3
+ - jupyter-client==8.1.0
+ - jupyter-core==5.3.0
+ - jupyterlab-widgets==3.0.7
+ - kiwisolver==1.4.4
+ - lightning==2.0.1.post0
+ - lightning-cloud==0.5.33
+ - lightning-utilities==0.8.0
+ - lxml==4.9.2
+ - markdown==3.4.3
+ - markdown-it-py==2.2.0
+ - matplotlib==3.7.1
+ - matplotlib-inline==0.1.6
+ - mdurl==0.1.2
+ - moviepy==1.0.3
+ - multidict==6.0.4
+ - nbformat==5.7.0
+ - nest-asyncio==1.5.6
+ - oauthlib==3.2.2
+ - omegaconf==2.3.0
+ - open3d==0.17.0
+ - opencv-python==4.7.0.72
+ - ordered-set==4.1.0
+ - orderedmultidict==1.0.1
+ - pandas==2.0.0
+ - parso==0.8.3
+ - pathlib2==2.3.7.post1
+ - pathtools==0.1.2
+ - pexpect==4.8.0
+ - pickleshare==0.7.5
+ - pkgutil-resolve-name==1.3.10
+ - platformdirs==3.2.0
+ - plotly==5.14.1
+ - portalocker==2.7.0
+ - proglog==0.1.10
+ - prompt-toolkit==3.0.38
+ - protobuf==4.22.1
+ - psutil==5.9.4
+ - ptyprocess==0.7.0
+ - pure-eval==0.2.2
+ - pyasn1==0.5.0
+ - pyasn1-modules==0.3.0
+ - pydantic==1.10.7
+ - pygments==2.14.0
+ - pyjwt==2.4.0
+ - pyparsing==3.0.9
+ - pyquaternion==0.9.9
+ - pyrsistent==0.19.3
+ - python-dateutil==2.8.2
+ - python-editor==1.0.4
+ - python-multipart==0.0.6
+ - pytorch-lightning==2.0.1.post0
+ - pytz==2023.3
+ - pyzmq==25.0.2
+ - readchar==4.0.5
+ - requests-oauthlib==1.3.1
+ - rich==13.3.4
+ - rsa==4.9
+ - s3transfer==0.6.0
+ - scikit-learn==1.2.2
+ - scipy==1.10.1
+ - sentry-sdk==1.19.1
+ - setproctitle==1.3.2
+ - shapely==2.0.1
+ - smmap==5.0.0
+ - sniffio==1.3.0
+ - soupsieve==2.4.1
+ - stable-baselines3==1.8.0
+ - stack-data==0.6.2
+ - starlette==0.22.0
+ - starsessions==1.3.0
+ - tabulate==0.9.0
+ - tenacity==8.2.2
+ - tensorboard==2.12.2
+ - tensorboard-data-server==0.7.0
+ - tensorboard-plugin-wit==1.8.1
+ - termcolor==2.3.0
+ - threadpoolctl==3.1.0
+ - timm==0.6.13
+ - torchmetrics==0.11.4
+ - tornado==6.2
+ - tqdm==4.65.0
+ - traitlets==5.9.0
+ - tzdata==2023.3
+ - uvicorn==0.21.1
+ - wandb==0.14.2
+ - wcwidth==0.2.6
+ - websocket-client==1.5.1
+ - websockets==11.0.2
+ - werkzeug==2.2.3
+ - widgetsnbextension==4.0.7
+ - yacs==0.1.8
+ - yarl==1.9.1
+ - zipp==3.15.0
diff --git a/carla_gym/__init__.py b/carla_gym/__init__.py
new file mode 100644
index 0000000..56e8404
--- /dev/null
+++ b/carla_gym/__init__.py
@@ -0,0 +1,33 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+from pathlib import Path
+from gym.envs.registration import register
+
+CARLA_GYM_ROOT_DIR = Path(__file__).resolve().parent
+
+# Declare available environments with a brief description
+_AVAILABLE_ENVS = {
+ 'Endless-v0': {
+ 'entry_point': 'carla_gym.envs:EndlessEnv',
+ 'description': 'endless env for rl training and testing',
+ 'kwargs': {}
+ },
+ 'LeaderBoard-v0': {
+ 'entry_point': 'carla_gym.envs:LeaderboardEnv',
+ 'description': 'leaderboard route with no-that-dense backtround traffic',
+ 'kwargs': {}
+ }
+}
+
+
+for env_id, val in _AVAILABLE_ENVS.items():
+ register(id=env_id, entry_point=val.get('entry_point'), kwargs=val.get('kwargs'))
+
+
+def list_available_envs():
+ print('Environment-ID: Short-description')
+ import pprint
+ available_envs = {}
+ for env_id, val in _AVAILABLE_ENVS.items():
+ available_envs[env_id] = val.get('description')
+ pprint.pprint(available_envs)
diff --git a/carla_gym/carla_multi_agent_env.py b/carla_gym/carla_multi_agent_env.py
new file mode 100644
index 0000000..d69c93a
--- /dev/null
+++ b/carla_gym/carla_multi_agent_env.py
@@ -0,0 +1,214 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import logging
+import gym
+import numpy as np
+import carla
+
+from .core.zombie_walker.zombie_walker_handler import ZombieWalkerHandler
+from .core.zombie_vehicle.zombie_vehicle_handler import ZombieVehicleHandler
+from .core.obs_manager.obs_manager_handler import ObsManagerHandler
+from .core.task_actor.ego_vehicle.ego_vehicle_handler import EgoVehicleHandler
+from .core.task_actor.scenario_actor.scenario_actor_handler import ScenarioActorHandler
+from .utils.traffic_light import TrafficLightHandler
+from .utils.dynamic_weather import WeatherHandler
+from stable_baselines3.common.utils import set_random_seed
+from constants import CARLA_FPS
+
+logger = logging.getLogger(__name__)
+
+
+class CarlaMultiAgentEnv(gym.Env):
+ def __init__(self, carla_map, host, port, seed, no_rendering,
+ obs_configs, reward_configs, terminal_configs, all_tasks):
+ self._all_tasks = all_tasks
+ self._obs_configs = obs_configs
+ self._carla_map = carla_map
+ self._seed = seed
+
+ self.name = self.__class__.__name__
+
+ self._init_client(carla_map, host, port, seed=seed, no_rendering=no_rendering)
+
+ # define observation spaces exposed to agent
+ self._om_handler = ObsManagerHandler(obs_configs)
+ # this contains all info related to reward, traffic lights violations etc
+ self._ev_handler = EgoVehicleHandler(self._client, reward_configs, terminal_configs)
+ self._zw_handler = ZombieWalkerHandler(self._client)
+ self._zv_handler = ZombieVehicleHandler(self._client, tm_port=self._tm.get_port())
+ self._sa_handler = ScenarioActorHandler(self._client)
+ self._wt_handler = WeatherHandler(self._world)
+
+ # observation spaces
+ self.observation_space = self._om_handler.observation_space
+ # define action spaces exposed to agent
+ # throttle, steer, brake
+ self.action_space = gym.spaces.Dict({ego_vehicle_id: gym.spaces.Box(
+ low=np.array([0.0, -1.0, 0.0]),
+ high=np.array([1.0, 1.0, 1.0]),
+ dtype=np.float32)
+ for ego_vehicle_id in obs_configs.keys()})
+
+ self._task_idx = 0
+ self._shuffle_task = True
+ self._task = self._all_tasks[self._task_idx].copy()
+
+ def set_task_idx(self, task_idx):
+ self._task_idx = task_idx
+ self._shuffle_task = False
+ self._task = self._all_tasks[self._task_idx].copy()
+
+ @property
+ def num_tasks(self):
+ return len(self._all_tasks)
+
+ @property
+ def task(self):
+ return self._task
+
+ @property
+ def world(self):
+ return self._world
+
+ def reset(self):
+ if self._shuffle_task:
+ self._task_idx = np.random.choice(self.num_tasks)
+ self._task = self._all_tasks[self._task_idx].copy()
+ self.clean()
+
+ self._wt_handler.reset(self._task['weather'])
+ logger.debug("_wt_handler reset done!!")
+
+ ev_spawn_locations = self._ev_handler.reset(self._task['ego_vehicles'])
+ logger.debug("_ev_handler reset done!!")
+
+ self._sa_handler.reset(self._task['scenario_actors'], self._ev_handler.ego_vehicles)
+ logger.debug("_sa_handler reset done!!")
+
+ self._zw_handler.reset(self._task['num_zombie_walkers'], ev_spawn_locations)
+ logger.debug("_zw_handler reset done!!")
+
+ self._zv_handler.reset(self._task['num_zombie_vehicles'], ev_spawn_locations)
+ logger.debug("_zv_handler reset done!!")
+
+ self._om_handler.reset(self._ev_handler.ego_vehicles)
+ logger.debug("_om_handler reset done!!")
+
+ self._world.tick()
+
+ snap_shot = self._world.get_snapshot()
+ self._timestamp = {
+ 'step': 0,
+ 'frame': snap_shot.timestamp.frame,
+ 'relative_wall_time': 0.0,
+ 'wall_time': snap_shot.timestamp.platform_timestamp,
+ 'relative_simulation_time': 0.0,
+ 'simulation_time': snap_shot.timestamp.elapsed_seconds,
+ 'start_frame': snap_shot.timestamp.frame,
+ 'start_wall_time': snap_shot.timestamp.platform_timestamp,
+ 'start_simulation_time': snap_shot.timestamp.elapsed_seconds
+ }
+
+ _, _, _ = self._ev_handler.tick(self.timestamp)
+ # get obeservations
+ obs_dict = self._om_handler.get_observation(self.timestamp)
+ return obs_dict
+
+ def step(self, control_dict):
+ self._ev_handler.apply_control(control_dict)
+ self._sa_handler.tick()
+ # tick world
+ self._world.tick()
+
+ # update timestamp
+ snap_shot = self._world.get_snapshot()
+ self._timestamp['step'] = snap_shot.timestamp.frame-self._timestamp['start_frame']
+ self._timestamp['frame'] = snap_shot.timestamp.frame
+ self._timestamp['wall_time'] = snap_shot.timestamp.platform_timestamp
+ self._timestamp['relative_wall_time'] = self._timestamp['wall_time'] - self._timestamp['start_wall_time']
+ self._timestamp['simulation_time'] = snap_shot.timestamp.elapsed_seconds
+ self._timestamp['relative_simulation_time'] = self._timestamp['simulation_time'] \
+ - self._timestamp['start_simulation_time']
+
+ reward_dict, done_dict, info_dict = self._ev_handler.tick(self.timestamp)
+
+ # get observations
+ obs_dict = self._om_handler.get_observation(self.timestamp)
+
+ # update weather
+ self._wt_handler.tick(snap_shot.timestamp.delta_seconds)
+
+ # num_walkers = len(self._world.get_actors().filter("*walker.pedestrian*"))
+ # num_vehicles = len(self._world.get_actors().filter("vehicle*"))
+ # logger.debug(f"num_walkers: {num_walkers}, num_vehicles: {num_vehicles}, ")
+
+ return obs_dict, reward_dict, done_dict, info_dict
+
+ def _init_client(self, carla_map, host, port, seed=2021, no_rendering=False):
+ client = None
+ while client is None:
+ try:
+ client = carla.Client(host, port)
+ client.set_timeout(60.0)
+ except RuntimeError as re:
+ if "timeout" not in str(re) and "time-out" not in str(re):
+ print("Could not connect to Carla server because:", re)
+ client = None
+
+ self._client = client
+ self._world = client.load_world(carla_map)
+ self._tm = client.get_trafficmanager(port+6000)
+
+ self.set_sync_mode(True)
+ self.set_no_rendering_mode(self._world, no_rendering)
+
+ # self._tm.set_hybrid_physics_mode(True)
+
+ # self._tm.set_global_distance_to_leading_vehicle(5.0)
+ # logger.debug("trafficmanager set_global_distance_to_leading_vehicle")
+
+ set_random_seed(self._seed, using_cuda=True)
+ self._tm.set_random_device_seed(self._seed)
+
+ self._world.tick()
+
+ # register traffic lights
+ TrafficLightHandler.reset(self._world)
+
+ def set_sync_mode(self, sync):
+ settings = self._world.get_settings()
+ settings.synchronous_mode = sync
+ settings.fixed_delta_seconds = 1.0 / CARLA_FPS
+ settings.deterministic_ragdolls = True
+ self._world.apply_settings(settings)
+ self._tm.set_synchronous_mode(sync)
+
+ @staticmethod
+ def set_no_rendering_mode(world, no_rendering):
+ settings = world.get_settings()
+ settings.no_rendering_mode = no_rendering
+ world.apply_settings(settings)
+
+ @property
+ def timestamp(self):
+ return self._timestamp.copy()
+
+ def __exit__(self, exception_type, exception_value, traceback):
+ self.close()
+ logger.debug("env __exit__!")
+
+ def close(self):
+ self.clean()
+ self.set_sync_mode(False)
+ self._client = None
+ self._world = None
+ self._tm = None
+
+ def clean(self):
+ self._sa_handler.clean()
+ self._zw_handler.clean()
+ self._zv_handler.clean()
+ self._om_handler.clean()
+ self._ev_handler.clean()
+ self._wt_handler.clean()
+ self._world.tick()
diff --git a/carla_gym/core/obs_manager/actor_state/control.py b/carla_gym/core/obs_manager/actor_state/control.py
new file mode 100644
index 0000000..617b692
--- /dev/null
+++ b/carla_gym/core/obs_manager/actor_state/control.py
@@ -0,0 +1,40 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import numpy as np
+from gym import spaces
+
+from carla_gym.core.obs_manager.obs_manager import ObsManagerBase
+
+
+class ObsManager(ObsManagerBase):
+
+ def __init__(self, obs_configs):
+ self._parent_actor = None
+ super(ObsManager, self).__init__()
+
+ def _define_obs_space(self):
+ self.obs_space = spaces.Dict({
+ 'throttle': spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32),
+ 'steer': spaces.Box(low=-1.0, high=1.0, shape=(1,), dtype=np.float32),
+ 'brake': spaces.Box(low=0.0, high=1.0, shape=(1,), dtype=np.float32),
+ 'gear': spaces.Box(low=0.0, high=5.0, shape=(1,), dtype=np.float32), # 0-5
+ 'speed_limit': spaces.Box(low=0.0, high=50.0, shape=(1,), dtype=np.float32)
+ })
+
+ def attach_ego_vehicle(self, parent_actor):
+ self._parent_actor = parent_actor
+
+ def get_observation(self):
+ control = self._parent_actor.vehicle.get_control()
+ speed_limit = self._parent_actor.vehicle.get_speed_limit() / 3.6 * 0.8
+ obs = {
+ 'throttle': np.array([control.throttle], dtype=np.float32),
+ 'steer': np.array([control.steer], dtype=np.float32),
+ 'brake': np.array([control.brake], dtype=np.float32),
+ 'gear': np.array([control.gear], dtype=np.float32),
+ 'speed_limit': np.array([speed_limit], dtype=np.float32),
+ }
+ return obs
+
+ def clean(self):
+ self._parent_actor = None
diff --git a/carla_gym/core/obs_manager/actor_state/route.py b/carla_gym/core/obs_manager/actor_state/route.py
new file mode 100644
index 0000000..d2124e9
--- /dev/null
+++ b/carla_gym/core/obs_manager/actor_state/route.py
@@ -0,0 +1,73 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import numpy as np
+from gym import spaces
+
+from carla_gym.core.obs_manager.obs_manager import ObsManagerBase
+import carla_gym.utils.transforms as trans_utils
+
+
+class ObsManager(ObsManagerBase):
+
+ def __init__(self, obs_configs):
+ self._parent_actor = None
+ self._route_steps = 5
+ super(ObsManager, self).__init__()
+
+ def _define_obs_space(self):
+ self.obs_space = spaces.Dict({
+ 'lateral_dist': spaces.Box(low=0.0, high=2.0, shape=(1,), dtype=np.float32),
+ 'angle_diff': spaces.Box(low=-2.0, high=2.0, shape=(1,), dtype=np.float32),
+ 'route_locs': spaces.Box(low=-5.0, high=5.0, shape=(self._route_steps*2,), dtype=np.float32),
+ 'dist_remaining': spaces.Box(low=0.0, high=100, shape=(1,), dtype=np.float32)
+ })
+
+ def attach_ego_vehicle(self, parent_actor):
+ self._parent_actor = parent_actor
+
+ def get_observation(self):
+ ev_transform = self._parent_actor.vehicle.get_transform()
+ route_plan = self._parent_actor.route_plan
+
+ # lateral_dist
+ waypoint, road_option = route_plan[0]
+ wp_transform = waypoint.transform
+
+ d_vec = ev_transform.location - wp_transform.location
+ np_d_vec = np.array([d_vec.x, d_vec.y], dtype=np.float32)
+ wp_unit_forward = wp_transform.rotation.get_forward_vector()
+ np_wp_unit_right = np.array([-wp_unit_forward.y, wp_unit_forward.x], dtype=np.float32)
+
+ lateral_dist = np.abs(np.dot(np_wp_unit_right, np_d_vec))
+ lateral_dist = np.clip(lateral_dist, 0, 2)
+
+ # angle_diff
+ angle_diff = np.deg2rad(np.abs(trans_utils.cast_angle(ev_transform.rotation.yaw - wp_transform.rotation.yaw)))
+ angle_diff = np.clip(angle_diff, -2, 2)
+
+ # route_locs
+ location_list = []
+ route_length = len(route_plan)
+ for i in range(self._route_steps):
+ if i < route_length:
+ waypoint, road_option = route_plan[i]
+ else:
+ waypoint, road_option = route_plan[-1]
+
+ wp_location_world_coord = waypoint.transform.location
+ wp_location_actor_coord = trans_utils.loc_global_to_ref(wp_location_world_coord, ev_transform)
+ location_list += [wp_location_actor_coord.x, wp_location_actor_coord.y]
+
+ # dist_remaining_in_km
+ dist_remaining_in_km = (self._parent_actor.route_length - self._parent_actor.route_completed) / 1000.0
+
+ obs = {
+ 'lateral_dist': np.array([lateral_dist], dtype=np.float32),
+ 'angle_diff': np.array([angle_diff], dtype=np.float32),
+ 'route_locs': np.array(location_list, dtype=np.float32),
+ 'dist_remaining': np.array([dist_remaining_in_km], dtype=np.float32)
+ }
+ return obs
+
+ def clean(self):
+ self._parent_actor = None
diff --git a/carla_gym/core/obs_manager/actor_state/speed.py b/carla_gym/core/obs_manager/actor_state/speed.py
new file mode 100644
index 0000000..4f579de
--- /dev/null
+++ b/carla_gym/core/obs_manager/actor_state/speed.py
@@ -0,0 +1,51 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import numpy as np
+from gym import spaces
+
+from carla_gym.core.obs_manager.obs_manager import ObsManagerBase
+
+
+class ObsManager(ObsManagerBase):
+ """
+ in m/s
+ """
+
+ def __init__(self, obs_configs):
+ self._parent_actor = None
+ super(ObsManager, self).__init__()
+
+ def _define_obs_space(self):
+ self.obs_space = spaces.Dict({
+ 'speed': spaces.Box(low=-10.0, high=30.0, shape=(1,), dtype=np.float32),
+ 'speed_xy': spaces.Box(low=-10.0, high=30.0, shape=(1,), dtype=np.float32),
+ 'forward_speed': spaces.Box(low=-10.0, high=30.0, shape=(1,), dtype=np.float32)
+ })
+
+ def attach_ego_vehicle(self, parent_actor):
+ self._parent_actor = parent_actor
+
+ def get_observation(self):
+ velocity = self._parent_actor.vehicle.get_velocity()
+ transform = self._parent_actor.vehicle.get_transform()
+
+ np_vel = np.array([velocity.x, velocity.y, velocity.z])
+
+ # See https://github.com/carla-simulator/leaderboard/blob/8956c4e0c53bfa24e2bd0ccb1a5269ce47770a57/leaderboard/envs/sensor_interface.py#L90
+ pitch = np.deg2rad(transform.rotation.pitch)
+ yaw = np.deg2rad(transform.rotation.yaw)
+ orientation = np.array([np.cos(pitch) * np.cos(yaw), np.cos(pitch) * np.sin(yaw), np.sin(pitch)])
+ forward_speed = np.dot(np_vel, orientation)
+
+ speed = np.linalg.norm(np_vel)
+ speed_xy = np.linalg.norm(np_vel[0:2])
+
+ obs = {
+ 'speed': np.array([speed], dtype=np.float32),
+ 'speed_xy': np.array([speed_xy], dtype=np.float32),
+ 'forward_speed': np.array([forward_speed], dtype=np.float32)
+ }
+ return obs
+
+ def clean(self):
+ self._parent_actor = None
diff --git a/carla_gym/core/obs_manager/actor_state/velocity.py b/carla_gym/core/obs_manager/actor_state/velocity.py
new file mode 100644
index 0000000..1b7a0c7
--- /dev/null
+++ b/carla_gym/core/obs_manager/actor_state/velocity.py
@@ -0,0 +1,45 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import numpy as np
+from gym import spaces
+
+from carla_gym.core.obs_manager.obs_manager import ObsManagerBase
+import carla_gym.utils.transforms as trans_utils
+
+
+class ObsManager(ObsManagerBase):
+
+ def __init__(self, obs_configs):
+ super(ObsManager, self).__init__()
+
+ def _define_obs_space(self):
+ # acc_x, acc_y: m/s2
+ # vel_x, vel_y: m/s
+ # vel_angular z: rad/s
+ self.obs_space = spaces.Dict({
+ 'acc_xy': spaces.Box(low=-1e3, high=1e3, shape=(2,), dtype=np.float32),
+ 'vel_xy': spaces.Box(low=-1e2, high=1e2, shape=(2,), dtype=np.float32),
+ 'vel_ang_z': spaces.Box(low=-1e3, high=1e3, shape=(1,), dtype=np.float32)
+ })
+
+ def attach_ego_vehicle(self, parent_actor):
+ self._parent_actor = parent_actor
+
+ def get_observation(self):
+ ev_transform = self._parent_actor.vehicle.get_transform()
+ acc_w = self._parent_actor.vehicle.get_acceleration()
+ vel_w = self._parent_actor.vehicle.get_velocity()
+ ang_w = self._parent_actor.vehicle.get_angular_velocity()
+
+ acc_ev = trans_utils.vec_global_to_ref(acc_w, ev_transform.rotation)
+ vel_ev = trans_utils.vec_global_to_ref(vel_w, ev_transform.rotation)
+
+ obs = {
+ 'acc_xy': np.array([acc_ev.x, acc_ev.y], dtype=np.float32),
+ 'vel_xy': np.array([vel_ev.x, vel_ev.y], dtype=np.float32),
+ 'vel_ang_z': np.array([ang_w.z], dtype=np.float32)
+ }
+ return obs
+
+ def clean(self):
+ self._parent_actor = None
diff --git a/carla_gym/core/obs_manager/birdview/chauffeurnet.py b/carla_gym/core/obs_manager/birdview/chauffeurnet.py
new file mode 100644
index 0000000..001e613
--- /dev/null
+++ b/carla_gym/core/obs_manager/birdview/chauffeurnet.py
@@ -0,0 +1,310 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import numpy as np
+import carla
+from gym import spaces
+import cv2 as cv
+from collections import deque
+from pathlib import Path
+import h5py
+
+from carla_gym.core.obs_manager.obs_manager import ObsManagerBase
+from carla_gym.utils.traffic_light import TrafficLightHandler
+
+
+COLOR_BLACK = (0, 0, 0)
+COLOR_RED = (255, 0, 0)
+COLOR_GREEN = (0, 255, 0)
+COLOR_BLUE = (0, 0, 255)
+COLOR_CYAN = (0, 255, 255)
+COLOR_MAGENTA = (255, 0, 255)
+COLOR_MAGENTA_2 = (255, 140, 255)
+COLOR_YELLOW = (255, 255, 0)
+COLOR_YELLOW_2 = (160, 160, 0)
+COLOR_WHITE = (255, 255, 255)
+COLOR_ALUMINIUM_0 = (238, 238, 236)
+COLOR_ALUMINIUM_3 = (136, 138, 133)
+COLOR_ALUMINIUM_5 = (46, 52, 54)
+
+
+def tint(color, factor):
+ r, g, b = color
+ r = int(r + (255-r) * factor)
+ g = int(g + (255-g) * factor)
+ b = int(b + (255-b) * factor)
+ r = min(r, 255)
+ g = min(g, 255)
+ b = min(b, 255)
+ return (r, g, b)
+
+
+class ObsManager(ObsManagerBase):
+ def __init__(self, obs_configs):
+ self._width = int(obs_configs['width_in_pixels'])
+ self._pixels_ev_to_bottom = obs_configs['pixels_ev_to_bottom']
+ self._pixels_per_meter = obs_configs['pixels_per_meter']
+ self._history_idx = obs_configs['history_idx']
+ self._scale_bbox = obs_configs.get('scale_bbox', True)
+ self._scale_mask_col = obs_configs.get('scale_mask_col', 1.1)
+
+ self._history_queue = deque(maxlen=20)
+
+ self._image_channels = 3
+ self._masks_channels = 3 + 3*len(self._history_idx)
+ self._parent_actor = None
+ self._world = None
+
+ self._map_dir = Path(__file__).resolve().parent / 'maps'
+
+ super(ObsManager, self).__init__()
+
+ def _define_obs_space(self):
+ self.obs_space = spaces.Dict(
+ {'rendered': spaces.Box(
+ low=0, high=255, shape=(self._width, self._width, self._image_channels),
+ dtype=np.uint8),
+ 'masks': spaces.Box(
+ low=0, high=255, shape=(self._masks_channels, self._width, self._width),
+ dtype=np.uint8)})
+
+ def attach_ego_vehicle(self, parent_actor):
+ self._parent_actor = parent_actor
+ self._world = self._parent_actor.vehicle.get_world()
+
+ maps_h5_path = self._map_dir / (self._world.get_map().name.split('/')[-1] + '.h5')
+ with h5py.File(maps_h5_path, 'r', libver='latest', swmr=True) as hf:
+ self._road = np.array(hf['road'], dtype=np.uint8)
+ self._lane_marking_all = np.array(hf['lane_marking_all'], dtype=np.uint8)
+ self._lane_marking_white_broken = np.array(hf['lane_marking_white_broken'], dtype=np.uint8)
+ # self._shoulder = np.array(hf['shoulder'], dtype=np.uint8)
+ # self._parking = np.array(hf['parking'], dtype=np.uint8)
+ # self._sidewalk = np.array(hf['sidewalk'], dtype=np.uint8)
+ # self._lane_marking_yellow_broken = np.array(hf['lane_marking_yellow_broken'], dtype=np.uint8)
+ # self._lane_marking_yellow_solid = np.array(hf['lane_marking_yellow_solid'], dtype=np.uint8)
+ # self._lane_marking_white_solid = np.array(hf['lane_marking_white_solid'], dtype=np.uint8)
+
+ self._world_offset = np.array(hf.attrs['world_offset_in_meters'], dtype=np.float32)
+ assert np.isclose(self._pixels_per_meter, float(hf.attrs['pixels_per_meter']))
+
+ self._distance_threshold = np.ceil(self._width / self._pixels_per_meter)
+ # dilate road mask, lbc draw road polygon with 10px boarder
+ # kernel = np.ones((11, 11), np.uint8)
+ # self._road = cv.dilate(self._road, kernel, iterations=1)
+
+ @staticmethod
+ def _get_stops(criteria_stop):
+ stop_sign = criteria_stop._target_stop_sign
+ stops = []
+ if (stop_sign is not None) and (not criteria_stop._stop_completed):
+ bb_loc = carla.Location(stop_sign.trigger_volume.location)
+ bb_ext = carla.Vector3D(stop_sign.trigger_volume.extent)
+ bb_ext.x = max(bb_ext.x, bb_ext.y)
+ bb_ext.y = max(bb_ext.x, bb_ext.y)
+ trans = stop_sign.get_transform()
+ stops = [(carla.Transform(trans.location, trans.rotation), bb_loc, bb_ext)]
+ return stops
+
+ def get_observation(self):
+ ev_transform = self._parent_actor.vehicle.get_transform()
+ ev_loc = ev_transform.location
+ ev_rot = ev_transform.rotation
+ ev_bbox = self._parent_actor.vehicle.bounding_box
+ snap_shot = self._world.get_snapshot()
+
+ def is_within_distance(w):
+ c_distance = abs(ev_loc.x - w.location.x) < self._distance_threshold \
+ and abs(ev_loc.y - w.location.y) < self._distance_threshold \
+ and abs(ev_loc.z - w.location.z) < 8.0
+ c_ev = abs(ev_loc.x - w.location.x) < 1.0 and abs(ev_loc.y - w.location.y) < 1.0
+ return c_distance and (not c_ev)
+
+ vehicle_bbox_list = self._world.get_level_bbs(carla.CityObjectLabel.Car)
+ walker_bbox_list = self._world.get_level_bbs(carla.CityObjectLabel.Pedestrians)
+ if self._scale_bbox:
+ vehicles = self._get_surrounding_actors(vehicle_bbox_list, is_within_distance, 1.0)
+ walkers = self._get_surrounding_actors(walker_bbox_list, is_within_distance, 2.0)
+ else:
+ vehicles = self._get_surrounding_actors(vehicle_bbox_list, is_within_distance)
+ walkers = self._get_surrounding_actors(walker_bbox_list, is_within_distance)
+
+ tl_green = TrafficLightHandler.get_stopline_vtx(ev_loc, 0)
+ tl_yellow = TrafficLightHandler.get_stopline_vtx(ev_loc, 1)
+ tl_red = TrafficLightHandler.get_stopline_vtx(ev_loc, 2)
+ stops = self._get_stops(self._parent_actor.criteria_stop)
+
+ self._history_queue.append((vehicles, walkers, tl_green, tl_yellow, tl_red, stops))
+
+ M_warp = self._get_warp_transform(ev_loc, ev_rot)
+
+ # objects with history
+ vehicle_masks, walker_masks, tl_green_masks, tl_yellow_masks, tl_red_masks, stop_masks \
+ = self._get_history_masks(M_warp)
+
+ # road_mask, lane_mask
+ road_mask = cv.warpAffine(self._road, M_warp, (self._width, self._width)).astype(np.bool)
+ lane_mask_all = cv.warpAffine(self._lane_marking_all, M_warp, (self._width, self._width)).astype(np.bool)
+ lane_mask_broken = cv.warpAffine(self._lane_marking_white_broken, M_warp,
+ (self._width, self._width)).astype(np.bool)
+
+ # route_mask
+ route_mask = np.zeros([self._width, self._width], dtype=np.uint8)
+ route_in_pixel = np.array([[self._world_to_pixel(wp.transform.location)]
+ for wp, _ in self._parent_actor.route_plan[0:80]])
+ route_warped = cv.transform(route_in_pixel, M_warp)
+ cv.polylines(route_mask, [np.round(route_warped).astype(np.int32)], False, 1, thickness=16)
+ route_mask = route_mask.astype(np.bool)
+
+ # ev_mask
+ ev_mask = self._get_mask_from_actor_list([(ev_transform, ev_bbox.location, ev_bbox.extent)], M_warp)
+ ev_mask_col = self._get_mask_from_actor_list([(ev_transform, ev_bbox.location,
+ ev_bbox.extent*self._scale_mask_col)], M_warp)
+
+ # render
+ image = np.zeros([self._width, self._width, 3], dtype=np.uint8)
+ image[road_mask] = COLOR_ALUMINIUM_5
+ image[route_mask] = COLOR_ALUMINIUM_3
+ image[lane_mask_all] = COLOR_MAGENTA
+ image[lane_mask_broken] = COLOR_MAGENTA_2
+
+ h_len = len(self._history_idx)-1
+ for i, mask in enumerate(stop_masks):
+ image[mask] = tint(COLOR_YELLOW_2, (h_len-i)*0.2)
+ for i, mask in enumerate(tl_green_masks):
+ image[mask] = tint(COLOR_GREEN, (h_len-i)*0.2)
+ for i, mask in enumerate(tl_yellow_masks):
+ image[mask] = tint(COLOR_YELLOW, (h_len-i)*0.2)
+ for i, mask in enumerate(tl_red_masks):
+ image[mask] = tint(COLOR_RED, (h_len-i)*0.2)
+
+ for i, mask in enumerate(vehicle_masks):
+ image[mask] = tint(COLOR_BLUE, (h_len-i)*0.2)
+ for i, mask in enumerate(walker_masks):
+ image[mask] = tint(COLOR_CYAN, (h_len-i)*0.2)
+
+ image[ev_mask] = COLOR_WHITE
+ # image[obstacle_mask] = COLOR_BLUE
+
+ # masks
+ c_road = road_mask * 255
+ c_route = route_mask * 255
+ c_lane = lane_mask_all * 255
+ c_lane[lane_mask_broken] = 120
+
+ # masks with history
+ c_tl_history = []
+ for i in range(len(self._history_idx)):
+ c_tl = np.zeros([self._width, self._width], dtype=np.uint8)
+ c_tl[tl_green_masks[i]] = 80
+ c_tl[tl_yellow_masks[i]] = 170
+ c_tl[tl_red_masks[i]] = 255
+ c_tl[stop_masks[i]] = 255
+ c_tl_history.append(c_tl)
+
+ c_vehicle_history = [m*255 for m in vehicle_masks]
+ c_walker_history = [m*255 for m in walker_masks]
+
+ masks = np.stack((c_road, c_route, c_lane, *c_vehicle_history, *c_walker_history, *c_tl_history), axis=2)
+ masks = np.transpose(masks, [2, 0, 1])
+
+ obs_dict = {'rendered': image, 'masks': masks}
+
+ self._parent_actor.collision_px = np.any(ev_mask_col & walker_masks[-1])
+
+ return obs_dict
+
+ def _get_history_masks(self, M_warp):
+ qsize = len(self._history_queue)
+ vehicle_masks, walker_masks, tl_green_masks, tl_yellow_masks, tl_red_masks, stop_masks = [], [], [], [], [], []
+ for idx in self._history_idx:
+ idx = max(idx, -1 * qsize)
+
+ vehicles, walkers, tl_green, tl_yellow, tl_red, stops = self._history_queue[idx]
+
+ vehicle_masks.append(self._get_mask_from_actor_list(vehicles, M_warp))
+ walker_masks.append(self._get_mask_from_actor_list(walkers, M_warp))
+ tl_green_masks.append(self._get_mask_from_stopline_vtx(tl_green, M_warp))
+ tl_yellow_masks.append(self._get_mask_from_stopline_vtx(tl_yellow, M_warp))
+ tl_red_masks.append(self._get_mask_from_stopline_vtx(tl_red, M_warp))
+ stop_masks.append(self._get_mask_from_actor_list(stops, M_warp))
+
+ return vehicle_masks, walker_masks, tl_green_masks, tl_yellow_masks, tl_red_masks, stop_masks
+
+ def _get_mask_from_stopline_vtx(self, stopline_vtx, M_warp):
+ mask = np.zeros([self._width, self._width], dtype=np.uint8)
+ for sp_locs in stopline_vtx:
+ stopline_in_pixel = np.array([[self._world_to_pixel(x)] for x in sp_locs])
+ stopline_warped = cv.transform(stopline_in_pixel, M_warp)
+ cv.line(mask, np.array(stopline_warped[0, 0], dtype=int), np.array(stopline_warped[1, 0], dtype=int),
+ color=1, thickness=6)
+ return mask.astype(np.bool)
+
+ def _get_mask_from_actor_list(self, actor_list, M_warp):
+ mask = np.zeros([self._width, self._width], dtype=np.uint8)
+ for actor_transform, bb_loc, bb_ext in actor_list:
+
+ corners = [carla.Location(x=-bb_ext.x, y=-bb_ext.y),
+ carla.Location(x=bb_ext.x, y=-bb_ext.y),
+ carla.Location(x=bb_ext.x, y=0),
+ carla.Location(x=bb_ext.x, y=bb_ext.y),
+ carla.Location(x=-bb_ext.x, y=bb_ext.y)]
+ corners = [bb_loc + corner for corner in corners]
+
+ corners = [actor_transform.transform(corner) for corner in corners]
+ corners_in_pixel = np.array([[self._world_to_pixel(corner)] for corner in corners])
+ corners_warped = cv.transform(corners_in_pixel, M_warp)
+
+ cv.fillConvexPoly(mask, np.round(corners_warped).astype(np.int32), 1)
+ return mask.astype(np.bool)
+
+ @staticmethod
+ def _get_surrounding_actors(bbox_list, criterium, scale=None):
+ actors = []
+ for bbox in bbox_list:
+ is_within_distance = criterium(bbox)
+ if is_within_distance:
+ bb_loc = carla.Location()
+ bb_ext = carla.Vector3D(bbox.extent)
+ if scale is not None:
+ bb_ext = bb_ext * scale
+ bb_ext.x = max(bb_ext.x, 0.8)
+ bb_ext.y = max(bb_ext.y, 0.8)
+
+ actors.append((carla.Transform(bbox.location, bbox.rotation), bb_loc, bb_ext))
+ return actors
+
+ def _get_warp_transform(self, ev_loc, ev_rot):
+ ev_loc_in_px = self._world_to_pixel(ev_loc)
+ yaw = np.deg2rad(ev_rot.yaw)
+
+ forward_vec = np.array([np.cos(yaw), np.sin(yaw)])
+ right_vec = np.array([np.cos(yaw + 0.5*np.pi), np.sin(yaw + 0.5*np.pi)])
+
+ bottom_left = ev_loc_in_px - self._pixels_ev_to_bottom * forward_vec - (0.5*self._width) * right_vec
+ top_left = ev_loc_in_px + (self._width-self._pixels_ev_to_bottom) * forward_vec - (0.5*self._width) * right_vec
+ top_right = ev_loc_in_px + (self._width-self._pixels_ev_to_bottom) * forward_vec + (0.5*self._width) * right_vec
+
+ src_pts = np.stack((bottom_left, top_left, top_right), axis=0).astype(np.float32)
+ dst_pts = np.array([[0, self._width-1],
+ [0, 0],
+ [self._width-1, 0]], dtype=np.float32)
+ return cv.getAffineTransform(src_pts, dst_pts)
+
+ def _world_to_pixel(self, location, projective=False):
+ """Converts the world coordinates to pixel coordinates"""
+ x = self._pixels_per_meter * (location.x - self._world_offset[0])
+ y = self._pixels_per_meter * (location.y - self._world_offset[1])
+
+ if projective:
+ p = np.array([x, y, 1], dtype=np.float32)
+ else:
+ p = np.array([x, y], dtype=np.float32)
+ return p
+
+ def _world_to_pixel_width(self, width):
+ """Converts the world units to pixel units"""
+ return self._pixels_per_meter * width
+
+ def clean(self):
+ self._parent_actor = None
+ self._world = None
+ self._history_queue.clear()
diff --git a/carla_gym/core/obs_manager/birdview/chauffeurnet_label.py b/carla_gym/core/obs_manager/birdview/chauffeurnet_label.py
new file mode 100644
index 0000000..9f960bf
--- /dev/null
+++ b/carla_gym/core/obs_manager/birdview/chauffeurnet_label.py
@@ -0,0 +1,311 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import numpy as np
+import carla
+from gym import spaces
+import cv2 as cv
+from collections import deque
+from pathlib import Path
+import h5py
+
+from carla_gym.core.obs_manager.obs_manager import ObsManagerBase
+from carla_gym.utils.traffic_light import TrafficLightHandler
+
+
+COLOR_BLACK = (0, 0, 0)
+COLOR_RED = (255, 0, 0)
+COLOR_GREEN = (0, 255, 0)
+COLOR_BLUE = (0, 0, 255)
+COLOR_CYAN = (0, 255, 255)
+COLOR_MAGENTA = (255, 0, 255)
+COLOR_MAGENTA_2 = (255, 140, 255)
+COLOR_YELLOW = (255, 255, 0)
+COLOR_YELLOW_2 = (160, 160, 0)
+COLOR_WHITE = (255, 255, 255)
+COLOR_ALUMINIUM_0 = (238, 238, 236)
+COLOR_ALUMINIUM_3 = (136, 138, 133)
+COLOR_ALUMINIUM_5 = (46, 52, 54)
+
+
+def tint(color, factor):
+ r, g, b = color
+ r = int(r + (255-r) * factor)
+ g = int(g + (255-g) * factor)
+ b = int(b + (255-b) * factor)
+ r = min(r, 255)
+ g = min(g, 255)
+ b = min(b, 255)
+ return (r, g, b)
+
+
+class ObsManager(ObsManagerBase):
+ def __init__(self, obs_configs):
+ self._width = int(obs_configs['width_in_pixels'])
+ # This is the distance to the center of the ego-vehicle.
+ self._pixels_ev_to_bottom = obs_configs['pixels_ev_to_bottom']
+ self._pixels_per_meter = obs_configs['pixels_per_meter']
+ self._history_idx = obs_configs['history_idx']
+ self._scale_bbox = obs_configs.get('scale_bbox', True)
+ self._scale_mask_col = obs_configs.get('scale_mask_col', 1.1)
+
+ self._history_queue = deque(maxlen=20)
+
+ self._image_channels = 3
+ self._masks_channels = 3 + 3*len(self._history_idx)
+ self._parent_actor = None
+ self._world = None
+
+ self._map_dir = Path(__file__).resolve().parent / 'maps'
+
+ super(ObsManager, self).__init__()
+
+ def _define_obs_space(self):
+ self.obs_space = spaces.Dict(
+ {'rendered': spaces.Box(
+ low=0, high=255, shape=(self._width, self._width, self._image_channels),
+ dtype=np.uint8),
+ 'masks': spaces.Box(
+ low=0, high=255, shape=(self._masks_channels, self._width, self._width),
+ dtype=np.uint8)})
+
+ def attach_ego_vehicle(self, parent_actor):
+ self._parent_actor = parent_actor
+ self._world = self._parent_actor.vehicle.get_world()
+
+ maps_h5_path = self._map_dir / (self._world.get_map().name.split('/')[-1] + '.h5')
+ with h5py.File(maps_h5_path, 'r', libver='latest', swmr=True) as hf:
+ self._road = np.array(hf['road'], dtype=np.uint8)
+ self._lane_marking_all = np.array(hf['lane_marking_all'], dtype=np.uint8)
+ self._lane_marking_white_broken = np.array(hf['lane_marking_white_broken'], dtype=np.uint8)
+ # self._shoulder = np.array(hf['shoulder'], dtype=np.uint8)
+ # self._parking = np.array(hf['parking'], dtype=np.uint8)
+ # self._sidewalk = np.array(hf['sidewalk'], dtype=np.uint8)
+ # self._lane_marking_yellow_broken = np.array(hf['lane_marking_yellow_broken'], dtype=np.uint8)
+ # self._lane_marking_yellow_solid = np.array(hf['lane_marking_yellow_solid'], dtype=np.uint8)
+ # self._lane_marking_white_solid = np.array(hf['lane_marking_white_solid'], dtype=np.uint8)
+
+ self._world_offset = np.array(hf.attrs['world_offset_in_meters'], dtype=np.float32)
+ assert np.isclose(self._pixels_per_meter, float(hf.attrs['pixels_per_meter']))
+
+ self._distance_threshold = np.ceil(self._width / self._pixels_per_meter)
+ # dilate road mask, lbc draw road polygon with 10px boarder
+ # kernel = np.ones((11, 11), np.uint8)
+ # self._road = cv.dilate(self._road, kernel, iterations=1)
+
+ @staticmethod
+ def _get_stops(criteria_stop):
+ stop_sign = criteria_stop._target_stop_sign
+ stops = []
+ if (stop_sign is not None) and (not criteria_stop._stop_completed):
+ bb_loc = carla.Location(stop_sign.trigger_volume.location)
+ bb_ext = carla.Vector3D(stop_sign.trigger_volume.extent)
+ bb_ext.x = max(bb_ext.x, bb_ext.y)
+ bb_ext.y = max(bb_ext.x, bb_ext.y)
+ trans = stop_sign.get_transform()
+ stops = [(carla.Transform(trans.location, trans.rotation), bb_loc, bb_ext)]
+ return stops
+
+ def get_observation(self):
+ ev_transform = self._parent_actor.vehicle.get_transform()
+ ev_loc = ev_transform.location
+ ev_rot = ev_transform.rotation
+ ev_bbox = self._parent_actor.vehicle.bounding_box
+ snap_shot = self._world.get_snapshot()
+
+ def is_within_distance(w):
+ c_distance = abs(ev_loc.x - w.location.x) < self._distance_threshold \
+ and abs(ev_loc.y - w.location.y) < self._distance_threshold \
+ and abs(ev_loc.z - w.location.z) < 8.0
+ c_ev = abs(ev_loc.x - w.location.x) < 1.0 and abs(ev_loc.y - w.location.y) < 1.0
+ return c_distance and (not c_ev)
+
+ vehicle_bbox_list = self._world.get_level_bbs(carla.CityObjectLabel.Car)
+ walker_bbox_list = self._world.get_level_bbs(carla.CityObjectLabel.Pedestrians)
+ if self._scale_bbox:
+ vehicles = self._get_surrounding_actors(vehicle_bbox_list, is_within_distance, 1.0)
+ walkers = self._get_surrounding_actors(walker_bbox_list, is_within_distance, 2.0)
+ else:
+ vehicles = self._get_surrounding_actors(vehicle_bbox_list, is_within_distance)
+ walkers = self._get_surrounding_actors(walker_bbox_list, is_within_distance)
+
+ tl_green = TrafficLightHandler.get_stopline_vtx(ev_loc, 0)
+ tl_yellow = TrafficLightHandler.get_stopline_vtx(ev_loc, 1)
+ tl_red = TrafficLightHandler.get_stopline_vtx(ev_loc, 2)
+ stops = self._get_stops(self._parent_actor.criteria_stop)
+
+ self._history_queue.append((vehicles, walkers, tl_green, tl_yellow, tl_red, stops))
+
+ M_warp = self._get_warp_transform(ev_loc, ev_rot)
+
+ # objects with history
+ vehicle_masks, walker_masks, tl_green_masks, tl_yellow_masks, tl_red_masks, stop_masks \
+ = self._get_history_masks(M_warp)
+
+ # road_mask, lane_mask
+ road_mask = cv.warpAffine(self._road, M_warp, (self._width, self._width)).astype(np.bool)
+ lane_mask_all = cv.warpAffine(self._lane_marking_all, M_warp, (self._width, self._width)).astype(np.bool)
+ lane_mask_broken = cv.warpAffine(self._lane_marking_white_broken, M_warp,
+ (self._width, self._width)).astype(np.bool)
+
+ # route_mask
+ route_mask = np.zeros([self._width, self._width], dtype=np.uint8)
+ route_in_pixel = np.array([[self._world_to_pixel(wp.transform.location)]
+ for wp, _ in self._parent_actor.route_plan[0:80]])
+ route_warped = cv.transform(route_in_pixel, M_warp)
+ cv.polylines(route_mask, [np.round(route_warped).astype(np.int32)], False, 1, thickness=16)
+ route_mask = route_mask.astype(np.bool)
+
+ # ev_mask
+ ev_mask = self._get_mask_from_actor_list([(ev_transform, ev_bbox.location, ev_bbox.extent)], M_warp)
+ ev_mask_col = self._get_mask_from_actor_list([(ev_transform, ev_bbox.location,
+ ev_bbox.extent*self._scale_mask_col)], M_warp)
+
+ # render
+ image = np.zeros([self._width, self._width, 3], dtype=np.uint8)
+ image[road_mask] = COLOR_ALUMINIUM_5
+ image[route_mask] = COLOR_ALUMINIUM_3
+ image[lane_mask_all] = COLOR_MAGENTA
+ image[lane_mask_broken] = COLOR_MAGENTA_2
+
+ h_len = len(self._history_idx)-1
+ for i, mask in enumerate(stop_masks):
+ image[mask] = tint(COLOR_YELLOW_2, (h_len-i)*0.2)
+ for i, mask in enumerate(tl_green_masks):
+ image[mask] = tint(COLOR_GREEN, (h_len-i)*0.2)
+ for i, mask in enumerate(tl_yellow_masks):
+ image[mask] = tint(COLOR_YELLOW, (h_len-i)*0.2)
+ for i, mask in enumerate(tl_red_masks):
+ image[mask] = tint(COLOR_RED, (h_len-i)*0.2)
+
+ for i, mask in enumerate(vehicle_masks):
+ image[mask] = tint(COLOR_BLUE, (h_len-i)*0.2)
+ for i, mask in enumerate(walker_masks):
+ image[mask] = tint(COLOR_CYAN, (h_len-i)*0.2)
+
+ image[ev_mask] = COLOR_WHITE
+ # image[obstacle_mask] = COLOR_BLUE
+
+ # masks
+ c_road = road_mask * 255
+ c_route = route_mask * 255
+ c_lane = lane_mask_all * 255
+ c_lane[lane_mask_broken] = 120
+
+ # masks with history
+ c_tl_history = []
+ for i in range(len(self._history_idx)):
+ c_tl = np.zeros([self._width, self._width], dtype=np.uint8)
+ c_tl[tl_green_masks[i]] = 80
+ c_tl[tl_yellow_masks[i]] = 170
+ c_tl[tl_red_masks[i]] = 255
+ c_tl[stop_masks[i]] = 255
+ c_tl_history.append(c_tl)
+
+ c_vehicle_history = [m*255 for m in vehicle_masks]
+ c_walker_history = [m*255 for m in walker_masks]
+
+ masks = np.stack((c_road, c_route, c_lane, *c_vehicle_history, *c_walker_history, *c_tl_history), axis=2)
+ masks = np.transpose(masks, [2, 0, 1])
+
+ obs_dict = {'rendered': image, 'masks': masks}
+
+ self._parent_actor.collision_px = np.any(ev_mask_col & walker_masks[-1])
+
+ return obs_dict
+
+ def _get_history_masks(self, M_warp):
+ qsize = len(self._history_queue)
+ vehicle_masks, walker_masks, tl_green_masks, tl_yellow_masks, tl_red_masks, stop_masks = [], [], [], [], [], []
+ for idx in self._history_idx:
+ idx = max(idx, -1 * qsize)
+
+ vehicles, walkers, tl_green, tl_yellow, tl_red, stops = self._history_queue[idx]
+
+ vehicle_masks.append(self._get_mask_from_actor_list(vehicles, M_warp))
+ walker_masks.append(self._get_mask_from_actor_list(walkers, M_warp))
+ tl_green_masks.append(self._get_mask_from_stopline_vtx(tl_green, M_warp))
+ tl_yellow_masks.append(self._get_mask_from_stopline_vtx(tl_yellow, M_warp))
+ tl_red_masks.append(self._get_mask_from_stopline_vtx(tl_red, M_warp))
+ stop_masks.append(self._get_mask_from_actor_list(stops, M_warp))
+
+ return vehicle_masks, walker_masks, tl_green_masks, tl_yellow_masks, tl_red_masks, stop_masks
+
+ def _get_mask_from_stopline_vtx(self, stopline_vtx, M_warp):
+ mask = np.zeros([self._width, self._width], dtype=np.uint8)
+ for sp_locs in stopline_vtx:
+ stopline_in_pixel = np.array([[self._world_to_pixel(x)] for x in sp_locs])
+ stopline_warped = cv.transform(stopline_in_pixel, M_warp)
+ cv.line(mask, np.array(stopline_warped[0, 0], dtype=int), np.array(stopline_warped[1, 0], dtype=int),
+ color=1, thickness=6)
+ return mask.astype(np.bool)
+
+ def _get_mask_from_actor_list(self, actor_list, M_warp):
+ mask = np.zeros([self._width, self._width], dtype=np.uint8)
+ for actor_transform, bb_loc, bb_ext in actor_list:
+
+ corners = [carla.Location(x=-bb_ext.x, y=-bb_ext.y),
+ carla.Location(x=bb_ext.x, y=-bb_ext.y),
+ carla.Location(x=bb_ext.x, y=0),
+ carla.Location(x=bb_ext.x, y=bb_ext.y),
+ carla.Location(x=-bb_ext.x, y=bb_ext.y)]
+ corners = [bb_loc + corner for corner in corners]
+
+ corners = [actor_transform.transform(corner) for corner in corners]
+ corners_in_pixel = np.array([[self._world_to_pixel(corner)] for corner in corners])
+ corners_warped = cv.transform(corners_in_pixel, M_warp)
+
+ cv.fillConvexPoly(mask, np.round(corners_warped).astype(np.int32), 1)
+ return mask.astype(np.bool)
+
+ @staticmethod
+ def _get_surrounding_actors(bbox_list, criterium, scale=None):
+ actors = []
+ for bbox in bbox_list:
+ is_within_distance = criterium(bbox)
+ if is_within_distance:
+ bb_loc = carla.Location()
+ bb_ext = carla.Vector3D(bbox.extent)
+ if scale is not None:
+ bb_ext = bb_ext * scale
+ bb_ext.x = max(bb_ext.x, 0.8)
+ bb_ext.y = max(bb_ext.y, 0.8)
+
+ actors.append((carla.Transform(bbox.location, bbox.rotation), bb_loc, bb_ext))
+ return actors
+
+ def _get_warp_transform(self, ev_loc, ev_rot):
+ ev_loc_in_px = self._world_to_pixel(ev_loc)
+ yaw = np.deg2rad(ev_rot.yaw)
+
+ forward_vec = np.array([np.cos(yaw), np.sin(yaw)])
+ right_vec = np.array([np.cos(yaw + 0.5*np.pi), np.sin(yaw + 0.5*np.pi)])
+
+ bottom_left = ev_loc_in_px - self._pixels_ev_to_bottom * forward_vec - (0.5*self._width) * right_vec
+ top_left = ev_loc_in_px + (self._width-self._pixels_ev_to_bottom) * forward_vec - (0.5*self._width) * right_vec
+ top_right = ev_loc_in_px + (self._width-self._pixels_ev_to_bottom) * forward_vec + (0.5*self._width) * right_vec
+
+ src_pts = np.stack((bottom_left, top_left, top_right), axis=0).astype(np.float32)
+ dst_pts = np.array([[0, self._width-1],
+ [0, 0],
+ [self._width-1, 0]], dtype=np.float32)
+ return cv.getAffineTransform(src_pts, dst_pts)
+
+ def _world_to_pixel(self, location, projective=False):
+ """Converts the world coordinates to pixel coordinates"""
+ x = self._pixels_per_meter * (location.x - self._world_offset[0])
+ y = self._pixels_per_meter * (location.y - self._world_offset[1])
+
+ if projective:
+ p = np.array([x, y, 1], dtype=np.float32)
+ else:
+ p = np.array([x, y], dtype=np.float32)
+ return p
+
+ def _world_to_pixel_width(self, width):
+ """Converts the world units to pixel units"""
+ return self._pixels_per_meter * width
+
+ def clean(self):
+ self._parent_actor = None
+ self._world = None
+ self._history_queue.clear()
diff --git a/carla_gym/core/obs_manager/birdview/maps/Town01.h5 b/carla_gym/core/obs_manager/birdview/maps/Town01.h5
new file mode 100644
index 0000000..c752c1a
Binary files /dev/null and b/carla_gym/core/obs_manager/birdview/maps/Town01.h5 differ
diff --git a/carla_gym/core/obs_manager/birdview/maps/Town02.h5 b/carla_gym/core/obs_manager/birdview/maps/Town02.h5
new file mode 100644
index 0000000..ca7f2fb
Binary files /dev/null and b/carla_gym/core/obs_manager/birdview/maps/Town02.h5 differ
diff --git a/carla_gym/core/obs_manager/birdview/maps/Town03.h5 b/carla_gym/core/obs_manager/birdview/maps/Town03.h5
new file mode 100644
index 0000000..8bb4fd8
Binary files /dev/null and b/carla_gym/core/obs_manager/birdview/maps/Town03.h5 differ
diff --git a/carla_gym/core/obs_manager/birdview/maps/Town04.h5 b/carla_gym/core/obs_manager/birdview/maps/Town04.h5
new file mode 100644
index 0000000..de25dba
Binary files /dev/null and b/carla_gym/core/obs_manager/birdview/maps/Town04.h5 differ
diff --git a/carla_gym/core/obs_manager/birdview/maps/Town05.h5 b/carla_gym/core/obs_manager/birdview/maps/Town05.h5
new file mode 100644
index 0000000..2a511a5
Binary files /dev/null and b/carla_gym/core/obs_manager/birdview/maps/Town05.h5 differ
diff --git a/carla_gym/core/obs_manager/birdview/maps/Town06.h5 b/carla_gym/core/obs_manager/birdview/maps/Town06.h5
new file mode 100644
index 0000000..62ee8b4
Binary files /dev/null and b/carla_gym/core/obs_manager/birdview/maps/Town06.h5 differ
diff --git a/carla_gym/core/obs_manager/camera/depth_semantic.py b/carla_gym/core/obs_manager/camera/depth_semantic.py
new file mode 100644
index 0000000..6a5f088
--- /dev/null
+++ b/carla_gym/core/obs_manager/camera/depth_semantic.py
@@ -0,0 +1,112 @@
+# import time
+
+import numpy as np
+import weakref
+import copy
+import carla
+from queue import Queue, Empty
+from gym import spaces
+# from matplotlib import cm
+# import open3d as o3d
+
+from carla_gym.core.obs_manager.obs_manager import ObsManagerBase
+
+
+class ObsManager(ObsManagerBase):
+ def __init__(self, obs_configs):
+ self._sensor_types = ('camera.depth', 'camera.semantic_segmentation')
+ self._height = obs_configs['height']
+ self._width = obs_configs['width']
+ self._fov = obs_configs['fov']
+ self._channels = 4
+
+ # Coordinates are forward-right-up https://carla.readthedocs.io/en/latest/ref_sensors/
+ location = carla.Location(
+ x=float(obs_configs['location'][0]),
+ y=float(obs_configs['location'][1]),
+ z=float(obs_configs['location'][2]))
+ rotation = carla.Rotation(
+ roll=float(obs_configs['rotation'][0]),
+ pitch=float(obs_configs['rotation'][1]),
+ yaw=float(obs_configs['rotation'][2]))
+
+ self._camera_transform = carla.Transform(location, rotation)
+
+ self._sensor_list = []
+ self._data_queue = None
+ self._queue_timeout = 10.0
+
+ super(ObsManager, self).__init__()
+
+ def _define_obs_space(self):
+
+ self.obs_space = spaces.Dict({
+ 'frame': spaces.Discrete(2**32-1),
+ 'data': spaces.Box(
+ low=0, high=255,
+ shape=(self._height, self._width, self._channels),
+ dtype=np.uint8)
+ })
+
+ def create_sensor(self, world, bp, transform, vehicle):
+ sensor_type = bp.tags[0]
+ sensor = world.spawn_actor(bp, transform, attach_to=vehicle)
+ weak_self = weakref.ref(self)
+ sensor.listen(lambda data: self._parse_points(weak_self, data, sensor_type))
+ self._sensor_list.append(sensor)
+
+ def attach_ego_vehicle(self, parent_actor):
+ self._data_queue = Queue()
+ self._world = parent_actor.vehicle.get_world()
+ bps = [self._world.get_blueprint_library().find("sensor." + sensor) for sensor in self._sensor_types]
+ for bp in bps:
+ bp.set_attribute('image_size_x', str(self._width))
+ bp.set_attribute('image_size_y', str(self._height))
+ bp.set_attribute('fov', str(self._fov))
+
+ self.create_sensor(self._world, bp, self._camera_transform, parent_actor.vehicle)
+
+ def get_observation(self):
+ snap_shot = self._world.get_snapshot()
+
+ assert self._data_queue.qsize() <= len(self._sensor_types)
+ datas = {}
+
+ try:
+ for _ in range(len(self._sensor_types)):
+ frame, sensor_type, data = self._data_queue.get(True, self._queue_timeout)
+ assert snap_shot.frame == frame
+ datas[sensor_type] = data
+ except Empty:
+ raise Exception(f'{sensor_type} sensor took too long!')
+
+ data = np.concatenate([datas['depth'], datas['semantic_segmentation']], axis=2)
+
+ obs = {'frame': snap_shot.frame,
+ 'data': data}
+
+ return obs
+
+ def clean(self):
+ for sensor in self._sensor_list:
+ if sensor and sensor.is_alive:
+ sensor.stop()
+ sensor.destroy()
+ self._sensor_list = {}
+ self._world = None
+
+ self._data_queue = None
+
+ @staticmethod
+ def _parse_points(weak_self, data, sensor_type):
+ self = weak_self()
+
+ np_img = np.frombuffer(data.raw_data, dtype=np.dtype('uint8'))
+ np_img = np.reshape(copy.deepcopy(np_img), (data.height, data.width, 4))
+ assert (sensor_type == 'depth' or sensor_type == 'semantic_segmentation'), 'sensor_type error'
+ if sensor_type == 'depth':
+ np_img = np_img[..., :3]
+ elif sensor_type == 'semantic_segmentation':
+ np_img = np_img[..., 2][..., None]
+
+ self._data_queue.put((data.frame, sensor_type, np_img))
diff --git a/carla_gym/core/obs_manager/camera/depth_semantic_m.py b/carla_gym/core/obs_manager/camera/depth_semantic_m.py
new file mode 100644
index 0000000..6593646
--- /dev/null
+++ b/carla_gym/core/obs_manager/camera/depth_semantic_m.py
@@ -0,0 +1,136 @@
+# import time
+
+import numpy as np
+import weakref
+import copy
+import carla
+from queue import Queue, Empty
+from gym import spaces
+# from matplotlib import cm
+# import open3d as o3d
+
+from carla_gym.core.obs_manager.obs_manager import ObsManagerBase
+
+
+class ObsManager(ObsManagerBase):
+ def __init__(self, obs_configs):
+ self._sensor_types = ('camera.depth', 'camera.semantic_segmentation')
+ self._height = obs_configs['height']
+ self._width = obs_configs['width']
+ self._fov = obs_configs['fov']
+ self._channels = 4
+
+ self._camera_transform_list = []
+ # self._depth_queue_list = []
+ # self._semantic_queue_list = []
+ self._data_queue_list = []
+ self._sensor_list = []
+ self._rotation = carla.Rotation(roll=0, pitch=-90, yaw=0) # roll, pitch, yaw
+ # self._scale = ((2, -1, 1), (2, 0, 1), (2, 1, 1),
+ # (1, -1, 1), (1, 0, 1), (1, 1, 1),
+ # (0, -1, 1), (0, 0, 1), (0, 1, 1),
+ # (-1, -1, 1), (-1, 0, 1), (-1, 1, 1),
+ # (-2, -1, 1), (-2, 0, 1), (-2, 1, 1))
+ self._scale = []
+ self._hw = obs_configs['sensor_num']
+ for i in range(2 * self._hw[0] + 1):
+ for j in range(2 * self._hw[1] + 1):
+ self._scale.append((-i + self._hw[0], j - self._hw[1], 1))
+ # self._scale = ((1, 1, 1), (-1, 1, 1), (-1, -1, 1), (1, -1, 1), (0, 0, 1))
+ # self._scale = ((1, 0, 1), (-1, 0, 1), (0, 0, 1),
+ # (0.5, 1, 1), (-0.5, 1, 1), (-0.5, -1, 1), (0.5, -1, 1))
+ self._box_size = (float(obs_configs['box_size'][0]),
+ float(obs_configs['box_size'][1]),
+ float(obs_configs['box_size'][2]))
+ x, y, z = self._box_size
+ for x_scale, y_scale, z_scale in self._scale:
+ location = carla.Location(
+ x=x * x_scale,
+ y=y * y_scale,
+ z=z * z_scale
+ )
+ self._camera_transform_list.append((carla.Transform(location, self._rotation)))
+
+ self._queue_timeout = 10.0
+
+ super(ObsManager, self).__init__()
+
+ def _define_obs_space(self):
+
+ self.obs_space = spaces.Dict({
+ 'frame': spaces.Discrete(2**32-1),
+ 'data': spaces.Box(
+ low=0, high=255,
+ shape=((2 * self._hw[0] + 1) * self._height, (2 * self._hw[1] + 1) * self._width, self._channels),
+ dtype=np.uint8)
+ })
+
+ def create_sensor(self, world, bp, transform, vehicle, i):
+ self._data_queue_list.append(Queue())
+ sensor_type = bp.tags[0]
+ sensor = world.spawn_actor(bp, transform, attach_to=vehicle)
+ weak_self = weakref.ref(self)
+ sensor.listen(lambda data: self._parse_points_m(weak_self, data, sensor_type, i))
+ self._sensor_list.append(sensor)
+
+ def attach_ego_vehicle(self, parent_actor):
+ self._world = parent_actor.vehicle.get_world()
+ bps = [self._world.get_blueprint_library().find("sensor." + sensor) for sensor in self._sensor_types]
+ for bp in bps:
+ bp.set_attribute('image_size_x', str(self._width))
+ bp.set_attribute('image_size_y', str(self._height))
+ bp.set_attribute('fov', str(self._fov))
+
+ for i, camera_transform in enumerate(self._camera_transform_list):
+ self.create_sensor(self._world, bp, camera_transform, parent_actor.vehicle, i)
+
+ def get_observation(self):
+ snap_shot = self._world.get_snapshot()
+ data_all = []
+ for transf, data_queue in zip(self._camera_transform_list, self._data_queue_list):
+ assert data_queue.qsize() <= 2
+ datas = {}
+
+ try:
+ for _ in range(len(self._sensor_types)):
+ frame, sensor_type, data = data_queue.get(True, self._queue_timeout)
+ assert snap_shot.frame == frame
+ datas[sensor_type] = data
+ except Empty:
+ raise Exception(f'{sensor_type} sensor took too long!')
+
+ data_all.append(np.concatenate([datas['depth'], datas['semantic_segmentation']], axis=2))
+ h_ = 2 * self._hw[0] + 1
+ w_ = 2 * self._hw[1] + 1
+ data = np.concatenate([np.concatenate(
+ [data_all[j] for j in range(w_*i, w_*i+w_)], axis=1) for i in range(h_)], axis=0)
+
+ obs = {'frame': snap_shot.frame,
+ 'data': data,
+ 'trans': self._box_size}
+
+ return obs
+
+ def clean(self):
+ for sensor in self._sensor_list:
+ if sensor and sensor.is_alive:
+ sensor.stop()
+ sensor.destroy()
+ self._sensor_list = {}
+ self._world = None
+
+ self._data_queue_list = {}
+
+ @staticmethod
+ def _parse_points_m(weak_self, data, sensor_type, i):
+ self = weak_self()
+
+ np_img = np.frombuffer(data.raw_data, dtype=np.dtype('uint8'))
+ np_img = np.reshape(copy.deepcopy(np_img), (data.height, data.width, 4))
+ assert (sensor_type == 'depth' or sensor_type == 'semantic_segmentation'), 'sensor_type error'
+ if sensor_type == 'depth':
+ np_img = np_img[..., :3]
+ elif sensor_type == 'semantic_segmentation':
+ np_img = np_img[..., 2][..., None]
+
+ self._data_queue_list[i].put((data.frame, sensor_type, np_img))
diff --git a/carla_gym/core/obs_manager/camera/rgb.py b/carla_gym/core/obs_manager/camera/rgb.py
new file mode 100644
index 0000000..03a2428
--- /dev/null
+++ b/carla_gym/core/obs_manager/camera/rgb.py
@@ -0,0 +1,128 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import numpy as np
+import copy
+import weakref
+import carla
+from queue import Queue, Empty
+from gym import spaces
+
+from carla_gym.core.obs_manager.obs_manager import ObsManagerBase
+
+from constants import DISTORT_IMAGES
+
+
+class ObsManager(ObsManagerBase):
+ """
+ Template configs:
+ obs_configs = {
+ "module": "camera.rgb",
+ "location": [-5.5, 0, 2.8],
+ "rotation": [0, -15, 0],
+ "frame_stack": 1,
+ "width": 1920,
+ "height": 1080
+ }
+ frame_stack: [Image(t-2), Image(t-1), Image(t)]
+ """
+
+ def __init__(self, obs_configs):
+
+ self._sensor_type = 'camera.rgb'
+
+ self._height = obs_configs['height']
+ self._width = obs_configs['width']
+ self._fov = obs_configs['fov']
+ self._channels = 4
+
+ # Coordinates are forward-right-up https://carla.readthedocs.io/en/latest/ref_sensors/
+ location = carla.Location(
+ x=float(obs_configs['location'][0]),
+ y=float(obs_configs['location'][1]),
+ z=float(obs_configs['location'][2]))
+ rotation = carla.Rotation(
+ roll=float(obs_configs['rotation'][0]),
+ pitch=float(obs_configs['rotation'][1]),
+ yaw=float(obs_configs['rotation'][2]))
+
+ self._camera_transform = carla.Transform(location, rotation)
+
+ self._sensor = None
+ self._queue_timeout = 10.0
+ self._image_queue = None
+
+ super(ObsManager, self).__init__()
+
+ def _define_obs_space(self):
+
+ self.obs_space = spaces.Dict({
+ 'frame': spaces.Discrete(2**32-1),
+ 'data': spaces.Box(
+ low=0, high=255, shape=(self._height, self._width, self._channels), dtype=np.uint8)
+ })
+
+ def attach_ego_vehicle(self, parent_actor):
+ init_obs = np.zeros([self._height, self._width, self._channels], dtype=np.uint8)
+ self._image_queue = Queue()
+
+ self._world = parent_actor.vehicle.get_world()
+
+ bp = self._world.get_blueprint_library().find("sensor."+self._sensor_type)
+ bp.set_attribute('image_size_x', str(self._width))
+ bp.set_attribute('image_size_y', str(self._height))
+ bp.set_attribute('fov', str(self._fov))
+
+ if DISTORT_IMAGES:
+ # set in leaderboard
+ # https://github.com/carla-simulator/leaderboard/blob/8956c4e0c53bfa24e2bd0ccb1a5269ce47770a57/leaderboard/autoagents/agent_wrapper.py#L100
+ bp.set_attribute('lens_circle_multiplier', str(3.0))
+ bp.set_attribute('lens_circle_falloff', str(3.0))
+ bp.set_attribute('chromatic_aberration_intensity', str(0.5))
+ bp.set_attribute('chromatic_aberration_offset', str(0))
+
+ self._sensor = self._world.spawn_actor(bp, self._camera_transform, attach_to=parent_actor.vehicle)
+ weak_self = weakref.ref(self)
+ self._sensor.listen(lambda image: self._parse_image(weak_self, image))
+
+ def get_observation(self):
+ snap_shot = self._world.get_snapshot()
+ assert self._image_queue.qsize() <= 1
+
+ try:
+ frame, data = self._image_queue.get(True, self._queue_timeout)
+ assert snap_shot.frame == frame
+ except Empty:
+ raise Exception('RGB sensor took too long!')
+
+ obs = {'frame': frame,
+ 'data': data}
+
+ return obs
+
+ def clean(self):
+ if self._sensor and self._sensor.is_alive:
+ self._sensor.stop()
+ self._sensor.destroy()
+ self._sensor = None
+ self._world = None
+
+ self._image_queue = None
+
+ @staticmethod
+ def _parse_image(weak_self, carla_image):
+ self = weak_self()
+
+ np_img = np.frombuffer(carla_image.raw_data, dtype=np.dtype("uint8"))
+
+ np_img = copy.deepcopy(np_img)
+
+ np_img = np.reshape(np_img, (carla_image.height, carla_image.width, 4))
+ np_img = np_img[:, :, :3]
+ np_img = np_img[:, :, ::-1]
+
+ # np_img = np.moveaxis(np_img, -1, 0)
+ # image = cv2.resize(image, (self._res_x, self._res_y), interpolation=cv2.INTER_AREA)
+ # image = np.float32
+ # image = (image.astype(np.float32) - 128) / 128
+
+ self._image_queue.put((carla_image.frame, np_img))
diff --git a/carla_gym/core/obs_manager/lidar/ray_cast.py b/carla_gym/core/obs_manager/lidar/ray_cast.py
new file mode 100644
index 0000000..5e06044
--- /dev/null
+++ b/carla_gym/core/obs_manager/lidar/ray_cast.py
@@ -0,0 +1,201 @@
+import time
+
+import numpy as np
+import weakref
+import carla
+from queue import Queue, Empty
+from gym import spaces
+from matplotlib import cm
+import open3d as o3d
+
+from carla_gym.core.obs_manager.obs_manager import ObsManagerBase
+from constants import CARLA_FPS
+
+VIRIDIS = np.array(cm.get_cmap('plasma').colors)
+VID_RANGE = np.linspace(0.0, 1.0, VIRIDIS.shape[0])
+
+
+def add_open3d_axis(vis):
+ """Add a small 3D axis on Open3D Visualizer"""
+ axis = o3d.geometry.LineSet()
+ axis.points = o3d.utility.Vector3dVector(np.array([
+ [0.0, 0.0, 0.0],
+ [1.0, 0.0, 0.0],
+ [0.0, 1.0, 0.0],
+ [0.0, 0.0, 1.0]]))
+ axis.lines = o3d.utility.Vector2iVector(np.array([
+ [0, 1],
+ [0, 2],
+ [0, 3]]))
+ axis.colors = o3d.utility.Vector3dVector(np.array([
+ [1.0, 0.0, 0.0],
+ [0.0, 1.0, 0.0],
+ [0.0, 0.0, 1.0]]))
+ vis.add_geometry(axis)
+
+
+class ObsManager(ObsManagerBase):
+ """
+ Template configs:
+ obs_configs = {
+ "module": "lidar.ray_cast",
+ "location": [-5.5, 0, 2.8],
+ "rotation": [0, 0, 0],
+ "frame_stack": 1,
+ "render_o3d": False,
+ "show_axis": False,
+ "no_noise": False,
+ "lidar_options": {
+ "width": 1920,
+ "height": 1080,
+ # https://github.com/carla-simulator/leaderboard/blob/master/leaderboard/autoagents/agent_wrapper.py
+ "channels": 64,
+ "range": 100,
+ "rotation_frequency": 20
+ "points_per_second": 100000
+ "upper_fov": 15.0,
+ "lower_fov": 25.0, # -30.0
+ "atmosphere_attenuation_rate": 0.004,
+ # if no_noise
+ "dropoff_general_rate": 0.45,
+ "dropoff_intensity_limit": 0.8,
+ "dropoff_zero_intensity": 0.4,
+ },
+ }
+ frame_stack: [Image(t-2), Image(t-1), Image(t)]
+ """
+
+ def __init__(self, obs_configs):
+
+ self._sensor_type = 'lidar.ray_cast'
+
+ self._lidar_options = obs_configs['lidar_options']
+ self._no_noise = obs_configs['no_noise']
+ self._render_o3d = obs_configs["render_o3d"]
+ self._show_axis = obs_configs["show_axis"]
+
+ # rewrite the 'rotation_frequency' to the same as carla_fps
+ self._lidar_options['rotation_frequency'] = CARLA_FPS
+
+ # Coordinates are forward-right-up https://carla.readthedocs.io/en/latest/ref_sensors/
+ location = carla.Location(
+ x=float(obs_configs['location'][0]),
+ y=float(obs_configs['location'][1]),
+ z=float(obs_configs['location'][2]))
+ rotation = carla.Rotation(
+ roll=float(obs_configs['rotation'][0]),
+ pitch=float(obs_configs['rotation'][1]),
+ yaw=float(obs_configs['rotation'][2]))
+
+ self._camera_transform = carla.Transform(location, rotation)
+
+ self._sensor = None
+ self._queue_timeout = 10.0
+ self._points_queue = None
+ if self._render_o3d:
+ self._point_list = o3d.geometry.PointCloud()
+ self._point_list.points = o3d.utility.Vector3dVector(10 * np.random.randn(1000, 3))
+
+ self.vis = o3d.visualization.Visualizer()
+ self.vis.create_window(
+ window_name='Carla Lidar',
+ width=960,
+ height=540,
+ left=480,
+ top=270)
+ self.vis.get_render_option().background_color = [0.05, 0.05, 0.05]
+ self.vis.get_render_option().point_size = 1
+ self.vis.get_render_option().show_coordinate_frame = True
+ if self._show_axis:
+ add_open3d_axis(self.vis)
+ self.vis.add_geometry(self._point_list)
+
+ super(ObsManager, self).__init__()
+
+ def _define_obs_space(self):
+
+ self.obs_space = spaces.Dict({
+ 'frame': spaces.Discrete(2**32-1),
+ 'data': spaces.Dict({
+ 'x': spaces.Box(low=-np.inf, high=np.inf, shape=(1, ), dtype=np.float32),
+ 'y': spaces.Box(low=-np.inf, high=np.inf, shape=(1, ), dtype=np.float32),
+ 'z': spaces.Box(low=-np.inf, high=np.inf, shape=(1, ), dtype=np.float32),
+ 'i': spaces.Box(low=0, high=1, shape=(1, ), dtype=np.float32),
+ })
+ })
+
+ def attach_ego_vehicle(self, parent_actor):
+ self._points_queue = Queue()
+
+ self._world = parent_actor.vehicle.get_world()
+
+ bp = self._world.get_blueprint_library().find("sensor."+self._sensor_type)
+ for key, value in self._lidar_options.items():
+ bp.set_attribute(key, str(value))
+ if self._no_noise:
+ bp.set_attribute('dropoff_general_rate', '0.0')
+ bp.set_attribute('dropoff_intensity_limit', '1.0')
+ bp.set_attribute('dropoff_zero_intensity', '0.0')
+ else:
+ bp.set_attribute('noise_stddev', '0.2')
+
+ self._sensor = self._world.spawn_actor(bp, self._camera_transform, attach_to=parent_actor.vehicle)
+ weak_self = weakref.ref(self)
+ self._sensor.listen(lambda data: self._parse_points(weak_self, data))
+
+ def get_observation(self):
+ snap_shot = self._world.get_snapshot()
+ assert self._points_queue.qsize() <= 1
+
+ try:
+ frame, data = self._points_queue.get(True, self._queue_timeout)
+ assert snap_shot.frame == frame
+ except Empty:
+ raise Exception('RGB sensor took too long!')
+
+ if self._render_o3d:
+ self.vis.update_geometry(self._point_list)
+ self.vis.poll_events()
+ self.vis.update_renderer()
+ time.sleep(0.005)
+
+ obs = {'frame': frame,
+ 'data': data}
+
+ return obs
+
+ def clean(self):
+ if self._sensor and self._sensor.is_alive:
+ self._sensor.stop()
+ self._sensor.destroy()
+ self._sensor = None
+ self._world = None
+
+ self._points_queue = None
+
+ @staticmethod
+ def _parse_points(weak_self, data):
+ self = weak_self()
+
+ # get 4D points data
+ point_cloud = np.copy(np.frombuffer(data.raw_data, dtype=np.dtype('f4')))
+ point_cloud = np.reshape(point_cloud, (int(point_cloud.shape[0] / 4), 4))
+
+ # Isolate the intensity
+ intensity = point_cloud[:, -1]
+
+ # Isolate the 3D points data
+ points = point_cloud[:, :-1]
+ # points[:, :1] = -points[:, :1]
+
+ if self._render_o3d:
+ intensity_col = 1.0 - np.log(intensity) / np.log(np.exp(-0.004 * 100))
+ int_color = np.c_[
+ np.interp(intensity_col, VID_RANGE, VIRIDIS[:, 0]),
+ np.interp(intensity_col, VID_RANGE, VIRIDIS[:, 1]),
+ np.interp(intensity_col, VID_RANGE, VIRIDIS[:, 2])]
+
+ self._point_list.points = o3d.utility.Vector3dVector(points)
+ self._point_list.colors = o3d.utility.Vector3dVector(int_color)
+
+ self._points_queue.put((data.frame, {"points_xyz": points, "intensity": intensity}))
diff --git a/carla_gym/core/obs_manager/lidar/ray_cast_multi.py b/carla_gym/core/obs_manager/lidar/ray_cast_multi.py
new file mode 100644
index 0000000..ef9bc2b
--- /dev/null
+++ b/carla_gym/core/obs_manager/lidar/ray_cast_multi.py
@@ -0,0 +1,141 @@
+import time
+
+import numpy as np
+import weakref
+import carla
+from queue import Queue, Empty
+from gym import spaces
+from matplotlib import cm
+import open3d as o3d
+
+from carla_gym.core.obs_manager.lidar.ray_cast_semantic import ObsManager as OM
+from carla_gym.core.obs_manager.lidar.ray_cast_semantic import LABEL_COLORS
+
+
+class ObsManager(OM):
+ def __init__(self, obs_configs):
+ super(ObsManager, self).__init__(obs_configs)
+ self._camera_transform_list = []
+ self._points_queue_list = {}
+ self._sensor_list = {}
+ self._rotations = ((90, 0, 0), (90, 0, 0), (90, 0, 0), (0, 90, 0), (0, 90, 0), (0, 90, 0))
+ # self._scale = ((1, 0, 1), (0, 1, 1), (-1, 0, 1), (0, -1, 1), (0, 0, 1),
+ # (0.7, 0.7, 0.7), (-0.7, 0.7, 0.7), (-0.7, -0.7, 0.7), (0.7, -0.7, 0.7))
+ self._scale = ((1, 0, 0.8), (-1, 0, 0.8), (0, 0, 1), (0, 1, 0.8), (0, -1, 0.8), (0, 0, 1))
+ # self._scale = ((1, 1, 1), (-1, 1, 1), (-1, -1, 1), (1, -1, 1), (0, 0, 1))
+ # self._scale = ((1, 0, 1), (-1, 0, 1), (0, 0, 1),
+ # (0.5, 1, 1), (-0.5, 1, 1), (-0.5, -1, 1), (0.5, -1, 1))
+ self._box_size = (float(obs_configs['box_size'][0]),
+ float(obs_configs['box_size'][1]),
+ float(obs_configs['box_size'][2]))
+ x, y, z = self._box_size
+ for (x_scale, y_scale, z_scale), (roll, pitch, yaw) in zip(self._scale, self._rotations):
+ location = carla.Location(
+ x=x * x_scale,
+ y=y * y_scale,
+ z=z * z_scale
+ )
+ rotation = carla.Rotation(
+ roll=roll,
+ pitch=pitch,
+ yaw=yaw
+ )
+ self._camera_transform_list.append((carla.Transform(location, rotation)))
+
+ def create_sensor(self, world, bp, transform, vehicle, i):
+ self._points_queue_list[i] = Queue()
+ sensor = world.spawn_actor(bp, transform, attach_to=vehicle)
+ weak_self = weakref.ref(self)
+ sensor.listen(lambda data: self._parse_points_m(weak_self, data, i))
+ self._sensor_list[i] = sensor
+
+ def attach_ego_vehicle(self, parent_actor):
+ self._world = parent_actor.vehicle.get_world()
+ bp = self._world.get_blueprint_library().find("sensor." + self._sensor_type)
+ for key, value in self._lidar_options.items():
+ bp.set_attribute(key, str(value))
+
+ for i, camera_transform in enumerate(self._camera_transform_list):
+ self.create_sensor(self._world, bp, camera_transform, parent_actor.vehicle, i)
+
+ def get_observation(self):
+ snap_shot = self._world.get_snapshot()
+ points = []
+ for transf, points_queue_key in zip(self._camera_transform_list, self._points_queue_list):
+ points_queue = self._points_queue_list[points_queue_key]
+ assert points_queue.qsize() <= 1
+ # transf_matrix = transf.get_matrix()
+
+ try:
+ frame, data = points_queue.get(True, self._queue_timeout)
+ assert snap_shot.frame == frame
+ point_cloud = data['points_xyz']
+ # point_cloud = np.append(point_cloud, np.ones((point_cloud.shape[0], 1)), axis=1)
+ # point_cloud = np.dot(transf_matrix, point_cloud.T).T
+ # point_cloud = point_cloud[:, :-1]
+ ObjTag = data['ObjTag']
+ CosAngel = data['CosAngel']
+ ObjIdx = data['ObjIdx']
+ points.append(
+ np.concatenate([point_cloud, CosAngel[:, None], ObjIdx[:, None], ObjTag[:, None]], axis=1))
+ # obs.append({'frame': frame,
+ # 'data': data,
+ # 'transformation': transf_matrix})
+ assert snap_shot.frame == frame
+ except Empty:
+ raise Exception('RGB sensor took too long!')
+ points = np.concatenate(points, axis=0)
+ data_all = {
+ 'points_xyz': points[:, :3],
+ 'CosAngel': points[:, 3],
+ 'ObjIdx': points[:, 4].astype(np.uint32),
+ 'ObjTag': points[:, 5].astype(np.uint32),
+ }
+ obs = {'frame': snap_shot.frame,
+ 'data': data_all}
+
+ if self._render_o3d:
+ self._point_list.points = o3d.utility.Vector3dVector(points[:, :3])
+ self._point_list.colors = o3d.utility.Vector3dVector(points[5])
+ self.vis.update_geometry(self._point_list)
+ self.vis.poll_events()
+ self.vis.update_renderer()
+ time.sleep(0.005)
+
+ return obs
+
+ def clean(self):
+ for key in self._sensor_list:
+ sensor = self._sensor_list[key]
+ if sensor and sensor.is_alive:
+ sensor.stop()
+ sensor.destroy()
+ self._sensor_list = {}
+ self._world = None
+
+ self._points_queue_list = {}
+
+ @staticmethod
+ def _parse_points_m(weak_self, data, i):
+ self = weak_self()
+
+ # get 4D points data
+ point_cloud = np.frombuffer(data.raw_data, dtype=np.dtype([
+ ('x', np.float32), ('y', np.float32), ('z', np.float32),
+ ('CosAngle', np.float32), ('ObjIdx', np.uint32), ('ObjTag', np.uint32)]))
+
+ # Isolate the 3D points data
+ points = np.array([point_cloud['x'], point_cloud['y'], point_cloud['z']]).T
+ transf_matrix = self._camera_transform_list[i].get_matrix()
+
+ points = np.append(points, np.ones((points.shape[0], 1)), axis=1)
+ points = np.dot(transf_matrix, points.T).T
+ points = points[:, :-1]
+ idx = (-50 <= points[:, 0]) & (points[:, 0] <= 50) \
+ & (-40 <= points[:, 1]) & (points[:, 1] <= 40) \
+ & (-20 <= points[:, 2]) & (points[:, 2] <= 20)
+
+ self._points_queue_list[i].put((data.frame, {"points_xyz": points[idx],
+ "CosAngel": np.array(point_cloud['CosAngle'])[idx],
+ "ObjIdx": np.array(point_cloud['ObjIdx'])[idx],
+ "ObjTag": np.array(point_cloud['ObjTag'])[idx]}))
diff --git a/carla_gym/core/obs_manager/lidar/ray_cast_semantic.py b/carla_gym/core/obs_manager/lidar/ray_cast_semantic.py
new file mode 100644
index 0000000..e5d2223
--- /dev/null
+++ b/carla_gym/core/obs_manager/lidar/ray_cast_semantic.py
@@ -0,0 +1,219 @@
+import time
+
+import numpy as np
+import weakref
+import carla
+from queue import Queue, Empty
+from gym import spaces
+from matplotlib import cm
+import open3d as o3d
+
+from carla_gym.core.obs_manager.obs_manager import ObsManagerBase
+from constants import CARLA_FPS
+
+LABEL_COLORS = np.array([
+ (0, 0, 0), # unlabeled
+ # cityscape
+ (128, 64, 128), # road = 1
+ (244, 35, 232), # sidewalk = 2
+ (70, 70, 70), # building = 3
+ (102, 102, 156), # wall = 4
+ (190, 153, 153), # fence = 5
+ (153, 153, 153), # pole = 6
+ (250, 170, 30), # traffic light = 7
+ (220, 220, 0), # traffic sign = 8
+ (107, 142, 35), # vegetation = 9
+ (152, 251, 152), # terrain = 10
+ (70, 130, 180), # sky = 11
+ (220, 20, 60), # pedestrian = 12
+ (255, 0, 0), # rider = 13
+ (0, 0, 142), # Car = 14
+ (0, 0, 70), # truck = 15
+ (0, 60, 100), # bs = 16
+ (0, 80, 100), # train = 17
+ (0, 0, 230), # motorcycle = 18
+ (119, 11, 32), # bicycle = 19
+ # custom
+ (110, 190, 160), # static = 20
+ (170, 120, 50), # dynamic = 21
+ (55, 90, 80), # other = 22
+ (45, 60, 150), # water = 23
+ (157, 234, 50), # road line = 24
+ (81, 0, 81), # grond = 25
+ (150, 100, 100), # bridge = 26
+ (230, 150, 140), # rail track = 27
+ (180, 165, 180) # gard rail = 28
+]) / 255.0 # normalize each channel [0-1] since is what Open3D uses
+
+
+def add_open3d_axis(vis):
+ """Add a small 3D axis on Open3D Visualizer"""
+ axis = o3d.geometry.LineSet()
+ axis.points = o3d.utility.Vector3dVector(np.array([
+ [0.0, 0.0, 0.0],
+ [1.0, 0.0, 0.0],
+ [0.0, 1.0, 0.0],
+ [0.0, 0.0, 1.0]]))
+ axis.lines = o3d.utility.Vector2iVector(np.array([
+ [0, 1],
+ [0, 2],
+ [0, 3]]))
+ axis.colors = o3d.utility.Vector3dVector(np.array([
+ [1.0, 0.0, 0.0],
+ [0.0, 1.0, 0.0],
+ [0.0, 0.0, 1.0]]))
+ vis.add_geometry(axis)
+
+
+class ObsManager(ObsManagerBase):
+ """
+ Template configs:
+ obs_configs = {
+ "module": "lidar.ray_cast",
+ "location": [-5.5, 0, 2.8],
+ "rotation": [0, 0, 0],
+ "frame_stack": 1,
+ "render_o3d": False,
+ "show_axis": False,
+ "lidar_options": {
+ "width": 1920,
+ "height": 1080,
+ # https://github.com/carla-simulator/leaderboard/blob/master/leaderboard/autoagents/agent_wrapper.py
+ "channels": 64,
+ "range": 100,
+ "rotation_frequency": 20
+ "points_per_second": 100000
+ "upper_fov": 15.0,
+ "lower_fov": 25.0, # -30.0
+ },
+ }
+ frame_stack: [Image(t-2), Image(t-1), Image(t)]
+ """
+
+ def __init__(self, obs_configs):
+
+ self._sensor_type = 'lidar.ray_cast_semantic'
+
+ self._lidar_options = obs_configs['lidar_options']
+ self._render_o3d = obs_configs["render_o3d"]
+ self._show_axis = obs_configs["show_axis"]
+
+ # rewrite the 'rotation_frequency' to the same as carla_fps
+ self._lidar_options['rotation_frequency'] = CARLA_FPS
+
+ # Coordinates are forward-right-up https://carla.readthedocs.io/en/latest/ref_sensors/
+ location = carla.Location(
+ x=float(obs_configs['location'][0]),
+ y=float(obs_configs['location'][1]),
+ z=float(obs_configs['location'][2]))
+ rotation = carla.Rotation(
+ roll=float(obs_configs['rotation'][0]),
+ pitch=float(obs_configs['rotation'][1]),
+ yaw=float(obs_configs['rotation'][2]))
+
+ self._camera_transform = carla.Transform(location, rotation)
+
+ self._world = None
+ self._sensor = None
+ self._queue_timeout = 10.0
+ self._points_queue = None
+ if self._render_o3d:
+ self._point_list = o3d.geometry.PointCloud()
+ self._point_list.points = o3d.utility.Vector3dVector(10 * np.random.randn(1000, 3))
+
+ self.vis = o3d.visualization.Visualizer()
+ self.vis.create_window(
+ window_name='Carla Lidar',
+ width=960,
+ height=540,
+ left=480,
+ top=270)
+ self.vis.get_render_option().background_color = [0.05, 0.05, 0.05]
+ self.vis.get_render_option().point_size = 1
+ self.vis.get_render_option().show_coordinate_frame = True
+ if self._show_axis:
+ add_open3d_axis(self.vis)
+ self.vis.add_geometry(self._point_list)
+
+ super(ObsManager, self).__init__()
+
+ def _define_obs_space(self):
+
+ self.obs_space = spaces.Dict({
+ 'frame': spaces.Discrete(2**32-1),
+ 'data': spaces.Dict({
+ 'x': spaces.Box(low=-np.inf, high=np.inf, shape=(1, ), dtype=np.float32),
+ 'y': spaces.Box(low=-np.inf, high=np.inf, shape=(1, ), dtype=np.float32),
+ 'z': spaces.Box(low=-np.inf, high=np.inf, shape=(1, ), dtype=np.float32),
+ 'CosAngle': spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32),
+ 'ObjIdx': spaces.Box(low=0, high=28, shape=(1,), dtype=np.uint32),
+ 'ObjTag': spaces.Box(low=0, high=28, shape=(1,), dtype=np.uint32),
+ })
+ })
+
+ def attach_ego_vehicle(self, parent_actor):
+ self._points_queue = Queue()
+
+ self._world = parent_actor.vehicle.get_world()
+
+ bp = self._world.get_blueprint_library().find("sensor."+self._sensor_type)
+ for key, value in self._lidar_options.items():
+ bp.set_attribute(key, str(value))
+
+ self._sensor = self._world.spawn_actor(bp, self._camera_transform, attach_to=parent_actor.vehicle)
+ weak_self = weakref.ref(self)
+ self._sensor.listen(lambda data: self._parse_points(weak_self, data))
+
+ def get_observation(self):
+ snap_shot = self._world.get_snapshot()
+ assert self._points_queue.qsize() <= 1
+
+ try:
+ frame, data = self._points_queue.get(True, self._queue_timeout)
+ assert snap_shot.frame == frame
+ except Empty:
+ raise Exception('RGB sensor took too long!')
+
+ if self._render_o3d:
+ self.vis.update_geometry(self._point_list)
+ self.vis.poll_events()
+ self.vis.update_renderer()
+ time.sleep(0.005)
+
+ obs = {'frame': frame,
+ 'data': data}
+
+ return obs
+
+ def clean(self):
+ if self._sensor and self._sensor.is_alive:
+ self._sensor.stop()
+ self._sensor.destroy()
+ self._sensor = None
+ self._world = None
+
+ self._points_queue = None
+
+ @staticmethod
+ def _parse_points(weak_self, data):
+ self = weak_self()
+
+ # get 4D points data
+ point_cloud = np.frombuffer(data.raw_data, dtype=np.dtype([
+ ('x', np.float32), ('y', np.float32), ('z', np.float32),
+ ('CosAngle', np.float32), ('ObjIdx', np.uint32), ('ObjTag', np.uint32)]))
+
+ # Isolate the 3D points data
+ points = np.array([point_cloud['x'], point_cloud['y'], point_cloud['z']]).T
+
+ if self._render_o3d:
+ labels = np.array(point_cloud['ObjTag'])
+ int_color = LABEL_COLORS[labels]
+
+ self._point_list.points = o3d.utility.Vector3dVector(points)
+ self._point_list.colors = o3d.utility.Vector3dVector(int_color)
+
+ self._points_queue.put((data.frame, {"points_xyz": points,
+ # "CosAngel": np.array(point_cloud['CosAngle']),
+ # "ObjIdx": np.array(point_cloud['ObjIdx']),
+ "ObjTag": np.array(point_cloud['ObjTag'], dtype=np.uint8)}))
diff --git a/carla_gym/core/obs_manager/navigation/gnss.py b/carla_gym/core/obs_manager/navigation/gnss.py
new file mode 100644
index 0000000..5661b13
--- /dev/null
+++ b/carla_gym/core/obs_manager/navigation/gnss.py
@@ -0,0 +1,173 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import numpy as np
+import weakref
+import carla
+from gym import spaces
+from queue import Queue, Empty
+
+from carla_gym.core.obs_manager.obs_manager import ObsManagerBase
+from carla_gym.core.task_actor.common.navigation.map_utils import RoadOption
+
+from data.dataset_utils import preprocess_gps
+
+
+class ObsManager(ObsManagerBase):
+
+ def __init__(self, obs_configs):
+
+ self._gnss_sensor = None
+ self._imu_sensor = None
+ self._queue_timeout = 10.0
+ self._gnss_queue = None
+ self._imu_queue = None
+
+ super(ObsManager, self).__init__()
+
+ def _define_obs_space(self):
+ # accelerometer: m/s2
+ # gyroscope: rad/s2
+ # compass: rad wrt. north
+ imu_low = np.array([-1e6, -1e6, -1e6, -1e6, -1e6, -1e6, 0], dtype=np.float32)
+ imu_high = np.array([1e6, 1e6, 1e6, 1e6, 1e6, 1e6, 2*np.pi], dtype=np.float32)
+
+ self.obs_space = spaces.Dict({
+ 'gnss': spaces.Box(low=-10, high=10, shape=(3,), dtype=np.float32),
+ 'imu': spaces.Box(low=imu_low, high=imu_high, dtype=np.float32),
+ 'target_gps': spaces.Box(low=-1e3, high=1e3, shape=(3,), dtype=np.float32),
+ 'command': spaces.Box(low=-1, high=6, shape=(1,), dtype=np.int8),
+ 'target_gps_next': spaces.Box(low=-1e3, high=1e3, shape=(3,), dtype=np.float32),
+ 'command_next': spaces.Box(low=-1, high=6, shape=(1,), dtype=np.int8),
+ })
+
+ def attach_ego_vehicle(self, parent_actor):
+ self._world = parent_actor.vehicle.get_world()
+ self._parent_actor = parent_actor
+ self._idx = -1
+ self._gnss_queue = Queue()
+ self._imu_queue = Queue()
+
+ # gnss sensor
+ bp = self._world.get_blueprint_library().find('sensor.other.gnss')
+ bp.set_attribute('noise_alt_stddev', str(0.000005))
+ bp.set_attribute('noise_lat_stddev', str(0.000005))
+ bp.set_attribute('noise_lon_stddev', str(0.000005))
+ bp.set_attribute('noise_alt_bias', str(0.0))
+ bp.set_attribute('noise_lat_bias', str(0.0))
+ bp.set_attribute('noise_lon_bias', str(0.0))
+ sensor_location = carla.Location()
+ sensor_rotation = carla.Rotation()
+ sensor_transform = carla.Transform(location=sensor_location, rotation=sensor_rotation)
+ self._gnss_sensor = self._world.spawn_actor(bp, sensor_transform, attach_to=parent_actor.vehicle)
+ weak_self = weakref.ref(self)
+ self._gnss_sensor.listen(lambda gnss_data: self._parse_gnss(weak_self, gnss_data))
+
+ # imu sensor
+ bp = self._world.get_blueprint_library().find('sensor.other.imu')
+ bp.set_attribute('noise_accel_stddev_x', str(0.001))
+ bp.set_attribute('noise_accel_stddev_y', str(0.001))
+ bp.set_attribute('noise_accel_stddev_z', str(0.015))
+ bp.set_attribute('noise_gyro_stddev_x', str(0.001))
+ bp.set_attribute('noise_gyro_stddev_y', str(0.001))
+ bp.set_attribute('noise_gyro_stddev_z', str(0.001))
+ sensor_location = carla.Location()
+ sensor_rotation = carla.Rotation()
+ sensor_transform = carla.Transform(location=sensor_location, rotation=sensor_rotation)
+ self._imu_sensor = self._world.spawn_actor(bp, sensor_transform, attach_to=parent_actor.vehicle)
+ weak_self = weakref.ref(self)
+ self._imu_sensor.listen(lambda imu_data: self._parse_imu(weak_self, imu_data))
+
+ def get_observation(self):
+ snap_shot = self._world.get_snapshot()
+ assert self._gnss_queue.qsize() <= 1
+ assert self._imu_queue.qsize() <= 1
+
+ # get gnss
+ try:
+ frame, gnss_data = self._gnss_queue.get(True, self._queue_timeout)
+ assert snap_shot.frame == frame
+ except Empty:
+ raise Exception('gnss sensor took too long!')
+
+ # get imu
+ try:
+ frame, imu_data = self._imu_queue.get(True, self._queue_timeout)
+ assert snap_shot.frame == frame
+ except Empty:
+ raise Exception('imu sensor took too long!')
+
+ # target gps
+ global_plan_gps = self._parent_actor.global_plan_gps
+
+ next_gps, _ = global_plan_gps[self._idx+1]
+
+ loc_in_ev = preprocess_gps(gnss_data, next_gps, imu_data)
+ if np.sqrt(loc_in_ev.x**2+loc_in_ev.y**2) < 12.0 and loc_in_ev.x < 0.0:
+ self._idx += 1
+
+ self._idx = min(self._idx, len(global_plan_gps)-2)
+
+ _, road_option_0 = global_plan_gps[max(0, self._idx)]
+ gps_point, road_option_1 = global_plan_gps[self._idx+1]
+ # Gps waypoint after the immediate next waypoint.
+ gps_point2, road_option_2 = global_plan_gps[min(len(global_plan_gps) - 1, self._idx + 2)]
+
+ if (road_option_0 in [RoadOption.CHANGELANELEFT, RoadOption.CHANGELANERIGHT]) \
+ and (road_option_1 not in [RoadOption.CHANGELANELEFT, RoadOption.CHANGELANERIGHT]):
+ road_option = road_option_1
+ else:
+ road_option = road_option_0
+
+ # Handle road option for next next waypoint
+ if (road_option_1 in [RoadOption.CHANGELANELEFT, RoadOption.CHANGELANERIGHT]) \
+ and (road_option_2 not in [RoadOption.CHANGELANELEFT, RoadOption.CHANGELANERIGHT]):
+ road_option_next = road_option_2
+ else:
+ road_option_next = road_option_1
+
+ obs = {'gnss': gnss_data,
+ 'imu': imu_data,
+ 'target_gps': np.array(gps_point, dtype=np.float32),
+ 'command': np.array([road_option.value], dtype=np.int8),
+ 'target_gps_next': np.array(gps_point2, dtype=np.float32),
+ 'command_next': np.array([road_option_next.value], dtype=np.int8),
+ }
+ return obs
+
+ def clean(self):
+ if self._imu_sensor and self._imu_sensor.is_alive:
+ self._imu_sensor.stop()
+ self._imu_sensor.destroy()
+ self._imu_sensor = None
+
+ if self._gnss_sensor and self._gnss_sensor.is_alive:
+ self._gnss_sensor.stop()
+ self._gnss_sensor.destroy()
+ self._gnss_sensor = None
+
+ self._world = None
+ self._parent_actor = None
+
+ self._gnss_queue = None
+ self._imu_queue = None
+
+ @staticmethod
+ def _parse_gnss(weak_self, gnss_data):
+ self = weak_self()
+ data = np.array([gnss_data.latitude,
+ gnss_data.longitude,
+ gnss_data.altitude], dtype=np.float32)
+ self._gnss_queue.put((gnss_data.frame, data))
+
+ @staticmethod
+ def _parse_imu(weak_self, imu_data):
+ self = weak_self()
+ data = np.array([imu_data.accelerometer.x,
+ imu_data.accelerometer.y,
+ imu_data.accelerometer.z,
+ imu_data.gyroscope.x,
+ imu_data.gyroscope.y,
+ imu_data.gyroscope.z,
+ imu_data.compass,
+ ], dtype=np.float32)
+ self._imu_queue.put((imu_data.frame, data))
diff --git a/carla_gym/core/obs_manager/navigation/waypoint_plan.py b/carla_gym/core/obs_manager/navigation/waypoint_plan.py
new file mode 100644
index 0000000..0a58dfa
--- /dev/null
+++ b/carla_gym/core/obs_manager/navigation/waypoint_plan.py
@@ -0,0 +1,82 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import numpy as np
+from gym import spaces
+
+from carla_gym.core.obs_manager.obs_manager import ObsManagerBase
+
+import carla_gym.utils.transforms as trans_utils
+
+
+class ObsManager(ObsManagerBase):
+ """
+ Template config
+ "obs_configs" = {
+ "module": "navigation.waypoint_plan",
+ "steps": 10
+ }
+ [command, loc_x, loc_y]
+ """
+
+ def __init__(self, obs_configs):
+ self._steps = obs_configs['steps']
+ self._parent_actor = None
+ super(ObsManager, self).__init__()
+
+ def _define_obs_space(self):
+ self.obs_space = spaces.Dict({
+ 'location': spaces.Box(low=-100, high=1000, shape=(self._steps, 2), dtype=np.float32),
+ 'command': spaces.Box(low=-1, high=6, shape=(self._steps,), dtype=np.uint8),
+ 'road_id': spaces.Box(low=0, high=6000, shape=(self._steps,), dtype=np.uint8),
+ 'lane_id': spaces.Box(low=-20, high=20, shape=(self._steps,), dtype=np.int8),
+ 'is_junction': spaces.MultiBinary(self._steps)})
+
+ def attach_ego_vehicle(self, parent_actor):
+ self._parent_actor = parent_actor
+ self._world = self._parent_actor.vehicle.get_world()
+
+ def get_observation(self):
+ ev_transform = self._parent_actor.vehicle.get_transform()
+
+ route_plan = self._parent_actor.route_plan
+
+ route_length = len(route_plan)
+ location_list = []
+ command_list = []
+ road_id = []
+ lane_id = []
+ is_junction = []
+ for i in range(self._steps):
+ if i < route_length:
+ waypoint, road_option = route_plan[i]
+ else:
+ waypoint, road_option = route_plan[-1]
+
+ wp_location_world_coord = waypoint.transform.location
+ wp_location_actor_coord = trans_utils.loc_global_to_ref(wp_location_world_coord, ev_transform)
+ location_list.append([wp_location_actor_coord.x, wp_location_actor_coord.y])
+ command_list.append(road_option.value)
+ road_id.append(waypoint.road_id)
+ lane_id.append(waypoint.lane_id)
+ is_junction.append(waypoint.is_junction)
+
+ obs_dict = {
+ 'location': np.array(location_list, dtype=np.float32),
+ 'command': np.array(command_list, dtype=np.int8),
+ 'road_id': np.array(road_id, dtype=np.int8),
+ 'lane_id': np.array(lane_id, dtype=np.int8),
+ 'is_junction': np.array(is_junction, dtype=np.int8)
+ }
+ return obs_dict
+
+ def clean(self):
+ self._parent_actor = None
+ self._world = None
+
+# VOID = 0
+# LEFT = 1
+# RIGHT = 2
+# STRAIGHT = 3
+# LANEFOLLOW = 4
+# CHANGELANELEFT = 5
+# CHANGELANERIGHT = 6
diff --git a/carla_gym/core/obs_manager/object_finder/ego.py b/carla_gym/core/obs_manager/object_finder/ego.py
new file mode 100644
index 0000000..cca4d7c
--- /dev/null
+++ b/carla_gym/core/obs_manager/object_finder/ego.py
@@ -0,0 +1,76 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import numpy as np
+import carla
+from gym import spaces
+from carla_gym.core.obs_manager.obs_manager import ObsManagerBase
+
+
+class ObsManager(ObsManagerBase):
+ def __init__(self, obs_configs):
+ self._parent_actor = None
+ self._map = None
+ super(ObsManager, self).__init__()
+
+ def _define_obs_space(self):
+ self.obs_space = spaces.Dict({
+ 'actor_location': spaces.Box(low=-1000, high=1000, shape=(3,), dtype=np.float32),
+ 'actor_rotation': spaces.Box(low=-180, high=180, shape=(3,), dtype=np.float32),
+ 'waypoint_location': spaces.Box(low=-1000, high=1000, shape=(3,), dtype=np.float32),
+ 'waypoint_rotation': spaces.Box(low=-180, high=180, shape=(3,), dtype=np.float32),
+ 'road_id': spaces.Discrete(5000),
+ 'section_id': spaces.Discrete(5000),
+ 'lane_id': spaces.Box(low=-20, high=20, shape=(1,), dtype=np.int8),
+ 'is_junction': spaces.Discrete(2),
+ 'lane_change': spaces.Discrete(4),
+ 'extent': spaces.Box(low=0, high=20, shape=(3,), dtype=np.float32),
+ 'speed_limit': spaces.Box(low=0, high=200, shape=(1,), dtype=np.float32)
+ })
+
+ def attach_ego_vehicle(self, parent_actor):
+ self._parent_actor = parent_actor
+ self._map = parent_actor.vehicle.get_world().get_map()
+
+ def get_observation(self):
+
+ actor_transform = self._parent_actor.vehicle.get_transform()
+
+ actor_location = [actor_transform.location.x,
+ actor_transform.location.y,
+ actor_transform.location.z]
+ actor_rotation = [actor_transform.rotation.roll,
+ actor_transform.rotation.pitch,
+ actor_transform.rotation.yaw]
+
+ actor_wp = self._map.get_waypoint(actor_transform.location, project_to_road=True,
+ lane_type=carla.LaneType.Driving)
+
+ waypoint_location = [actor_wp.transform.location.x,
+ actor_wp.transform.location.y,
+ actor_wp.transform.location.z]
+ waypoint_rotation = [actor_wp.transform.rotation.roll,
+ actor_wp.transform.rotation.pitch,
+ actor_wp.transform.rotation.yaw]
+
+ extent = self._parent_actor.vehicle.bounding_box.extent
+ speed_limit = self._parent_actor.vehicle.get_speed_limit()
+
+ obs_dict = {
+ 'actor_location': np.array(actor_location, dtype=np.float32),
+ 'actor_rotation': np.array(actor_rotation, dtype=np.float32),
+ 'waypoint_location': np.array(waypoint_location, dtype=np.float32),
+ 'waypoint_rotation': np.array(waypoint_rotation, dtype=np.float32),
+ 'road_id': int(actor_wp.road_id),
+ 'section_id': int(actor_wp.section_id),
+ 'lane_id': int(actor_wp.lane_id),
+ 'is_junction': int(actor_wp.is_junction),
+ 'lane_change': int(actor_wp.lane_change),
+ 'extent': np.array([extent.x, extent.y, extent.z], dtype=np.float32),
+ 'speed_limit': np.float32(speed_limit)
+ }
+
+ return obs_dict
+
+ def clean(self):
+ self._parent_actor = None
+ self._map = None
diff --git a/carla_gym/core/obs_manager/object_finder/pedestrian.py b/carla_gym/core/obs_manager/object_finder/pedestrian.py
new file mode 100644
index 0000000..b5fea02
--- /dev/null
+++ b/carla_gym/core/obs_manager/object_finder/pedestrian.py
@@ -0,0 +1,136 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import numpy as np
+import carla
+from gym import spaces
+from carla_gym.core.obs_manager.obs_manager import ObsManagerBase
+
+import carla_gym.utils.transforms as trans_utils
+
+
+class ObsManager(ObsManagerBase):
+ """
+ Template config
+ obs_configs = {
+ "module": "object_finder.pedestrian",
+ "distance_threshold": 50.0,
+ "max_detection_number": 5
+ }
+ """
+
+ def __init__(self, obs_configs):
+ self._max_detection_number = obs_configs['max_detection_number']
+ self._distance_threshold = obs_configs['distance_threshold']
+ self._parent_actor = None
+ self._world = None
+
+ super(ObsManager, self).__init__()
+
+ def _define_obs_space(self):
+ self.obs_space = spaces.Dict(
+ {'frame': spaces.Discrete(2**32-1),
+ 'binary_mask': spaces.MultiBinary(self._max_detection_number),
+ 'location': spaces.Box(
+ low=-self._distance_threshold, high=self._distance_threshold, shape=(self._max_detection_number, 3),
+ dtype=np.float32),
+ 'rotation': spaces.Box(
+ low=-180, high=180, shape=(self._max_detection_number, 3),
+ dtype=np.float32),
+ 'extent': spaces.Box(
+ low=0, high=5, shape=(self._max_detection_number, 3),
+ dtype=np.float32),
+ 'absolute_velocity': spaces.Box(
+ low=-5, high=5, shape=(self._max_detection_number, 3),
+ dtype=np.float32),
+ 'on_sidewalk': spaces.MultiBinary(self._max_detection_number),
+ 'road_id': spaces.Box(
+ low=0, high=5000, shape=(self._max_detection_number, 1),
+ dtype=np.int8),
+ 'lane_id': spaces.Box(
+ low=-20, high=20, shape=(self._max_detection_number, 1),
+ dtype=np.int8)})
+
+ def attach_ego_vehicle(self, parent_actor):
+ self._parent_actor = parent_actor
+ self._world = parent_actor.vehicle.get_world()
+ self._map = self._world.get_map()
+
+ def get_observation(self):
+ ev_transform = self._parent_actor.vehicle.get_transform()
+ ev_location = ev_transform.location
+ def dist_to_actor(w): return w.get_location().distance(ev_location)
+
+ surrounding_pedestrians = []
+ pedestrian_list = self._world.get_actors().filter("*walker.pedestrian*")
+ for pedestrian in pedestrian_list:
+ if dist_to_actor(pedestrian) <= self._distance_threshold:
+ surrounding_pedestrians.append(pedestrian)
+
+ sorted_surrounding_pedestrians = sorted(surrounding_pedestrians, key=dist_to_actor)
+
+ location, rotation, absolute_velocity = trans_utils.get_loc_rot_vel_in_ev(
+ sorted_surrounding_pedestrians, ev_transform)
+
+ binary_mask, extent, on_sidewalk, road_id, lane_id = [], [], [], [], []
+ for ped in sorted_surrounding_pedestrians[:self._max_detection_number]:
+ binary_mask.append(1)
+
+ bbox_extent = ped.bounding_box.extent
+ extent.append([bbox_extent.x, bbox_extent.y, bbox_extent.z])
+
+ loc = ped.get_location()
+ wp = self._map.get_waypoint(loc, project_to_road=False, lane_type=carla.LaneType.Driving)
+ if wp is None:
+ on_sidewalk.append(1)
+ else:
+ on_sidewalk.append(0)
+ wp = self._map.get_waypoint(loc)
+ road_id.append(wp.road_id)
+ lane_id.append(wp.lane_id)
+
+ for i in range(self._max_detection_number - len(binary_mask)):
+ binary_mask.append(0)
+ location.append([0, 0, 0])
+ rotation.append([0, 0, 0])
+ absolute_velocity.append([0, 0, 0])
+ extent.append([0, 0, 0])
+ on_sidewalk.append(0)
+ road_id.append(0)
+ lane_id.append(0)
+
+ obs_dict = {
+ 'frame': self._world.get_snapshot().frame,
+ 'binary_mask': np.array(binary_mask, dtype=np.int8),
+ 'location': np.array(location, dtype=np.float32),
+ 'rotation': np.array(rotation, dtype=np.float32),
+ 'absolute_velocity': np.array(absolute_velocity, dtype=np.float32),
+ 'extent': np.array(extent, dtype=np.float32),
+ 'on_sidewalk': np.array(on_sidewalk, dtype=np.int8),
+ 'road_id': np.array(road_id, dtype=np.int8),
+ 'lane_id': np.array(lane_id, dtype=np.int8)
+ }
+
+ return obs_dict
+
+ def clean(self):
+ self._parent_actor = None
+ self._world = None
+ self._map = None
+
+ # self._debug_draw(sorted_surrounding_pedestrians)
+ def _debug_draw(self, pedestrian_list):
+ # self._world.debug.draw_point(
+ # ev_location + carla.Location(z=2.0),
+ # color=carla.Color(g=255),
+ # life_time=0.1)
+ # extent = carla.Vector3D(x=5.0, y=5.0, z=0.0)
+ # box = carla.BoundingBox(extent=extent, location=ev_location+ carla.Location(z=1.0))
+ # box = self._parent_actor.vehicle.bounding_box
+ # box.location += ev_location
+ # self._world.debug.draw_box(box, rotation=self._parent_actor.vehicle.get_transform(
+ # ).rotation, color=carla.Color(g=255), life_time=0.05)
+ for ped in pedestrian_list:
+ self._world.debug.draw_point(
+ ped.get_location(),
+ color=carla.Color(b=255),
+ life_time=0.1)
diff --git a/carla_gym/core/obs_manager/object_finder/stop_sign.py b/carla_gym/core/obs_manager/object_finder/stop_sign.py
new file mode 100644
index 0000000..bece64c
--- /dev/null
+++ b/carla_gym/core/obs_manager/object_finder/stop_sign.py
@@ -0,0 +1,40 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+from gym import spaces
+import carla
+from carla_gym.core.obs_manager.obs_manager import ObsManagerBase
+
+
+class ObsManager(ObsManagerBase):
+
+ def __init__(self, obs_configs):
+ self._parent_actor = None
+ self._distance_threshold = obs_configs['distance_threshold']
+
+ super(ObsManager, self).__init__()
+
+ def _define_obs_space(self):
+ self.obs_space = spaces.Dict({
+ 'at_stop_sign': spaces.Discrete(2)
+ })
+
+ def attach_ego_vehicle(self, parent_actor):
+ self._parent_actor = parent_actor
+
+ def get_observation(self):
+ ev_loc = self._parent_actor.vehicle.get_location()
+ stop_sign = self._parent_actor.criteria_stop._target_stop_sign
+
+ at_stop_sign = 0
+ if (stop_sign is not None) and (not self._parent_actor.criteria_stop._stop_completed):
+ stop_t = stop_sign.get_transform()
+ stop_loc = stop_t.transform(stop_sign.trigger_volume.location)
+
+ if carla.Location(stop_loc).distance(ev_loc) < self._distance_threshold:
+ at_stop_sign = 1
+
+ obs = {'at_stop_sign': at_stop_sign}
+ return obs
+
+ def clean(self):
+ self._parent_actor = None
diff --git a/carla_gym/core/obs_manager/object_finder/traffic_light_new.py b/carla_gym/core/obs_manager/object_finder/traffic_light_new.py
new file mode 100644
index 0000000..941208c
--- /dev/null
+++ b/carla_gym/core/obs_manager/object_finder/traffic_light_new.py
@@ -0,0 +1,37 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import numpy as np
+import carla
+from gym import spaces
+from carla_gym.core.obs_manager.obs_manager import ObsManagerBase
+
+
+class ObsManager(ObsManagerBase):
+ # Template config
+ # obs_configs = {
+ # "module": "object_finder.traffic_light_new",
+ # }
+ def __init__(self, obs_configs):
+ self._parent_actor = None
+ super(ObsManager, self).__init__()
+
+ def _define_obs_space(self):
+ self.obs_space = spaces.Dict({
+ 'at_red_light': spaces.Discrete(2),
+ 'trigger_location': spaces.Box(low=-5000, high=5000, shape=(3,), dtype=np.float32),
+ 'trigger_square': spaces.Box(low=-5000, high=5000, shape=(5, 3), dtype=np.float32)
+ })
+
+ def attach_ego_vehicle(self, parent_actor):
+ self._parent_actor = parent_actor
+
+ def get_observation(self):
+ obs = {
+ 'at_red_light': int(self._parent_actor.vehicle.is_at_traffic_light()),
+ 'trigger_location': np.zeros((3,), dtype=np.float32),
+ 'trigger_square': np.zeros((5, 3), dtype=np.float32)
+ }
+ return obs
+
+ def clean(self):
+ self._parent_actor = None
\ No newline at end of file
diff --git a/carla_gym/core/obs_manager/object_finder/vehicle.py b/carla_gym/core/obs_manager/object_finder/vehicle.py
new file mode 100644
index 0000000..d56e62f
--- /dev/null
+++ b/carla_gym/core/obs_manager/object_finder/vehicle.py
@@ -0,0 +1,112 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import numpy as np
+import carla
+from gym import spaces
+from carla_gym.core.obs_manager.obs_manager import ObsManagerBase
+
+import carla_gym.utils.transforms as trans_utils
+
+
+class ObsManager(ObsManagerBase):
+ """
+ Template config
+ obs_configs = {
+ "module": "object_finder.vehicle",
+ "distance_threshold": 50.0,
+ "max_detection_number": 5
+ }
+ """
+
+ def __init__(self, obs_configs):
+ self._max_detection_number = obs_configs['max_detection_number']
+ self._distance_threshold = obs_configs['distance_threshold']
+
+ self._parent_actor = None
+ self._world = None
+ self._map = None
+ super(ObsManager, self).__init__()
+
+ def _define_obs_space(self):
+ self.obs_space = spaces.Dict(
+ {'frame': spaces.Discrete(2**32-1),
+ 'binary_mask': spaces.MultiBinary(self._max_detection_number),
+ 'location': spaces.Box(
+ low=-self._distance_threshold, high=self._distance_threshold, shape=(self._max_detection_number, 3),
+ dtype=np.float32),
+ 'rotation': spaces.Box(
+ low=-180, high=180, shape=(self._max_detection_number, 3),
+ dtype=np.float32),
+ 'extent': spaces.Box(
+ low=0, high=20, shape=(self._max_detection_number, 3),
+ dtype=np.float32),
+ 'absolute_velocity': spaces.Box(
+ low=-10, high=50, shape=(self._max_detection_number, 3),
+ dtype=np.float32),
+ 'road_id': spaces.Box(
+ low=0, high=5000, shape=(self._max_detection_number, 1),
+ dtype=np.int8),
+ 'lane_id': spaces.Box(
+ low=-20, high=20, shape=(self._max_detection_number, 1),
+ dtype=np.int8)})
+
+ def attach_ego_vehicle(self, parent_actor):
+ self._parent_actor = parent_actor
+ self._world = self._parent_actor.vehicle.get_world()
+ self._map = self._world.get_map()
+
+ def get_observation(self):
+ ev_transform = self._parent_actor.vehicle.get_transform()
+ ev_location = ev_transform.location
+ def dist_to_ev(w): return w.get_location().distance(ev_location)
+
+ surrounding_vehicles = []
+ vehicle_list = self._world.get_actors().filter("*vehicle*")
+ for vehicle in vehicle_list:
+ has_different_id = self._parent_actor.vehicle.id != vehicle.id
+ is_within_distance = dist_to_ev(vehicle) <= self._distance_threshold
+ if has_different_id and is_within_distance:
+ surrounding_vehicles.append(vehicle)
+
+ sorted_surrounding_vehicles = sorted(surrounding_vehicles, key=dist_to_ev)
+
+ location, rotation, absolute_velocity = trans_utils.get_loc_rot_vel_in_ev(
+ sorted_surrounding_vehicles, ev_transform)
+
+ binary_mask, extent, road_id, lane_id = [], [], [], []
+ for sv in sorted_surrounding_vehicles[:self._max_detection_number]:
+ binary_mask.append(1)
+
+ bbox_extent = sv.bounding_box.extent
+ extent.append([bbox_extent.x, bbox_extent.y, bbox_extent.z])
+
+ loc = sv.get_location()
+ wp = self._map.get_waypoint(loc)
+ road_id.append(wp.road_id)
+ lane_id.append(wp.lane_id)
+
+ for i in range(self._max_detection_number - len(binary_mask)):
+ binary_mask.append(0)
+ location.append([0, 0, 0])
+ rotation.append([0, 0, 0])
+ extent.append([0, 0, 0])
+ absolute_velocity.append([0, 0, 0])
+ road_id.append(0)
+ lane_id.append(0)
+
+ obs_dict = {
+ 'frame': self._world.get_snapshot().frame,
+ 'binary_mask': np.array(binary_mask, dtype=np.int8),
+ 'location': np.array(location, dtype=np.float32),
+ 'rotation': np.array(rotation, dtype=np.float32),
+ 'extent': np.array(extent, dtype=np.float32),
+ 'absolute_velocity': np.array(absolute_velocity, dtype=np.float32),
+ 'road_id': np.array(road_id, dtype=np.int8),
+ 'lane_id': np.array(lane_id, dtype=np.int8)
+ }
+ return obs_dict
+
+ def clean(self):
+ self._parent_actor = None
+ self._world = None
+ self._map = None
diff --git a/carla_gym/core/obs_manager/obs_manager.py b/carla_gym/core/obs_manager/obs_manager.py
new file mode 100644
index 0000000..c769ef4
--- /dev/null
+++ b/carla_gym/core/obs_manager/obs_manager.py
@@ -0,0 +1,21 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+# base class for observation managers
+
+
+class ObsManagerBase(object):
+
+ def __init__(self):
+ self._define_obs_space()
+
+ def _define_obs_space(self):
+ raise NotImplementedError
+
+ def attach_ego_vehicle(self, parent_actor):
+ raise NotImplementedError
+
+ def get_observation(self):
+ raise NotImplementedError
+
+ def clean(self):
+ raise NotImplementedError
diff --git a/carla_gym/core/obs_manager/obs_manager_handler.py b/carla_gym/core/obs_manager/obs_manager_handler.py
new file mode 100644
index 0000000..1d541e7
--- /dev/null
+++ b/carla_gym/core/obs_manager/obs_manager_handler.py
@@ -0,0 +1,50 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+from importlib import import_module
+from gym import spaces
+
+
+class ObsManagerHandler(object):
+
+ def __init__(self, obs_configs):
+ self._obs_managers = {}
+ self._obs_configs = obs_configs
+ self._init_obs_managers()
+
+ def get_observation(self, timestamp):
+ obs_dict = {}
+ for ev_id, om_dict in self._obs_managers.items():
+ obs_dict[ev_id] = {}
+ for obs_id, om in om_dict.items():
+ obs_dict[ev_id][obs_id] = om.get_observation()
+ return obs_dict
+
+ @property
+ def observation_space(self):
+ obs_spaces_dict = {}
+ for ev_id, om_dict in self._obs_managers.items():
+ ev_obs_spaces_dict = {}
+ for obs_id, om in om_dict.items():
+ ev_obs_spaces_dict[obs_id] = om.obs_space
+ obs_spaces_dict[ev_id] = spaces.Dict(ev_obs_spaces_dict)
+ return spaces.Dict(obs_spaces_dict)
+
+ def reset(self, ego_vehicles):
+ self._init_obs_managers()
+
+ for ev_id, ev_actor in ego_vehicles.items():
+ for obs_id, om in self._obs_managers[ev_id].items():
+ om.attach_ego_vehicle(ev_actor)
+
+ def clean(self):
+ for ev_id, om_dict in self._obs_managers.items():
+ for obs_id, om in om_dict.items():
+ om.clean()
+ self._obs_managers = {}
+
+ def _init_obs_managers(self):
+ for ev_id, ev_obs_configs in self._obs_configs.items():
+ self._obs_managers[ev_id] = {}
+ for obs_id, obs_config in ev_obs_configs.items():
+ ObsManager = getattr(import_module('carla_gym.core.obs_manager.' + obs_config["module"]), 'ObsManager')
+ self._obs_managers[ev_id][obs_id] = ObsManager(obs_config)
diff --git a/carla_gym/core/task_actor/common/criteria/blocked.py b/carla_gym/core/task_actor/common/criteria/blocked.py
new file mode 100644
index 0000000..6210139
--- /dev/null
+++ b/carla_gym/core/task_actor/common/criteria/blocked.py
@@ -0,0 +1,32 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import numpy as np
+
+
+class Blocked():
+
+ def __init__(self, speed_threshold=0.1, below_threshold_max_time=90.0):
+ self._speed_threshold = speed_threshold
+ self._below_threshold_max_time = below_threshold_max_time
+ self._time_last_valid_state = None
+
+ def tick(self, vehicle, timestamp):
+ info = None
+ linear_speed = self._calculate_speed(vehicle.get_velocity())
+
+ if linear_speed < self._speed_threshold and self._time_last_valid_state:
+ if (timestamp['relative_simulation_time'] - self._time_last_valid_state) > self._below_threshold_max_time:
+ # The actor has been "blocked" for too long
+ ev_loc = vehicle.get_location()
+ info = {
+ 'step': timestamp['step'],
+ 'simulation_time': timestamp['relative_simulation_time'],
+ 'ev_loc': [ev_loc.x, ev_loc.y, ev_loc.z]
+ }
+ else:
+ self._time_last_valid_state = timestamp['relative_simulation_time']
+ return info
+
+ @staticmethod
+ def _calculate_speed(carla_velocity):
+ return np.linalg.norm([carla_velocity.x, carla_velocity.y])
diff --git a/carla_gym/core/task_actor/common/criteria/collision.py b/carla_gym/core/task_actor/common/criteria/collision.py
new file mode 100644
index 0000000..f732fc9
--- /dev/null
+++ b/carla_gym/core/task_actor/common/criteria/collision.py
@@ -0,0 +1,119 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import carla
+import weakref
+import numpy as np
+
+
+class Collision():
+ def __init__(self, vehicle, carla_world, intensity_threshold=0.0,
+ min_area_of_collision=3, max_area_of_collision=5, max_id_time=5):
+ blueprint = carla_world.get_blueprint_library().find('sensor.other.collision')
+ self._collision_sensor = carla_world.spawn_actor(blueprint, carla.Transform(), attach_to=vehicle)
+ self._collision_sensor.listen(lambda event: self._on_collision(weakref.ref(self), event))
+ self._collision_info = None
+
+ self.registered_collisions = []
+ self.last_id = None
+ self.collision_time = None
+
+ # If closer than this distance, the collision is ignored
+ self._min_area_of_collision = min_area_of_collision
+ # If further than this distance, the area is forgotten
+ self._max_area_of_collision = max_area_of_collision
+ # Amount of time the last collision if is remembered
+ self._max_id_time = max_id_time
+ # intensity_threshold, LBC uses 400, leaderboard does not use it (set to 0)
+ self._intensity_threshold = intensity_threshold
+
+ def tick(self, vehicle, timestamp):
+ ev_loc = vehicle.get_location()
+ new_registered_collisions = []
+ # Loops through all the previous registered collisions
+ for collision_location in self.registered_collisions:
+ distance = ev_loc.distance(collision_location)
+ # If far away from a previous collision, forget it
+ if distance <= self._max_area_of_collision:
+ new_registered_collisions.append(collision_location)
+
+ self.registered_collisions = new_registered_collisions
+
+ if self.last_id and timestamp['relative_simulation_time'] - self.collision_time > self._max_id_time:
+ self.last_id = None
+
+ info = self._collision_info
+ self._collision_info = None
+ if info is not None:
+ info['step'] -= timestamp['start_frame']
+ info['simulation_time'] -= timestamp['start_simulation_time']
+ return info
+
+ @staticmethod
+ def _on_collision(weakself, event):
+ self = weakself()
+ if not self:
+ return
+ # Ignore the current one if it's' the same id as before
+ if self.last_id == event.other_actor.id:
+ return
+ # Ignore if it's too close to a previous collision (avoid micro collisions)
+ ev_loc = event.actor.get_transform().location
+ for collision_location in self.registered_collisions:
+ if ev_loc.distance(collision_location) <= self._min_area_of_collision:
+ return
+ # Ignore if its intensity is smaller than self._intensity_threshold
+ impulse = event.normal_impulse
+ intensity = np.linalg.norm([impulse.x, impulse.y, impulse.z])
+ if intensity < self._intensity_threshold:
+ return
+
+ # collision_type
+ if ('static' in event.other_actor.type_id or 'traffic' in event.other_actor.type_id) \
+ and 'sidewalk' not in event.other_actor.type_id:
+ collision_type = 0 # TrafficEventType.COLLISION_STATIC
+ elif 'vehicle' in event.other_actor.type_id:
+ collision_type = 1 # TrafficEventType.COLLISION_VEHICLE
+ elif 'walker' in event.other_actor.type_id:
+ collision_type = 2 # TrafficEventType.COLLISION_PEDESTRIAN
+ else:
+ collision_type = -1
+
+ # write to info, all quantities in in world coordinate
+ event_loc = event.transform.location
+ event_rot = event.transform.rotation
+ oa_loc = event.other_actor.get_transform().location
+ oa_rot = event.other_actor.get_transform().rotation
+ oa_vel = event.other_actor.get_velocity()
+ ev_rot = event.actor.get_transform().rotation
+ ev_vel = event.actor.get_velocity()
+
+ self._collision_info = {
+ 'step': event.frame,
+ 'simulation_time': event.timestamp,
+ 'collision_type': collision_type,
+ 'other_actor_id': event.other_actor.id,
+ 'other_actor_type_id': event.other_actor.type_id,
+ 'intensity': intensity,
+ 'normal_impulse': [impulse.x, impulse.y, impulse.z],
+ 'event_loc': [event_loc.x, event_loc.y, event_loc.z],
+ 'event_rot': [event_rot.roll, event_rot.pitch, event_rot.yaw],
+ 'ev_loc': [ev_loc.x, ev_loc.y, ev_loc.z],
+ 'ev_rot': [ev_rot.roll, ev_rot.pitch, ev_rot.yaw],
+ 'ev_vel': [ev_vel.x, ev_vel.y, ev_vel.z],
+ 'oa_loc': [oa_loc.x, oa_loc.y, oa_loc.z],
+ 'oa_rot': [oa_rot.roll, oa_rot.pitch, oa_rot.yaw],
+ 'oa_vel': [oa_vel.x, oa_vel.y, oa_vel.z]
+ }
+
+ self.collision_time = event.timestamp
+
+ self.registered_collisions.append(ev_loc)
+
+ # Number 0: static objects -> ignore it
+ if event.other_actor.id != 0:
+ self.last_id = event.other_actor.id
+
+ def clean(self):
+ self._collision_sensor.stop()
+ self._collision_sensor.destroy()
+ self._collision_sensor = None
diff --git a/carla_gym/core/task_actor/common/criteria/encounter_light.py b/carla_gym/core/task_actor/common/criteria/encounter_light.py
new file mode 100644
index 0000000..d0e3b57
--- /dev/null
+++ b/carla_gym/core/task_actor/common/criteria/encounter_light.py
@@ -0,0 +1,28 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+from carla_gym.utils.traffic_light import TrafficLightHandler
+
+
+class EncounterLight():
+
+ def __init__(self, dist_threshold=7.5):
+ self._last_light_id = None
+ self._dist_threshold = dist_threshold
+
+ def tick(self, vehicle, timestamp):
+ info = None
+
+ light_state, light_loc, light_id = TrafficLightHandler.get_light_state(
+ vehicle, dist_threshold=self._dist_threshold)
+
+ if light_id is not None:
+ if light_id != self._last_light_id:
+ self._last_light_id = light_id
+ info = {
+ 'step': timestamp['step'],
+ 'simulation_time': timestamp['relative_simulation_time'],
+ 'id': light_id,
+ 'tl_loc': light_loc.tolist()
+ }
+
+ return info
diff --git a/carla_gym/core/task_actor/common/criteria/outside_route_lane.py b/carla_gym/core/task_actor/common/criteria/outside_route_lane.py
new file mode 100644
index 0000000..9d26177
--- /dev/null
+++ b/carla_gym/core/task_actor/common/criteria/outside_route_lane.py
@@ -0,0 +1,101 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import carla
+from carla_gym.utils.transforms import cast_angle
+
+
+class OutsideRouteLane():
+
+ def __init__(self, carla_map, vehicle_loc,
+ allowed_out_distance=1.3, max_allowed_vehicle_angle=120.0, max_allowed_waypint_angle=150.0):
+ self._map = carla_map
+ self._pre_ego_waypoint = self._map.get_waypoint(vehicle_loc)
+
+ self._allowed_out_distance = allowed_out_distance
+ self._max_allowed_vehicle_angle = max_allowed_vehicle_angle
+ self._max_allowed_waypint_angle = max_allowed_waypint_angle
+
+ self._outside_lane_active = False
+ self._wrong_lane_active = False
+ self._last_road_id = None
+ self._last_lane_id = None
+
+ def tick(self, vehicle, timestamp, distance_traveled):
+ ev_loc = vehicle.get_location()
+ ev_yaw = vehicle.get_transform().rotation.yaw
+ self._is_outside_driving_lanes(ev_loc)
+ self._is_at_wrong_lane(ev_loc, ev_yaw)
+
+ info = None
+ if self._outside_lane_active or self._wrong_lane_active:
+ info = {
+ 'step': timestamp['step'],
+ 'simulation_time': timestamp['relative_simulation_time'],
+ 'ev_loc': [ev_loc.x, ev_loc.y, ev_loc.z],
+ 'distance_traveled': distance_traveled,
+ 'outside_lane': self._outside_lane_active,
+ 'wrong_lane': self._wrong_lane_active
+ }
+ return info
+
+ def _is_outside_driving_lanes(self, location):
+ """
+ Detects if the ego_vehicle is outside driving/parking lanes
+ """
+
+ current_driving_wp = self._map.get_waypoint(location, lane_type=carla.LaneType.Driving, project_to_road=True)
+ current_parking_wp = self._map.get_waypoint(location, lane_type=carla.LaneType.Parking, project_to_road=True)
+
+ driving_distance = location.distance(current_driving_wp.transform.location)
+ if current_parking_wp is not None: # Some towns have no parking
+ parking_distance = location.distance(current_parking_wp.transform.location)
+ else:
+ parking_distance = float('inf')
+
+ if driving_distance >= parking_distance:
+ distance = parking_distance
+ lane_width = current_parking_wp.lane_width
+ else:
+ distance = driving_distance
+ lane_width = current_driving_wp.lane_width
+
+ self._outside_lane_active = distance > (lane_width / 2 + self._allowed_out_distance)
+
+ def _is_at_wrong_lane(self, location, yaw):
+ """
+ Detects if the ego_vehicle has invaded a wrong driving lane
+ """
+
+ current_waypoint = self._map.get_waypoint(location, lane_type=carla.LaneType.Driving, project_to_road=True)
+ current_lane_id = current_waypoint.lane_id
+ current_road_id = current_waypoint.road_id
+
+ # Lanes and roads are too chaotic at junctions
+ if current_waypoint.is_junction:
+ self._wrong_lane_active = False
+ elif self._last_road_id != current_road_id or self._last_lane_id != current_lane_id:
+
+ # Route direction can be considered continuous, except after exiting a junction.
+ if self._pre_ego_waypoint.is_junction:
+ # cast angle to [-180, +180)
+ vehicle_lane_angle = cast_angle(
+ current_waypoint.transform.rotation.yaw - yaw)
+
+ self._wrong_lane_active = abs(vehicle_lane_angle) > self._max_allowed_vehicle_angle
+
+ else:
+ # Check for a big gap in waypoint directions.
+ waypoint_angle = cast_angle(
+ current_waypoint.transform.rotation.yaw - self._pre_ego_waypoint.transform.rotation.yaw)
+
+ if abs(waypoint_angle) >= self._max_allowed_waypint_angle:
+ # Is the ego vehicle going back to the lane, or going out? Take the opposite
+ self._wrong_lane_active = not bool(self._wrong_lane_active)
+ else:
+ # Changing to a lane with the same direction
+ self._wrong_lane_active = False
+
+ # Remember the last state
+ self._last_lane_id = current_lane_id
+ self._last_road_id = current_road_id
+ self._pre_ego_waypoint = current_waypoint
diff --git a/carla_gym/core/task_actor/common/criteria/route_deviation.py b/carla_gym/core/task_actor/common/criteria/route_deviation.py
new file mode 100644
index 0000000..baf4e1d
--- /dev/null
+++ b/carla_gym/core/task_actor/common/criteria/route_deviation.py
@@ -0,0 +1,36 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+class RouteDeviation():
+
+ def __init__(self, offroad_min=15, offroad_max=30, max_route_percentage=0.3):
+ self._offroad_min = offroad_min
+ self._offroad_max = offroad_max
+ self._max_route_percentage = max_route_percentage
+ self._out_route_distance = 0.0
+
+ def tick(self, vehicle, timestamp, ref_waypoint, distance_traveled, route_length):
+ ev_loc = vehicle.get_location()
+
+ distance = ev_loc.distance(ref_waypoint.transform.location)
+
+ # fail if off_route is True
+ off_route_max = distance > self._offroad_max
+
+ # fail if off_safe_route more than 30% of total route length
+ off_route_min = False
+ if distance > self._offroad_min:
+ self._out_route_distance += distance_traveled
+ out_route_percentage = self._out_route_distance / route_length
+ if out_route_percentage > self._max_route_percentage:
+ off_route_min = True
+
+ info = None
+ if off_route_max or off_route_min:
+ info = {
+ 'step': timestamp['step'],
+ 'simulation_time': timestamp['relative_simulation_time'],
+ 'ev_loc': [ev_loc.x, ev_loc.y, ev_loc.z],
+ 'off_route_max': off_route_max,
+ 'off_route_min': off_route_min
+ }
+ return info
diff --git a/carla_gym/core/task_actor/common/criteria/run_red_light.py b/carla_gym/core/task_actor/common/criteria/run_red_light.py
new file mode 100644
index 0000000..f967f20
--- /dev/null
+++ b/carla_gym/core/task_actor/common/criteria/run_red_light.py
@@ -0,0 +1,66 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import carla
+import shapely.geometry
+from carla_gym.utils.traffic_light import TrafficLightHandler
+
+
+class RunRedLight():
+
+ def __init__(self, carla_map, distance_light=30):
+ self._map = carla_map
+ self._last_red_light_id = None
+ self._distance_light = distance_light
+
+ def tick(self, vehicle, timestamp):
+ ev_tra = vehicle.get_transform()
+ ev_loc = ev_tra.location
+ ev_dir = ev_tra.get_forward_vector()
+ ev_extent = vehicle.bounding_box.extent.x
+
+ tail_close_pt = ev_tra.transform(carla.Location(x=-0.8 * ev_extent))
+ tail_far_pt = ev_tra.transform(carla.Location(x=-ev_extent - 1.0))
+ tail_wp = self._map.get_waypoint(tail_far_pt)
+
+ info = None
+ for idx_tl in range(TrafficLightHandler.num_tl):
+ traffic_light = TrafficLightHandler.list_tl_actor[idx_tl]
+ tl_tv_loc = TrafficLightHandler.list_tv_loc[idx_tl]
+ if tl_tv_loc.distance(ev_loc) > self._distance_light:
+ continue
+ if traffic_light.state != carla.TrafficLightState.Red:
+ continue
+ if self._last_red_light_id and self._last_red_light_id == traffic_light.id:
+ continue
+
+ for idx_wp in range(len(TrafficLightHandler.list_stopline_wps[idx_tl])):
+ wp = TrafficLightHandler.list_stopline_wps[idx_tl][idx_wp]
+ wp_dir = wp.transform.get_forward_vector()
+ dot_ve_wp = ev_dir.x * wp_dir.x + ev_dir.y * wp_dir.y + ev_dir.z * wp_dir.z
+
+ if tail_wp.road_id == wp.road_id and tail_wp.lane_id == wp.lane_id and dot_ve_wp > 0:
+ # This light is red and is affecting our lane
+ stop_left_loc, stop_right_loc = TrafficLightHandler.list_stopline_vtx[idx_tl][idx_wp]
+ # Is the vehicle traversing the stop line?
+ if self._is_vehicle_crossing_line((tail_close_pt, tail_far_pt), (stop_left_loc, stop_right_loc)):
+ tl_loc = traffic_light.get_location()
+ # loc_in_ev = trans_utils.loc_global_to_ref(tl_loc, ev_tra)
+ self._last_red_light_id = traffic_light.id
+ info = {
+ 'step': timestamp['step'],
+ 'simulation_time': timestamp['relative_simulation_time'],
+ 'id': traffic_light.id,
+ 'tl_loc': [tl_loc.x, tl_loc.y, tl_loc.z],
+ 'ev_loc': [ev_loc.x, ev_loc.y, ev_loc.z]
+ }
+ return info
+
+ @staticmethod
+ def _is_vehicle_crossing_line(seg1, seg2):
+ """
+ check if vehicle crosses a line segment
+ """
+ line1 = shapely.geometry.LineString([(seg1[0].x, seg1[0].y), (seg1[1].x, seg1[1].y)])
+ line2 = shapely.geometry.LineString([(seg2[0].x, seg2[0].y), (seg2[1].x, seg2[1].y)])
+ inter = line1.intersection(line2)
+ return not inter.is_empty
diff --git a/carla_gym/core/task_actor/common/criteria/run_stop_sign.py b/carla_gym/core/task_actor/common/criteria/run_stop_sign.py
new file mode 100644
index 0000000..96a4382
--- /dev/null
+++ b/carla_gym/core/task_actor/common/criteria/run_stop_sign.py
@@ -0,0 +1,159 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import carla
+import numpy as np
+
+
+class RunStopSign():
+
+ def __init__(self, carla_world, proximity_threshold=50.0, speed_threshold=0.1, waypoint_step=1.0):
+ self._map = carla_world.get_map()
+ self._proximity_threshold = proximity_threshold
+ self._speed_threshold = speed_threshold
+ self._waypoint_step = waypoint_step
+
+ all_actors = carla_world.get_actors()
+ self._list_stop_signs = []
+ for _actor in all_actors:
+ if 'traffic.stop' in _actor.type_id:
+ self._list_stop_signs.append(_actor)
+
+ self._target_stop_sign = None
+ self._stop_completed = False
+ self._affected_by_stop = False
+
+ def tick(self, vehicle, timestamp):
+ info = None
+ ev_loc = vehicle.get_location()
+ ev_f_vec = vehicle.get_transform().get_forward_vector()
+
+ if self._target_stop_sign is None:
+ self._target_stop_sign = self._scan_for_stop_sign(vehicle.get_transform())
+ if self._target_stop_sign is not None:
+ stop_loc = self._target_stop_sign.get_location()
+ info = {
+ 'event': 'encounter',
+ 'step': timestamp['step'],
+ 'simulation_time': timestamp['relative_simulation_time'],
+ 'id': self._target_stop_sign.id,
+ 'stop_loc': [stop_loc.x, stop_loc.y, stop_loc.z],
+ 'ev_loc': [ev_loc.x, ev_loc.y, ev_loc.z]
+ }
+ else:
+ # we were in the middle of dealing with a stop sign
+ if not self._stop_completed:
+ # did the ego-vehicle stop?
+ current_speed = self._calculate_speed(vehicle.get_velocity())
+ if current_speed < self._speed_threshold:
+ self._stop_completed = True
+
+ if not self._affected_by_stop:
+ stop_t = self._target_stop_sign.get_transform()
+ transformed_tv = stop_t.transform(self._target_stop_sign.trigger_volume.location)
+ stop_extent = self._target_stop_sign.trigger_volume.extent
+ if self.point_inside_boundingbox(ev_loc, transformed_tv, stop_extent):
+ self._affected_by_stop = True
+
+ if not self.is_affected_by_stop(ev_loc, self._target_stop_sign):
+ # is the vehicle out of the influence of this stop sign now?
+ if not self._stop_completed and self._affected_by_stop:
+ # did we stop?
+ stop_loc = self._target_stop_sign.get_transform().location
+ info = {
+ 'event': 'run',
+ 'step': timestamp['step'],
+ 'simulation_time': timestamp['relative_simulation_time'],
+ 'id': self._target_stop_sign.id,
+ 'stop_loc': [stop_loc.x, stop_loc.y, stop_loc.z],
+ 'ev_loc': [ev_loc.x, ev_loc.y, ev_loc.z]
+ }
+ # reset state
+ self._target_stop_sign = None
+ self._stop_completed = False
+ self._affected_by_stop = False
+
+ return info
+
+ def _scan_for_stop_sign(self, vehicle_transform):
+ target_stop_sign = None
+
+ ve_dir = vehicle_transform.get_forward_vector()
+
+ wp = self._map.get_waypoint(vehicle_transform.location)
+ wp_dir = wp.transform.get_forward_vector()
+
+ dot_ve_wp = ve_dir.x * wp_dir.x + ve_dir.y * wp_dir.y + ve_dir.z * wp_dir.z
+
+ if dot_ve_wp > 0: # Ignore all when going in a wrong lane
+ for stop_sign in self._list_stop_signs:
+ if self.is_affected_by_stop(vehicle_transform.location, stop_sign):
+ # this stop sign is affecting the vehicle
+ target_stop_sign = stop_sign
+ break
+
+ return target_stop_sign
+
+ def is_affected_by_stop(self, vehicle_loc, stop, multi_step=20):
+ """
+ Check if the given actor is affected by the stop
+ """
+ affected = False
+ # first we run a fast coarse test
+ stop_t = stop.get_transform()
+ stop_location = stop_t.location
+ if stop_location.distance(vehicle_loc) > self._proximity_threshold:
+ return affected
+
+ transformed_tv = stop_t.transform(stop.trigger_volume.location)
+
+ # slower and accurate test based on waypoint's horizon and geometric test
+ list_locations = [vehicle_loc]
+ waypoint = self._map.get_waypoint(vehicle_loc)
+ for _ in range(multi_step):
+ if waypoint:
+ next_wps = waypoint.next(self._waypoint_step)
+ if not next_wps:
+ break
+ waypoint = next_wps[0]
+ if not waypoint:
+ break
+ list_locations.append(waypoint.transform.location)
+
+ for actor_location in list_locations:
+ if self.point_inside_boundingbox(actor_location, transformed_tv, stop.trigger_volume.extent):
+ affected = True
+
+ return affected
+
+ @staticmethod
+ def _calculate_speed(carla_velocity):
+ return np.linalg.norm([carla_velocity.x, carla_velocity.y])
+
+ @staticmethod
+ def point_inside_boundingbox(point, bb_center, bb_extent):
+ """
+ X
+ :param point:
+ :param bb_center:
+ :param bb_extent:
+ :return:
+ """
+ # bugfix slim bbox
+ bb_extent.x = max(bb_extent.x, bb_extent.y)
+ bb_extent.y = max(bb_extent.x, bb_extent.y)
+
+ # pylint: disable=invalid-name
+ A = carla.Vector2D(bb_center.x - bb_extent.x, bb_center.y - bb_extent.y)
+ B = carla.Vector2D(bb_center.x + bb_extent.x, bb_center.y - bb_extent.y)
+ D = carla.Vector2D(bb_center.x - bb_extent.x, bb_center.y + bb_extent.y)
+ M = carla.Vector2D(point.x, point.y)
+
+ AB = B - A
+ AD = D - A
+ AM = M - A
+ am_ab = AM.x * AB.x + AM.y * AB.y
+ ab_ab = AB.x * AB.x + AB.y * AB.y
+ am_ad = AM.x * AD.x + AM.y * AD.y
+ ad_ad = AD.x * AD.x + AD.y * AD.y
+
+ return am_ab > 0 and am_ab < ab_ab and am_ad > 0 and am_ad < ad_ad
diff --git a/carla_gym/core/task_actor/common/navigation/global_route_planner.py b/carla_gym/core/task_actor/common/navigation/global_route_planner.py
new file mode 100644
index 0000000..14291c7
--- /dev/null
+++ b/carla_gym/core/task_actor/common/navigation/global_route_planner.py
@@ -0,0 +1,368 @@
+"""
+Modified from carla/PythonAPI/carla/agents/navigation/global_route_planner.py
+"""
+
+import numpy as np
+import networkx as nx
+
+import carla
+
+from .map_utils import get_sampled_topology, RoadOption, vector
+
+
+class GlobalRoutePlanner(object):
+
+ def __init__(self, carla_map, resolution):
+ """
+ Constructor
+ """
+ self._map = carla_map
+ self._resolution = resolution
+
+ self._topology = get_sampled_topology(self._map.get_topology(), self._resolution)
+
+ self._intersection_end_node = -1
+ self._previous_decision = RoadOption.VOID
+
+ self._graph, self._id_map, self._road_id_to_edge = self._build_graph()
+ self._find_loose_ends()
+ self._lane_change_link()
+
+ def _build_graph(self):
+ """
+ This function builds a networkx graph representation of topology.
+ The topology is read from self._topology.
+ graph node properties:
+ vertex - (x,y,z) position in world map
+ graph edge properties:
+ entry_vector - unit vector along tangent at entry point
+ exit_vector - unit vector along tangent at exit point
+ net_vector - unit vector of the chord from entry to exit
+ intersection - boolean indicating if the edge belongs to an
+ intersection
+ return : graph -> networkx graph representing the world map,
+ id_map-> mapping from (x,y,z) to node id
+ road_id_to_edge-> map from road id to edge in the graph
+ """
+ graph = nx.DiGraph()
+ id_map = dict() # Map with structure {(x,y,z): id, ... }
+ road_id_to_edge = dict() # Map with structure {road_id: {lane_id: edge, ... }, ... }
+
+ for segment in self._topology:
+
+ entry_xyz, exit_xyz = segment['entryxyz'], segment['exitxyz']
+ path = segment['path']
+ entry_wp, exit_wp = segment['entry'], segment['exit']
+ intersection = entry_wp.is_junction
+ road_id, section_id, lane_id = entry_wp.road_id, entry_wp.section_id, entry_wp.lane_id
+
+ for vertex in entry_xyz, exit_xyz:
+ # Adding unique nodes and populating id_map
+ if vertex not in id_map:
+ new_id = len(id_map)
+ id_map[vertex] = new_id
+ graph.add_node(new_id, vertex=vertex)
+ n1 = id_map[entry_xyz]
+ n2 = id_map[exit_xyz]
+ if road_id not in road_id_to_edge:
+ road_id_to_edge[road_id] = dict()
+ if section_id not in road_id_to_edge[road_id]:
+ road_id_to_edge[road_id][section_id] = dict()
+ road_id_to_edge[road_id][section_id][lane_id] = (n1, n2)
+
+ entry_carla_vector = entry_wp.transform.rotation.get_forward_vector()
+ exit_carla_vector = exit_wp.transform.rotation.get_forward_vector()
+
+ # Adding edge with attributes
+ graph.add_edge(
+ n1, n2,
+ length=len(path) + 1, path=path,
+ entry_waypoint=entry_wp, exit_waypoint=exit_wp,
+ entry_vector=np.array(
+ [entry_carla_vector.x, entry_carla_vector.y, entry_carla_vector.z]),
+ exit_vector=np.array(
+ [exit_carla_vector.x, exit_carla_vector.y, exit_carla_vector.z]),
+ net_vector=vector(entry_wp.transform.location, exit_wp.transform.location),
+ intersection=intersection, type=RoadOption.LANEFOLLOW)
+
+ return graph, id_map, road_id_to_edge
+
+ def _find_loose_ends(self):
+ """
+ This method finds road segments that have an unconnected end, and
+ adds them to the internal graph representation
+ """
+ count_loose_ends = 0
+ for segment in self._topology:
+ end_wp = segment['exit']
+ exit_xyz = segment['exitxyz']
+ road_id, section_id, lane_id = end_wp.road_id, end_wp.section_id, end_wp.lane_id
+ if road_id in self._road_id_to_edge and section_id in self._road_id_to_edge[road_id] and lane_id in self._road_id_to_edge[road_id][section_id]:
+ pass
+ else:
+ count_loose_ends += 1
+ if road_id not in self._road_id_to_edge:
+ self._road_id_to_edge[road_id] = dict()
+ if section_id not in self._road_id_to_edge[road_id]:
+ self._road_id_to_edge[road_id][section_id] = dict()
+ n1 = self._id_map[exit_xyz]
+ n2 = -1*count_loose_ends
+ self._road_id_to_edge[road_id][section_id][lane_id] = (n1, n2)
+ next_wp = end_wp.next(self._resolution)
+ path = []
+ while next_wp is not None and next_wp and next_wp[0].road_id == road_id and next_wp[0].section_id == section_id and next_wp[0].lane_id == lane_id:
+ path.append(next_wp[0])
+ next_wp = next_wp[0].next(self._resolution)
+ if path:
+ n2_xyz = (path[-1].transform.location.x,
+ path[-1].transform.location.y,
+ path[-1].transform.location.z)
+ self._graph.add_node(n2, vertex=n2_xyz)
+ self._graph.add_edge(
+ n1, n2,
+ length=len(path) + 1, path=path,
+ entry_waypoint=end_wp, exit_waypoint=path[-1],
+ entry_vector=None, exit_vector=None, net_vector=None,
+ intersection=end_wp.is_junction, type=RoadOption.LANEFOLLOW)
+
+ def _localize(self, location):
+ """
+ This function finds the road segment closest to given location
+ location : carla.Location to be localized in the graph
+ return : pair node ids representing an edge in the graph
+ """
+ waypoint = self._map.get_waypoint(location)
+ edge = None
+ try:
+ edge = self._road_id_to_edge[waypoint.road_id][waypoint.section_id][waypoint.lane_id]
+ except KeyError:
+ print(
+ "Failed to localize! : ",
+ "Road id : ", waypoint.road_id,
+ "Section id : ", waypoint.section_id,
+ "Lane id : ", waypoint.lane_id,
+ "Location : ", waypoint.transform.location.x,
+ waypoint.transform.location.y)
+ return edge
+
+ def _lane_change_link(self):
+ """
+ This method places zero cost links in the topology graph
+ representing availability of lane changes.
+ """
+
+ for segment in self._topology:
+ left_found, right_found = False, False
+
+ for waypoint in segment['path']:
+ if not segment['entry'].is_junction:
+ next_waypoint, next_road_option, next_segment = None, None, None
+
+ if waypoint.right_lane_marking.lane_change & carla.LaneChange.Right and not right_found:
+ next_waypoint = waypoint.get_right_lane()
+ if next_waypoint is not None and next_waypoint.lane_type == carla.LaneType.Driving and waypoint.road_id == next_waypoint.road_id:
+ next_road_option = RoadOption.CHANGELANERIGHT
+ next_segment = self._localize(next_waypoint.transform.location)
+ if next_segment is not None:
+ self._graph.add_edge(
+ self._id_map[segment['entryxyz']], next_segment[0], entry_waypoint=waypoint,
+ exit_waypoint=next_waypoint, intersection=False, exit_vector=None,
+ path=[], length=0, type=next_road_option, change_waypoint=next_waypoint)
+ right_found = True
+ if waypoint.left_lane_marking.lane_change & carla.LaneChange.Left and not left_found:
+ next_waypoint = waypoint.get_left_lane()
+ if next_waypoint is not None and next_waypoint.lane_type == carla.LaneType.Driving and waypoint.road_id == next_waypoint.road_id:
+ next_road_option = RoadOption.CHANGELANELEFT
+ next_segment = self._localize(next_waypoint.transform.location)
+ if next_segment is not None:
+ self._graph.add_edge(
+ self._id_map[segment['entryxyz']], next_segment[0], entry_waypoint=waypoint,
+ exit_waypoint=next_waypoint, intersection=False, exit_vector=None,
+ path=[], length=0, type=next_road_option, change_waypoint=next_waypoint)
+ left_found = True
+ if left_found and right_found:
+ break
+
+ def _distance_heuristic(self, n1, n2):
+ """
+ Distance heuristic calculator for path searching
+ in self._graph
+ """
+ l1 = np.array(self._graph.nodes[n1]['vertex'])
+ l2 = np.array(self._graph.nodes[n2]['vertex'])
+ return np.linalg.norm(l1-l2)
+
+ def _path_search(self, origin, destination):
+ """
+ This function finds the shortest path connecting origin and destination
+ using A* search with distance heuristic.
+ origin : carla.Location object of start position
+ destination : carla.Location object of of end position
+ return : path as list of node ids (as int) of the graph self._graph
+ connecting origin and destination
+ """
+
+ start, end = self._localize(origin), self._localize(destination)
+
+ route = nx.astar_path(
+ self._graph, source=start[0], target=end[0],
+ heuristic=self._distance_heuristic, weight='length')
+ route.append(end[1])
+ return route
+
+ def _successive_last_intersection_edge(self, index, route):
+ """
+ This method returns the last successive intersection edge
+ from a starting index on the route.
+ This helps moving past tiny intersection edges to calculate
+ proper turn decisions.
+ """
+
+ last_intersection_edge = None
+ last_node = None
+ for node1, node2 in [(route[i], route[i+1]) for i in range(index, len(route)-1)]:
+ candidate_edge = self._graph.edges[node1, node2]
+ if node1 == route[index]:
+ last_intersection_edge = candidate_edge
+ if candidate_edge['type'] == RoadOption.LANEFOLLOW and candidate_edge['intersection']:
+ last_intersection_edge = candidate_edge
+ last_node = node2
+ else:
+ break
+
+ return last_node, last_intersection_edge
+
+ def _turn_decision(self, index, route, threshold=np.deg2rad(35)):
+ """
+ This method returns the turn decision (RoadOption) for pair of edges
+ around current index of route list
+ """
+
+ decision = None
+ previous_node = route[index-1]
+ current_node = route[index]
+ next_node = route[index+1]
+ next_edge = self._graph.edges[current_node, next_node]
+ if index > 0:
+ if self._previous_decision != RoadOption.VOID and self._intersection_end_node > 0 and self._intersection_end_node != previous_node and next_edge['type'] == RoadOption.LANEFOLLOW and next_edge['intersection']:
+ decision = self._previous_decision
+ else:
+ self._intersection_end_node = -1
+ current_edge = self._graph.edges[previous_node, current_node]
+ calculate_turn = current_edge['type'] == RoadOption.LANEFOLLOW and not current_edge[
+ 'intersection'] and next_edge['type'] == RoadOption.LANEFOLLOW and next_edge['intersection']
+ if calculate_turn:
+ last_node, tail_edge = self._successive_last_intersection_edge(index, route)
+ self._intersection_end_node = last_node
+ if tail_edge is not None:
+ next_edge = tail_edge
+ cv, nv = current_edge['exit_vector'], next_edge['exit_vector']
+ if cv is None or nv is None:
+ return next_edge['type']
+ cross_list = []
+ for neighbor in self._graph.successors(current_node):
+ select_edge = self._graph.edges[current_node, neighbor]
+ if select_edge['type'] == RoadOption.LANEFOLLOW:
+ if neighbor != route[index+1]:
+ sv = select_edge['net_vector']
+ cross_list.append(np.cross(cv, sv)[2])
+ next_cross = np.cross(cv, nv)[2]
+ deviation = np.arccos(np.clip(
+ np.dot(cv, nv)/(np.linalg.norm(cv)*np.linalg.norm(nv)), -1.0, 1.0))
+ if not cross_list:
+ cross_list.append(0)
+ if deviation < threshold:
+ decision = RoadOption.STRAIGHT
+ elif cross_list and next_cross < min(cross_list):
+ decision = RoadOption.LEFT
+ elif cross_list and next_cross > max(cross_list):
+ decision = RoadOption.RIGHT
+ elif next_cross < 0:
+ decision = RoadOption.LEFT
+ elif next_cross > 0:
+ decision = RoadOption.RIGHT
+ else:
+ decision = next_edge['type']
+
+ else:
+ decision = next_edge['type']
+
+ self._previous_decision = decision
+ return decision
+
+ def abstract_route_plan(self, origin, destination):
+ """
+ The following function generates the route plan based on
+ origin : carla.Location object of the route's start position
+ destination : carla.Location object of the route's end position
+ return : list of turn by turn navigation decisions as
+ agents.navigation.local_planner.RoadOption elements
+ Possible values are STRAIGHT, LEFT, RIGHT, LANEFOLLOW, VOID
+ CHANGELANELEFT, CHANGELANERIGHT
+ """
+
+ route = self._path_search(origin, destination)
+ plan = []
+
+ for i in range(len(route) - 1):
+ road_option = self._turn_decision(i, route)
+ plan.append(road_option)
+
+ return plan
+
+ def _find_closest_in_list(self, current_waypoint, waypoint_list):
+ min_distance = float('inf')
+ closest_index = -1
+ for i, waypoint in enumerate(waypoint_list):
+ distance = waypoint.transform.location.distance(current_waypoint.transform.location)
+ if distance < min_distance:
+ min_distance = distance
+ closest_index = i
+
+ return closest_index
+
+ def trace_route(self, origin, destination):
+ """
+ This method returns list of (carla.Waypoint, RoadOption)
+ from origin to destination
+ """
+
+ route_trace = []
+ route = self._path_search(origin, destination)
+ current_waypoint = self._map.get_waypoint(origin)
+ destination_waypoint = self._map.get_waypoint(destination)
+
+ for i in range(len(route) - 1):
+ road_option = self._turn_decision(i, route)
+ edge = self._graph.edges[route[i], route[i+1]]
+ path = []
+
+ if edge['type'] != RoadOption.LANEFOLLOW and edge['type'] != RoadOption.VOID:
+ route_trace.append((current_waypoint, road_option))
+ exit_wp = edge['exit_waypoint']
+ n1, n2 = self._road_id_to_edge[exit_wp.road_id][exit_wp.section_id][exit_wp.lane_id]
+ next_edge = self._graph.edges[n1, n2]
+ if next_edge['path']:
+ closest_index = self._find_closest_in_list(current_waypoint, next_edge['path'])
+ closest_index = min(len(next_edge['path'])-1, closest_index+5)
+ current_waypoint = next_edge['path'][closest_index]
+ else:
+ current_waypoint = next_edge['exit_waypoint']
+ route_trace.append((current_waypoint, road_option))
+
+ else:
+ path = path + [edge['entry_waypoint']] + edge['path'] + [edge['exit_waypoint']]
+ closest_index = self._find_closest_in_list(current_waypoint, path)
+ for waypoint in path[closest_index:]:
+ current_waypoint = waypoint
+ route_trace.append((current_waypoint, road_option))
+ if len(route)-i <= 2 and waypoint.transform.location.distance(destination) < 2*self._resolution:
+ break
+ elif len(route)-i <= 2 and current_waypoint.road_id == destination_waypoint.road_id \
+ and current_waypoint.section_id == destination_waypoint.section_id \
+ and current_waypoint.lane_id == destination_waypoint.lane_id:
+ destination_index = self._find_closest_in_list(destination_waypoint, path)
+ if closest_index > destination_index:
+ break
+
+ return route_trace
diff --git a/carla_gym/core/task_actor/common/navigation/map_utils.py b/carla_gym/core/task_actor/common/navigation/map_utils.py
new file mode 100644
index 0000000..442b5d5
--- /dev/null
+++ b/carla_gym/core/task_actor/common/navigation/map_utils.py
@@ -0,0 +1,71 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import numpy as np
+from enum import Enum
+
+
+class RoadOption(Enum):
+ """
+ RoadOption represents the possible topological configurations when moving from a segment of lane to other.
+ """
+ VOID = -1
+ # VOID = 0
+ LEFT = 1
+ RIGHT = 2
+ STRAIGHT = 3
+ LANEFOLLOW = 4
+ CHANGELANELEFT = 5
+ CHANGELANERIGHT = 6
+
+
+def vector(location_1, location_2):
+ """
+ Returns the unit vector from location_1 to location_2
+
+ :param location_1, location_2: carla.Location objects
+ """
+ x = location_2.x - location_1.x
+ y = location_2.y - location_1.y
+ z = location_2.z - location_1.z
+ norm = np.linalg.norm([x, y, z]) + np.finfo(float).eps
+
+ return [x / norm, y / norm, z / norm]
+
+
+def get_sampled_topology(map_topology, resolution):
+ """
+ Accessor for topology.
+ This function retrieves topology from the server as a list of
+ road segments as pairs of waypoint objects, and processes the
+ topology into a list of dictionary objects.
+
+ :return topology: list of dictionary objects with the following attributes
+ entry - waypoint of entry point of road segment
+ entryxyz- (x,y,z) of entry point of road segment
+ exit - waypoint of exit point of road segment
+ exitxyz - (x,y,z) of exit point of road segment
+ path - list of waypoints separated by 1m from entry
+ to exit
+ """
+ topology = []
+ # Retrieving waypoints to construct a detailed topology
+ for segment in map_topology:
+ wp1, wp2 = segment[0], segment[1]
+ l1, l2 = wp1.transform.location, wp2.transform.location
+ # Rounding off to avoid floating point imprecision
+ x1, y1, z1, x2, y2, z2 = np.round([l1.x, l1.y, l1.z, l2.x, l2.y, l2.z], 0)
+ wp1.transform.location, wp2.transform.location = l1, l2
+ seg_dict = dict()
+ seg_dict['entry'], seg_dict['exit'] = wp1, wp2
+ seg_dict['entryxyz'], seg_dict['exitxyz'] = (x1, y1, z1), (x2, y2, z2)
+ seg_dict['path'] = []
+ endloc = wp2.transform.location
+ if wp1.transform.location.distance(endloc) > resolution:
+ w = wp1.next(resolution)[0]
+ while w.transform.location.distance(endloc) > resolution:
+ seg_dict['path'].append(w)
+ w = w.next(resolution)[0]
+ else:
+ seg_dict['path'].append(wp1.next(resolution)[0])
+ topology.append(seg_dict)
+ return topology
diff --git a/carla_gym/core/task_actor/common/navigation/route_manipulation.py b/carla_gym/core/task_actor/common/navigation/route_manipulation.py
new file mode 100644
index 0000000..5a86223
--- /dev/null
+++ b/carla_gym/core/task_actor/common/navigation/route_manipulation.py
@@ -0,0 +1,157 @@
+#!/usr/bin/env python
+# Copyright (c) 2018-2019 Intel Labs.
+# authors: German Ros (german.ros@intel.com), Felipe Codevilla (felipe.alcm@gmail.com)
+#
+# This work is licensed under the terms of the MIT license.
+# For a copy, see .
+
+"""
+Module to manipulate the routes, by making then more or less dense (Up to a certain parameter).
+It also contains functions to convert the CARLA world location do GPS coordinates.
+"""
+
+import math
+import xml.etree.ElementTree as ET
+import carla
+import numpy as np
+
+from .map_utils import RoadOption
+
+EARTH_RADIUS_EQUA = 6378137.0
+
+
+def location_to_gps(location):
+
+ lon = location.x * 180.0 / (math.pi * EARTH_RADIUS_EQUA)
+ lat = 360.0 * math.atan(math.exp(-location.y / EARTH_RADIUS_EQUA)) / math.pi - 90.0
+ z = location.z
+
+ return (lat, lon, z)
+
+
+def gps_to_location(gps):
+ lat, lon, z = gps
+ lat = float(lat)
+ lon = float(lon)
+ z = float(z)
+
+ location = carla.Location(z=z)
+
+ location.x = lon / 180.0 * (math.pi * EARTH_RADIUS_EQUA)
+
+ location.y = -1.0 * math.log(math.tan((lat + 90.0) * math.pi / 360.0)) * EARTH_RADIUS_EQUA
+
+ return location
+
+
+def _location_to_gps_leaderbaord(lat_ref, lon_ref, location):
+ """
+ Convert from world coordinates to GPS coordinates
+ :param lat_ref: latitude reference for the current map
+ :param lon_ref: longitude reference for the current map
+ :param location: location to translate
+ :return: dictionary with lat, lon and height
+ """
+
+ EARTH_RADIUS_EQUA = 6378137.0 # pylint: disable=invalid-name
+ scale = math.cos(lat_ref * math.pi / 180.0)
+ mx = scale * lon_ref * math.pi * EARTH_RADIUS_EQUA / 180.0
+ my = scale * EARTH_RADIUS_EQUA * math.log(math.tan((90.0 + lat_ref) * math.pi / 360.0))
+ mx += location.x
+ my -= location.y
+
+ lon = mx * 180.0 / (math.pi * EARTH_RADIUS_EQUA * scale)
+ lat = 360.0 * math.atan(math.exp(my / (EARTH_RADIUS_EQUA * scale))) / math.pi - 90.0
+ z = location.z
+
+ return {'lat': lat, 'lon': lon, 'z': z}
+
+
+def location_route_to_gps(route):
+ """
+ Locate each waypoint of the route into gps, (lat long ) representations.
+ :param route:
+ :param lat_ref:
+ :param lon_ref:
+ :return:
+ """
+ # lat_ref, lon_ref = _get_latlon_ref(world)
+
+ gps_route = []
+
+ for wp, connection in route:
+ gps_point = location_to_gps(wp.transform.location)
+ gps_route.append((gps_point, connection))
+
+ return gps_route
+
+
+def _get_latlon_ref(world):
+ """
+ Convert from waypoints world coordinates to CARLA GPS coordinates
+ :return: tuple with lat and lon coordinates
+ """
+ xodr = world.get_map().to_opendrive()
+ tree = ET.ElementTree(ET.fromstring(xodr))
+
+ # default reference
+ lat_ref = 42.0
+ lon_ref = 2.0
+
+ for opendrive in tree.iter("OpenDRIVE"):
+ for header in opendrive.iter("header"):
+ for georef in header.iter("geoReference"):
+ if georef.text:
+ str_list = georef.text.split(' ')
+ for item in str_list:
+ if '+lat_0' in item:
+ lat_ref = float(item.split('=')[1])
+ if '+lon_0' in item:
+ lon_ref = float(item.split('=')[1])
+ return lat_ref, lon_ref
+
+
+def downsample_route(route, sample_factor):
+ """
+ Downsample the route by some factor.
+ :param route: the trajectory , has to contain the waypoints and the road options
+ :param sample_factor: Maximum distance between samples
+ :return: returns the ids of the final route that can
+ """
+
+ ids_to_sample = []
+ prev_option = None
+ dist = 0
+
+ for i, point in enumerate(route):
+ curr_option = point[1]
+
+ # Lane changing
+ if curr_option in (RoadOption.CHANGELANELEFT, RoadOption.CHANGELANERIGHT):
+ ids_to_sample.append(i)
+ dist = 0
+
+ # When road option changes
+ elif prev_option != curr_option and prev_option not in (RoadOption.CHANGELANELEFT, RoadOption.CHANGELANERIGHT):
+ ids_to_sample.append(i)
+ dist = 0
+
+ # After a certain max distance
+ elif dist > sample_factor:
+ ids_to_sample.append(i)
+ dist = 0
+
+ # At the end
+ elif i == len(route) - 1:
+ ids_to_sample.append(i)
+ dist = 0
+
+ # Compute the distance traveled
+ else:
+ curr_location = point[0].transform.location
+ prev_location = route[i-1][0].transform.location
+ dist += curr_location.distance(prev_location)
+
+ prev_option = curr_option
+
+ return ids_to_sample
diff --git a/carla_gym/core/task_actor/common/task_vehicle.py b/carla_gym/core/task_actor/common/task_vehicle.py
new file mode 100644
index 0000000..0ea3f61
--- /dev/null
+++ b/carla_gym/core/task_actor/common/task_vehicle.py
@@ -0,0 +1,229 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import carla
+import weakref
+from .navigation.global_route_planner import GlobalRoutePlanner
+from .navigation.route_manipulation import location_route_to_gps, downsample_route
+import numpy as np
+import logging
+import copy
+
+from .criteria import blocked, collision, outside_route_lane, route_deviation, run_stop_sign
+from .criteria import encounter_light, run_red_light
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.DEBUG)
+
+
+class TaskVehicle(object):
+
+ def __init__(self, vehicle, target_transforms, spawn_transforms, endless):
+ """
+ vehicle: carla.Vehicle
+ target_transforms: list of carla.Transform
+ """
+ self.vehicle = vehicle
+ world = self.vehicle.get_world()
+ self._map = world.get_map()
+ self._world = world
+
+ self.criteria_blocked = blocked.Blocked()
+ self.criteria_collision = collision.Collision(self.vehicle, world)
+ self.criteria_light = run_red_light.RunRedLight(self._map)
+ self.criteria_encounter_light = encounter_light.EncounterLight()
+ self.criteria_stop = run_stop_sign.RunStopSign(world)
+ self.criteria_outside_route_lane = outside_route_lane.OutsideRouteLane(self._map, self.vehicle.get_location())
+ self.criteria_route_deviation = route_deviation.RouteDeviation()
+
+ # navigation
+ self._route_completed = 0.0
+ self._route_length = 0.0
+
+ self._target_transforms = target_transforms # transforms
+
+ self._planner = GlobalRoutePlanner(self._map, resolution=1.0)
+
+ self._global_route = []
+ self._global_plan_gps = []
+ self._global_plan_world_coord = []
+
+ self._trace_route_to_global_target()
+
+ self._spawn_transforms = spawn_transforms
+
+ self._endless = endless
+ if len(self._target_transforms) == 0:
+ while self._route_length < 1000.0:
+ self._add_random_target()
+
+ self._last_route_location = self.vehicle.get_location()
+ self.collision_px = False
+
+ def _update_leaderboard_plan(self, route_trace):
+ plan_gps = location_route_to_gps(route_trace)
+ ds_ids = downsample_route(route_trace, 50)
+
+ self._global_plan_gps += [plan_gps[x] for x in ds_ids]
+ self._global_plan_world_coord += [(route_trace[x][0].transform.location, route_trace[x][1]) for x in ds_ids]
+
+ def _add_random_target(self):
+ if len(self._target_transforms) == 0:
+ last_target_loc = self.vehicle.get_location()
+ ev_wp = self._map.get_waypoint(last_target_loc)
+ next_wp = ev_wp.next(6)[0]
+ new_target_transform = next_wp.transform
+ else:
+ last_target_loc = self._target_transforms[-1].location
+ last_road_id = self._map.get_waypoint(last_target_loc).road_id
+ new_target_transform = np.random.choice([x[1] for x in self._spawn_transforms if x[0] != last_road_id])
+
+ route_trace = self._planner.trace_route(last_target_loc, new_target_transform.location)
+ self._global_route += route_trace
+ self._target_transforms.append(new_target_transform)
+ self._route_length += self._compute_route_length(route_trace)
+ self._update_leaderboard_plan(route_trace)
+
+ def _trace_route_to_global_target(self):
+ current_location = self.vehicle.get_location()
+ for tt in self._target_transforms:
+ next_target_location = tt.location
+ route_trace = self._planner.trace_route(current_location, next_target_location)
+ self._global_route += route_trace
+ self._route_length += self._compute_route_length(route_trace)
+ current_location = next_target_location
+
+ self._update_leaderboard_plan(self._global_route)
+
+ @staticmethod
+ def _compute_route_length(route):
+ length_in_m = 0.0
+ for i in range(len(route)-1):
+ d = route[i][0].transform.location.distance(route[i+1][0].transform.location)
+ length_in_m += d
+ return length_in_m
+
+ def _truncate_global_route_till_local_target(self, windows_size=5):
+ ev_location = self.vehicle.get_location()
+ closest_idx = 0
+
+ for i in range(len(self._global_route)-1):
+ if i > windows_size:
+ break
+
+ loc0 = self._global_route[i][0].transform.location
+ loc1 = self._global_route[i+1][0].transform.location
+
+ wp_dir = loc1 - loc0
+ wp_veh = ev_location - loc0
+ dot_ve_wp = wp_veh.x * wp_dir.x + wp_veh.y * wp_dir.y + wp_veh.z * wp_dir.z
+
+ if dot_ve_wp > 0:
+ closest_idx = i+1
+
+ distance_traveled = self._compute_route_length(self._global_route[:closest_idx+1])
+ self._route_completed += distance_traveled
+
+ if closest_idx > 0:
+ self._last_route_location = carla.Location(self._global_route[0][0].transform.location)
+
+ self._global_route = self._global_route[closest_idx:]
+ return distance_traveled
+
+ def _is_route_completed(self, percentage_threshold=0.99, distance_threshold=10.0):
+ # distance_threshold=10.0
+ ev_loc = self.vehicle.get_location()
+
+ percentage_route_completed = self._route_completed / self._route_length
+ is_completed = percentage_route_completed > percentage_threshold
+ is_within_dist = ev_loc.distance(self._target_transforms[-1].location) < distance_threshold
+
+ return is_completed and is_within_dist
+
+ def tick(self, timestamp):
+ distance_traveled = self._truncate_global_route_till_local_target()
+ route_completed = self._is_route_completed()
+ if self._endless and (len(self._global_route) < 10 or route_completed):
+ self._add_random_target()
+ route_completed = False
+
+ info_blocked = self.criteria_blocked.tick(self.vehicle, timestamp)
+ info_collision = self.criteria_collision.tick(self.vehicle, timestamp)
+ info_light = self.criteria_light.tick(self.vehicle, timestamp)
+ info_encounter_light = self.criteria_encounter_light.tick(self.vehicle, timestamp)
+ info_stop = self.criteria_stop.tick(self.vehicle, timestamp)
+ info_outside_route_lane = self.criteria_outside_route_lane.tick(self.vehicle, timestamp, distance_traveled)
+ info_route_deviation = self.criteria_route_deviation.tick(
+ self.vehicle, timestamp, self._global_route[0][0], distance_traveled, self._route_length)
+
+ info_route_completion = {
+ 'step': timestamp['step'],
+ 'simulation_time': timestamp['relative_simulation_time'],
+ 'route_completed_in_m': self._route_completed,
+ 'route_length_in_m': self._route_length,
+ 'is_route_completed': route_completed
+ }
+
+ self._info_criteria = {
+ 'route_completion': info_route_completion,
+ 'outside_route_lane': info_outside_route_lane,
+ 'route_deviation': info_route_deviation,
+ 'blocked': info_blocked,
+ 'collision': info_collision,
+ 'run_red_light': info_light,
+ 'encounter_light': info_encounter_light,
+ 'run_stop_sign': info_stop
+ }
+
+ # turn on light
+ weather = self._world.get_weather()
+ if weather.sun_altitude_angle < 0.0:
+ vehicle_lights = carla.VehicleLightState.Position | carla.VehicleLightState.LowBeam
+ else:
+ vehicle_lights = carla.VehicleLightState.NONE
+ self.vehicle.set_light_state(carla.VehicleLightState(vehicle_lights))
+
+ return self._info_criteria
+
+ def clean(self):
+ self.criteria_collision.clean()
+ self.vehicle.destroy()
+
+ @property
+ def info_criteria(self):
+ return self._info_criteria
+
+ @property
+ def dest_transform(self):
+ return self._target_transforms[-1]
+
+ @property
+ def route_plan(self):
+ return self._global_route
+
+ @property
+ def global_plan_gps(self):
+ return self._global_plan_gps
+
+ @property
+ def global_plan_world_coord(self):
+ return self._global_plan_world_coord
+
+ @property
+ def route_length(self):
+ return self._route_length
+
+ @property
+ def route_completed(self):
+ return self._route_completed
+
+ def get_route_transform(self):
+ loc0 = self._last_route_location
+ loc1 = self._global_route[0][0].transform.location
+
+ if loc1.distance(loc0) < 0.1:
+ yaw = self._global_route[0][0].transform.rotation.yaw
+ else:
+ f_vec = loc1 - loc0
+ yaw = np.rad2deg(np.arctan2(f_vec.y, f_vec.x))
+ rot = carla.Rotation(yaw=yaw)
+ return carla.Transform(location=loc0, rotation=rot)
diff --git a/carla_gym/core/task_actor/ego_vehicle/ego_vehicle_handler.py b/carla_gym/core/task_actor/ego_vehicle/ego_vehicle_handler.py
new file mode 100644
index 0000000..83910db
--- /dev/null
+++ b/carla_gym/core/task_actor/ego_vehicle/ego_vehicle_handler.py
@@ -0,0 +1,254 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+from carla_gym.core.task_actor.common.task_vehicle import TaskVehicle
+import numpy as np
+from importlib import import_module
+
+PENALTY_COLLISION_PEDESTRIAN = 0.50
+PENALTY_COLLISION_VEHICLE = 0.60
+PENALTY_COLLISION_STATIC = 0.65
+PENALTY_TRAFFIC_LIGHT = 0.70
+PENALTY_STOP = 0.80
+
+
+class EgoVehicleHandler(object):
+ def __init__(self, client, reward_configs, terminal_configs):
+ self.ego_vehicles = {}
+ self.info_buffers = {}
+ self.reward_buffers = {}
+ self.reward_handlers = {}
+ self.terminal_handlers = {}
+
+ self._reward_configs = reward_configs
+ self._terminal_configs = terminal_configs
+
+ self._world = client.get_world()
+ self._map = self._world.get_map()
+ self._spawn_transforms = self._get_spawn_points(self._map)
+
+ def reset(self, task_config):
+ actor_config = task_config['actors']
+ route_config = task_config['routes']
+ endless_config = task_config.get('endless')
+
+ ev_spawn_locations = []
+ for ev_id in actor_config:
+ bp_filter = actor_config[ev_id]['model']
+ blueprint = np.random.choice(self._world.get_blueprint_library().filter(bp_filter))
+ blueprint.set_attribute('role_name', ev_id)
+
+ if len(route_config[ev_id]) == 0:
+ spawn_transform = np.random.choice([x[1] for x in self._spawn_transforms])
+ else:
+ spawn_transform = route_config[ev_id][0]
+
+ wp = self._map.get_waypoint(spawn_transform.location)
+ spawn_transform.location.z = wp.transform.location.z + 1.321
+
+ carla_vehicle = self._world.try_spawn_actor(blueprint, spawn_transform)
+ self._world.tick()
+
+ if endless_config is None:
+ endless = False
+ else:
+ endless = endless_config[ev_id]
+ target_transforms = route_config[ev_id][1:]
+ self.ego_vehicles[ev_id] = TaskVehicle(carla_vehicle, target_transforms, self._spawn_transforms, endless)
+
+ self.reward_handlers[ev_id] = self._build_instance(
+ self._reward_configs[ev_id], self.ego_vehicles[ev_id])
+ self.terminal_handlers[ev_id] = self._build_instance(
+ self._terminal_configs[ev_id], self.ego_vehicles[ev_id])
+
+ self.reward_buffers[ev_id] = []
+ self.info_buffers[ev_id] = {
+ 'collisions_layout': [],
+ 'collisions_vehicle': [],
+ 'collisions_pedestrian': [],
+ 'collisions_others': [],
+ 'red_light': [],
+ 'encounter_light': [],
+ 'stop_infraction': [],
+ 'encounter_stop': [],
+ 'route_dev': [],
+ 'vehicle_blocked': [],
+ 'outside_lane': [],
+ 'wrong_lane': []
+ }
+
+ ev_spawn_locations.append(carla_vehicle.get_location())
+ return ev_spawn_locations
+
+ @staticmethod
+ def _build_instance(config, ego_vehicle):
+ module_str, class_str = config['entry_point'].split(':')
+ _Class = getattr(import_module('carla_gym.core.task_actor.ego_vehicle.'+module_str), class_str)
+ return _Class(ego_vehicle, **config.get('kwargs', {}))
+
+ def apply_control(self, control_dict):
+ for ev_id, control in control_dict.items():
+ self.ego_vehicles[ev_id].vehicle.apply_control(control)
+
+ def tick(self, timestamp):
+ reward_dict, done_dict, info_dict = {}, {}, {}
+
+ for ev_id, ev in self.ego_vehicles.items():
+ info_criteria = ev.tick(timestamp)
+ info = info_criteria.copy()
+ done, timeout, terminal_reward, terminal_debug = self.terminal_handlers[ev_id].get(timestamp)
+ reward, reward_debug = self.reward_handlers[ev_id].get(terminal_reward)
+
+ reward_dict[ev_id] = reward
+ done_dict[ev_id] = done
+ info_dict[ev_id] = info
+ info_dict[ev_id]['timeout'] = timeout
+ info_dict[ev_id]['reward_debug'] = reward_debug
+ info_dict[ev_id]['terminal_debug'] = terminal_debug
+
+ # accumulate into buffers
+ self.reward_buffers[ev_id].append(reward)
+
+ if info['collision']:
+ if info['collision']['collision_type'] == 0:
+ self.info_buffers[ev_id]['collisions_layout'].append(info['collision'])
+ elif info['collision']['collision_type'] == 1:
+ self.info_buffers[ev_id]['collisions_vehicle'].append(info['collision'])
+ elif info['collision']['collision_type'] == 2:
+ self.info_buffers[ev_id]['collisions_pedestrian'].append(info['collision'])
+ else:
+ self.info_buffers[ev_id]['collisions_others'].append(info['collision'])
+ if info['run_red_light']:
+ self.info_buffers[ev_id]['red_light'].append(info['run_red_light'])
+ if info['encounter_light']:
+ self.info_buffers[ev_id]['encounter_light'].append(info['encounter_light'])
+ if info['run_stop_sign']:
+ if info['run_stop_sign']['event'] == 'encounter':
+ self.info_buffers[ev_id]['encounter_stop'].append(info['run_stop_sign'])
+ elif info['run_stop_sign']['event'] == 'run':
+ self.info_buffers[ev_id]['stop_infraction'].append(info['run_stop_sign'])
+ if info['route_deviation']:
+ self.info_buffers[ev_id]['route_dev'].append(info['route_deviation'])
+ if info['blocked']:
+ self.info_buffers[ev_id]['vehicle_blocked'].append(info['blocked'])
+ if info['outside_route_lane']:
+ if info['outside_route_lane']['outside_lane']:
+ self.info_buffers[ev_id]['outside_lane'].append(info['outside_route_lane'])
+ if info['outside_route_lane']['wrong_lane']:
+ self.info_buffers[ev_id]['wrong_lane'].append(info['outside_route_lane'])
+ # save episode summary
+ if done:
+ info_dict[ev_id]['episode_event'] = self.info_buffers[ev_id]
+ info_dict[ev_id]['episode_event']['timeout'] = info['timeout']
+ info_dict[ev_id]['episode_event']['route_completion'] = info['route_completion']
+
+ total_length = float(info['route_completion']['route_length_in_m']) / 1000
+ completed_length = float(info['route_completion']['route_completed_in_m']) / 1000
+ total_length = max(total_length, 0.001)
+ completed_length = max(completed_length, 0.001)
+
+ outside_lane_length = np.sum([x['distance_traveled']
+ for x in self.info_buffers[ev_id]['outside_lane']]) / 1000
+ wrong_lane_length = np.sum([x['distance_traveled']
+ for x in self.info_buffers[ev_id]['wrong_lane']]) / 1000
+
+ if ev._endless:
+ score_route = completed_length
+ else:
+ if info['route_completion']['is_route_completed']:
+ score_route = 1.0
+ else:
+ score_route = completed_length / total_length
+
+ n_collisions_layout = int(len(self.info_buffers[ev_id]['collisions_layout']))
+ n_collisions_vehicle = int(len(self.info_buffers[ev_id]['collisions_vehicle']))
+ n_collisions_pedestrian = int(len(self.info_buffers[ev_id]['collisions_pedestrian']))
+ n_collisions_others = int(len(self.info_buffers[ev_id]['collisions_others']))
+ n_red_light = int(len(self.info_buffers[ev_id]['red_light']))
+ n_encounter_light = int(len(self.info_buffers[ev_id]['encounter_light']))
+ n_stop_infraction = int(len(self.info_buffers[ev_id]['stop_infraction']))
+ n_encounter_stop = int(len(self.info_buffers[ev_id]['encounter_stop']))
+ n_collisions = n_collisions_layout + n_collisions_vehicle + n_collisions_pedestrian + n_collisions_others
+
+ score_penalty = 1.0 * (1 - (outside_lane_length+wrong_lane_length)/completed_length) \
+ * (PENALTY_COLLISION_STATIC ** n_collisions_layout) \
+ * (PENALTY_COLLISION_VEHICLE ** n_collisions_vehicle) \
+ * (PENALTY_COLLISION_PEDESTRIAN ** n_collisions_pedestrian) \
+ * (PENALTY_TRAFFIC_LIGHT ** n_red_light) \
+ * (PENALTY_STOP ** n_stop_infraction) \
+
+ if info['route_completion']['is_route_completed'] and n_collisions == 0:
+ is_route_completed_nocrash = 1.0
+ else:
+ is_route_completed_nocrash = 0.0
+
+ info_dict[ev_id]['episode_stat'] = {
+ 'score_route': score_route,
+ 'score_penalty': score_penalty,
+ 'score_composed': max(score_route*score_penalty, 0.0),
+ 'length': len(self.reward_buffers[ev_id]),
+ 'reward': np.sum(self.reward_buffers[ev_id]),
+ 'timeout': float(info['timeout']),
+ 'is_route_completed': float(info['route_completion']['is_route_completed']),
+ 'is_route_completed_nocrash': is_route_completed_nocrash,
+ 'route_completed_in_km': completed_length,
+ 'route_length_in_km': total_length,
+ 'percentage_outside_lane': outside_lane_length / completed_length,
+ 'percentage_wrong_lane': wrong_lane_length / completed_length,
+ 'collisions_layout': n_collisions_layout / completed_length,
+ 'collisions_vehicle': n_collisions_vehicle / completed_length,
+ 'collisions_pedestrian': n_collisions_pedestrian / completed_length,
+ 'collisions_others': n_collisions_others / completed_length,
+ 'red_light': n_red_light / completed_length,
+ 'light_passed': n_encounter_light-n_red_light,
+ 'encounter_light': n_encounter_light,
+ 'stop_infraction': n_stop_infraction / completed_length,
+ 'stop_passed': n_encounter_stop-n_stop_infraction,
+ 'encounter_stop': n_encounter_stop,
+ 'route_dev': len(self.info_buffers[ev_id]['route_dev']) / completed_length,
+ 'vehicle_blocked': len(self.info_buffers[ev_id]['vehicle_blocked']) / completed_length
+ }
+
+ done_dict['__all__'] = all(done for obs_id, done in done_dict.items())
+ return reward_dict, done_dict, info_dict
+
+ def clean(self):
+ for ev_id, ev in self.ego_vehicles.items():
+ ev.clean()
+ self.ego_vehicles = {}
+ self.reward_handlers = {}
+ self.terminal_handlers = {}
+ self.info_buffers = {}
+ self.reward_buffers = {}
+
+ @staticmethod
+ def _get_spawn_points(c_map):
+ all_spawn_points = c_map.get_spawn_points()
+
+ spawn_transforms = []
+ for trans in all_spawn_points:
+ wp = c_map.get_waypoint(trans.location)
+
+ if wp.is_junction:
+ wp_prev = wp
+ # wp_next = wp
+ while wp_prev.is_junction:
+ wp_prev = wp_prev.previous(1.0)[0]
+ spawn_transforms.append([wp_prev.road_id, wp_prev.transform])
+ if c_map.name == 'Town03' and (wp_prev.road_id == 44):
+ for _ in range(100):
+ spawn_transforms.append([wp_prev.road_id, wp_prev.transform])
+ # while wp_next.is_junction:
+ # wp_next = wp_next.next(1.0)[0]
+
+ # spawn_transforms.append([wp_next.road_id, wp_next.transform])
+ # if c_map.name == 'Town03' and (wp_next.road_id == 44 or wp_next.road_id == 58):
+ # for _ in range(100):
+ # spawn_transforms.append([wp_next.road_id, wp_next.transform])
+
+ else:
+ spawn_transforms.append([wp.road_id, wp.transform])
+ if c_map.name == 'Town03' and (wp.road_id == 44):
+ for _ in range(100):
+ spawn_transforms.append([wp.road_id, wp.transform])
+
+ return spawn_transforms
diff --git a/carla_gym/core/task_actor/ego_vehicle/reward/valeo_action.py b/carla_gym/core/task_actor/ego_vehicle/reward/valeo_action.py
new file mode 100644
index 0000000..adb7c8e
--- /dev/null
+++ b/carla_gym/core/task_actor/ego_vehicle/reward/valeo_action.py
@@ -0,0 +1,138 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import numpy as np
+import carla
+
+import carla_gym.utils.transforms as trans_utils
+from carla_gym.core.obs_manager.object_finder.vehicle import ObsManager as OmVehicle
+from carla_gym.core.obs_manager.object_finder.pedestrian import ObsManager as OmPedestrian
+
+from carla_gym.utils.traffic_light import TrafficLightHandler
+from carla_gym.utils.hazard_actor import lbc_hazard_vehicle, lbc_hazard_walker
+
+
+class ValeoAction(object):
+
+ def __init__(self, ego_vehicle):
+ self._ego_vehicle = ego_vehicle
+
+ self.om_vehicle = OmVehicle({'max_detection_number': 10, 'distance_threshold': 15})
+ self.om_pedestrian = OmPedestrian({'max_detection_number': 10, 'distance_threshold': 15})
+ self.om_vehicle.attach_ego_vehicle(self._ego_vehicle)
+ self.om_pedestrian.attach_ego_vehicle(self._ego_vehicle)
+
+ self._maxium_speed = 6.0
+ self._last_steer = 0.0
+ self._tl_offset = -0.8 * self._ego_vehicle.vehicle.bounding_box.extent.x
+
+ def get(self, terminal_reward):
+ ev_transform = self._ego_vehicle.vehicle.get_transform()
+ ev_control = self._ego_vehicle.vehicle.get_control()
+ ev_vel = self._ego_vehicle.vehicle.get_velocity()
+ ev_speed = np.linalg.norm(np.array([ev_vel.x, ev_vel.y]))
+
+ # action
+ if abs(ev_control.steer - self._last_steer) > 0.01:
+ r_action = -0.1
+ else:
+ r_action = 0.0
+ self._last_steer = ev_control.steer
+
+ # desired_speed
+ obs_vehicle = self.om_vehicle.get_observation()
+ obs_pedestrian = self.om_pedestrian.get_observation()
+
+ # all locations in ego_vehicle coordinate
+ hazard_vehicle_loc = lbc_hazard_vehicle(obs_vehicle, proximity_threshold=9.5)
+ hazard_ped_loc = lbc_hazard_walker(obs_pedestrian, proximity_threshold=9.5)
+ light_state, light_loc, _ = TrafficLightHandler.get_light_state(self._ego_vehicle.vehicle,
+ offset=self._tl_offset, dist_threshold=18.0)
+
+ desired_spd_veh = desired_spd_ped = desired_spd_rl = desired_spd_stop = self._maxium_speed
+
+ if hazard_vehicle_loc is not None:
+ dist_veh = max(0.0, np.linalg.norm(hazard_vehicle_loc[0:2])-8.0)
+ desired_spd_veh = self._maxium_speed * np.clip(dist_veh, 0.0, 5.0)/5.0
+
+ if hazard_ped_loc is not None:
+ dist_ped = max(0.0, np.linalg.norm(hazard_ped_loc[0:2])-6.0)
+ desired_spd_ped = self._maxium_speed * np.clip(dist_ped, 0.0, 5.0)/5.0
+
+ if (light_state == carla.TrafficLightState.Red or light_state == carla.TrafficLightState.Yellow):
+ dist_rl = max(0.0, np.linalg.norm(light_loc[0:2])-5.0)
+ desired_spd_rl = self._maxium_speed * np.clip(dist_rl, 0.0, 5.0)/5.0
+
+ # stop sign
+ stop_sign = self._ego_vehicle.criteria_stop._target_stop_sign
+ stop_loc = None
+ if (stop_sign is not None) and (not self._ego_vehicle.criteria_stop._stop_completed):
+ trans = stop_sign.get_transform()
+ tv_loc = stop_sign.trigger_volume.location
+ loc_in_world = trans.transform(tv_loc)
+ loc_in_ev = trans_utils.loc_global_to_ref(loc_in_world, ev_transform)
+ stop_loc = np.array([loc_in_ev.x, loc_in_ev.y, loc_in_ev.z], dtype=np.float32)
+ dist_stop = max(0.0, np.linalg.norm(stop_loc[0:2])-5.0)
+ desired_spd_stop = self._maxium_speed * np.clip(dist_stop, 0.0, 5.0)/5.0
+
+ desired_speed = min(self._maxium_speed, desired_spd_veh, desired_spd_ped, desired_spd_rl, desired_spd_stop)
+
+ # r_speed
+ if ev_speed > self._maxium_speed:
+ # r_speed = 0.0
+ r_speed = 1.0 - np.abs(ev_speed-desired_speed) / self._maxium_speed
+ else:
+ r_speed = 1.0 - np.abs(ev_speed-desired_speed) / self._maxium_speed
+
+ # r_position
+ wp_transform = self._ego_vehicle.get_route_transform()
+
+ d_vec = ev_transform.location - wp_transform.location
+ np_d_vec = np.array([d_vec.x, d_vec.y], dtype=np.float32)
+ wp_unit_forward = wp_transform.rotation.get_forward_vector()
+ np_wp_unit_right = np.array([-wp_unit_forward.y, wp_unit_forward.x], dtype=np.float32)
+
+ lateral_distance = np.abs(np.dot(np_wp_unit_right, np_d_vec))
+ r_position = -1.0 * (lateral_distance / 2.0)
+
+ # r_rotation
+ angle_difference = np.deg2rad(np.abs(trans_utils.cast_angle(
+ ev_transform.rotation.yaw - wp_transform.rotation.yaw)))
+ # r_rotation = -1.0 * (angle_difference / np.pi)
+ r_rotation = -1.0 * angle_difference
+
+ reward = r_speed + r_position + r_rotation + terminal_reward + r_action
+
+ if hazard_vehicle_loc is None:
+ txt_hazard_veh = '[]'
+ else:
+ txt_hazard_veh = np.array2string(hazard_vehicle_loc[0:2], precision=1, separator=',', suppress_small=True)
+ if hazard_ped_loc is None:
+ txt_hazard_ped = '[]'
+ else:
+ txt_hazard_ped = np.array2string(hazard_ped_loc[0:2], precision=1, separator=',', suppress_small=True)
+ if light_loc is None:
+ txt_light = '[]'
+ else:
+ txt_light = np.array2string(light_loc[0:2], precision=1, separator=',', suppress_small=True)
+ if stop_loc is None:
+ txt_stop = '[]'
+ else:
+ txt_stop = np.array2string(stop_loc[0:2], precision=1, separator=',', suppress_small=True)
+
+ debug_texts = [
+ f'Desired speed: {desired_speed:5.2f}m/s',
+ f'Vehicles desired speed:{desired_spd_veh:5.2f}m/s {txt_hazard_veh}',
+ f'Pedestrians desired speed:{desired_spd_ped:5.2f}m/s {txt_hazard_ped}',
+ f'Traffic light desired speed:{desired_spd_rl:5.2f}m/s, light state: {light_state} {txt_light}',
+ f'Stop sign desired speed:{desired_spd_stop:5.2f}m/s {txt_stop}',
+ f'Reward_terminal:{terminal_reward:5.2f}'
+ ]
+ reward_debug = {
+ 'debug_texts': debug_texts,
+ 'reward': reward,
+ 'reward_speed': r_speed,
+ 'reward_position': r_position,
+ 'reward_angle': r_rotation,
+ 'reward_oscillation': r_action,
+ }
+ return reward, reward_debug
diff --git a/carla_gym/core/task_actor/ego_vehicle/terminal/leaderboard.py b/carla_gym/core/task_actor/ego_vehicle/terminal/leaderboard.py
new file mode 100644
index 0000000..fa382d0
--- /dev/null
+++ b/carla_gym/core/task_actor/ego_vehicle/terminal/leaderboard.py
@@ -0,0 +1,40 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+class Leaderboard(object):
+
+ def __init__(self, ego_vehicle, max_time=None):
+ self._ego_vehicle = ego_vehicle
+ self._max_time = max_time # in sec
+
+ def get(self, timestamp):
+
+ info_criteria = self._ego_vehicle.info_criteria
+ # Done condition 1: route completed
+ c_route = info_criteria['route_completion']['is_route_completed']
+
+ # Done condition 2: blocked
+ c_blocked = info_criteria['blocked'] is not None
+
+ # Done condition 3: route_deviation
+ c_route_deviation = info_criteria['route_deviation'] is not None
+
+ # Done condition 4: timeout
+ if self._max_time is not None:
+ timeout = timestamp['relative_simulation_time'] > self._max_time
+ else:
+ timeout = False
+
+ done = c_route or c_blocked or c_route_deviation or timeout
+
+ debug_texts = [
+ f'cpl:{int(c_route)} dev:{int(c_route_deviation)} blo:{int(c_blocked)} t_out:{int(timeout)}'
+ ]
+
+ terminal_debug = {
+ 'blocked': c_blocked,
+ 'route_deviation': c_route_deviation,
+ 'debug_texts': debug_texts
+ }
+
+ terminal_reward = 0.0
+ return done, timeout, terminal_reward, terminal_debug
diff --git a/carla_gym/core/task_actor/ego_vehicle/terminal/leaderboard_dagger.py b/carla_gym/core/task_actor/ego_vehicle/terminal/leaderboard_dagger.py
new file mode 100644
index 0000000..152780b
--- /dev/null
+++ b/carla_gym/core/task_actor/ego_vehicle/terminal/leaderboard_dagger.py
@@ -0,0 +1,54 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+class LeaderboardDagger(object):
+
+ def __init__(self, ego_vehicle, no_collision=True, no_run_rl=True, no_run_stop=True, max_time=300):
+ self._ego_vehicle = ego_vehicle
+
+ self._no_collision = no_collision
+ self._no_run_rl = no_run_rl
+ self._no_run_stop = no_run_stop
+ self._max_time = max_time # in sec
+
+ def get(self, timestamp):
+
+ info_criteria = self._ego_vehicle.info_criteria
+
+ # Done condition 1: blocked
+ c_blocked = info_criteria['blocked'] is not None
+
+ # Done condition 2: route_deviation
+ c_route_deviation = info_criteria['route_deviation'] is not None
+
+ # Done condition 3: collision
+ c_collision = (info_criteria['collision'] is not None) and self._no_collision
+
+ # Done condition 4: running red light
+ c_run_rl = (info_criteria['run_red_light'] is not None) and self._no_run_rl
+
+ # Done condition 5: run stop sign
+ if info_criteria['run_stop_sign'] is not None and info_criteria['run_stop_sign']['event'] == 'run':
+ c_run_stop = True
+ else:
+ c_run_stop = False
+ c_run_stop = c_run_stop and self._no_run_stop
+
+ # Done condition 6: timeout
+ timeout = timestamp['relative_simulation_time'] > self._max_time
+
+ done = c_blocked or c_route_deviation or c_collision or c_run_rl or c_run_stop or timeout
+
+ debug_texts = [
+ f'dev:{int(c_route_deviation)} blo:{int(c_blocked)} t_out:{int(timeout)}',
+ f'col:{int(c_collision)} redl:{int(c_run_rl)} stop:{int(c_run_stop)}'
+ ]
+
+ terminal_debug = {
+ 'traffic_rule_violated': c_collision or c_run_rl or c_run_stop,
+ 'blocked': c_blocked,
+ 'route_deviation': c_route_deviation,
+ 'debug_texts': debug_texts
+ }
+
+ terminal_reward = 0.0
+ return done, timeout, terminal_reward, terminal_debug
diff --git a/carla_gym/core/task_actor/ego_vehicle/terminal/valeo.py b/carla_gym/core/task_actor/ego_vehicle/terminal/valeo.py
new file mode 100644
index 0000000..b0c994b
--- /dev/null
+++ b/carla_gym/core/task_actor/ego_vehicle/terminal/valeo.py
@@ -0,0 +1,147 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import numpy as np
+from collections import deque
+import carla
+
+from carla_gym.core.obs_manager.object_finder.vehicle import ObsManager as OmVehicle
+from carla_gym.core.obs_manager.object_finder.pedestrian import ObsManager as OmPedestrian
+from carla_gym.utils.hazard_actor import lbc_hazard_vehicle, lbc_hazard_walker
+from carla_gym.utils.traffic_light import TrafficLightHandler
+
+
+class Valeo(object):
+ '''
+ Follow valeo paper as close as possible
+ '''
+
+ def __init__(self, ego_vehicle, exploration_suggest=True, eval_mode=False):
+ self._ego_vehicle = ego_vehicle
+ self._exploration_suggest = exploration_suggest
+
+ self.om_vehicle = OmVehicle({'max_detection_number': 10, 'distance_threshold': 15})
+ self.om_pedestrian = OmPedestrian({'max_detection_number': 10, 'distance_threshold': 15})
+ self.om_vehicle.attach_ego_vehicle(self._ego_vehicle)
+ self.om_pedestrian.attach_ego_vehicle(self._ego_vehicle)
+
+ self._vehicle_stuck_step = 100
+ self._vehicle_stuck_counter = 0
+ self._speed_queue = deque(maxlen=10)
+ self._tl_offset = -0.8 * self._ego_vehicle.vehicle.bounding_box.extent.x
+ self._last_lat_dist = 0.0
+ self._min_thresh_lat_dist = 3.5
+
+ self._eval_mode = eval_mode
+ self._eval_time = 1200
+
+ def get(self, timestamp):
+ # Done condition 1: vehicle stuck
+ ev_vel = self._ego_vehicle.vehicle.get_velocity()
+ ev_speed = np.linalg.norm(np.array([ev_vel.x, ev_vel.y]))
+ self._speed_queue.append(ev_speed)
+ obs_vehicle = self.om_vehicle.get_observation()
+ obs_pedestrian = self.om_pedestrian.get_observation()
+ hazard_vehicle_loc = lbc_hazard_vehicle(obs_vehicle, proximity_threshold=9.5)
+ hazard_ped_loc = lbc_hazard_walker(obs_pedestrian, proximity_threshold=9.5)
+
+ light_state, light_loc, _ = TrafficLightHandler.get_light_state(self._ego_vehicle.vehicle,
+ offset=self._tl_offset, dist_threshold=18.0)
+
+ is_free_road = (hazard_vehicle_loc is None) and (hazard_ped_loc is None) \
+ and (light_state is None or light_state == carla.TrafficLightState.Green)
+
+ if is_free_road and np.mean(self._speed_queue) < 1.0:
+ self._vehicle_stuck_counter += 1
+ if np.mean(self._speed_queue) >= 1.0:
+ self._vehicle_stuck_counter = 0
+
+ c_vehicle_stuck = self._vehicle_stuck_counter >= self._vehicle_stuck_step
+
+ # Done condition 2: lateral distance too large
+ ev_loc = self._ego_vehicle.vehicle.get_location()
+ wp_transform = self._ego_vehicle.get_route_transform()
+ d_vec = ev_loc - wp_transform.location
+ np_d_vec = np.array([d_vec.x, d_vec.y], dtype=np.float32)
+ wp_unit_forward = wp_transform.rotation.get_forward_vector()
+ np_wp_unit_right = np.array([-wp_unit_forward.y, wp_unit_forward.x], dtype=np.float32)
+ lat_dist = np.abs(np.dot(np_wp_unit_right, np_d_vec))
+
+ if lat_dist - self._last_lat_dist > 0.8:
+ thresh_lat_dist = lat_dist + 0.5
+ else:
+ thresh_lat_dist = max(self._min_thresh_lat_dist, self._last_lat_dist)
+ c_lat_dist = lat_dist > thresh_lat_dist + 1e-2
+ self._last_lat_dist = lat_dist
+
+ # Done condition 3: running red light
+ c_run_rl = self._ego_vehicle.info_criteria['run_red_light'] is not None
+ # Done condition 4: collision
+ c_collision = self._ego_vehicle.info_criteria['collision'] is not None
+ # Done condition 5: run stop sign
+ if self._ego_vehicle.info_criteria['run_stop_sign'] is not None \
+ and self._ego_vehicle.info_criteria['run_stop_sign']['event'] == 'run':
+ c_run_stop = True
+ else:
+ c_run_stop = False
+
+ # Done condition 6: vehicle blocked
+ c_blocked = self._ego_vehicle.info_criteria['blocked'] is not None
+
+ # endless env: timeout means succeed
+ if self._eval_mode:
+ timeout = timestamp['relative_simulation_time'] > self._eval_time
+ else:
+ timeout = False
+
+ done = c_vehicle_stuck or c_lat_dist or c_run_rl or c_collision or c_run_stop or c_blocked or timeout
+
+ # terminal reward
+ terminal_reward = 0.0
+ if done:
+ terminal_reward = -1.0
+ if c_run_rl or c_collision or c_run_stop:
+ terminal_reward -= ev_speed
+
+ # terminal guide
+ exploration_suggest = {
+ 'n_steps': 0,
+ 'suggest': ('', '')
+ }
+ if self._exploration_suggest:
+ if c_vehicle_stuck or c_blocked:
+ exploration_suggest['n_steps'] = 100
+ exploration_suggest['suggest'] = ('go', '')
+ if c_lat_dist:
+ exploration_suggest['n_steps'] = 100
+ exploration_suggest['suggest'] = ('', 'turn')
+ if c_run_rl or c_collision or c_run_stop:
+ exploration_suggest['n_steps'] = 100
+ exploration_suggest['suggest'] = ('stop', '')
+
+ # debug info
+ if hazard_vehicle_loc is None:
+ txt_hazard_veh = '[]'
+ else:
+ txt_hazard_veh = np.array2string(hazard_vehicle_loc[0:2], precision=1, separator=',', suppress_small=True)
+ if hazard_ped_loc is None:
+ txt_hazard_ped = '[]'
+ else:
+ txt_hazard_ped = np.array2string(hazard_ped_loc[0:2], precision=1, separator=',', suppress_small=True)
+ if light_loc is None:
+ txt_hazard_rl = '[]'
+ else:
+ txt_hazard_rl = np.array2string(light_loc[0:2], precision=1, separator=',', suppress_small=True)
+
+ debug_texts = [
+ f'{self._vehicle_stuck_counter:3}/{self._vehicle_stuck_step}'
+ f' fre:{int(is_free_road)} stu:{int(c_vehicle_stuck)} blo:{int(c_blocked)}',
+ f'v:{txt_hazard_veh} p:{txt_hazard_ped} {light_state}{txt_hazard_rl}',
+ f'ev: {int(self._eval_mode)} col:{int(c_collision)} red:{int(c_run_rl)} st:{int(c_run_stop)}',
+ f'latd:{int(c_lat_dist)}, {lat_dist:.2f}/{thresh_lat_dist:.2f}',
+ f"[{exploration_suggest['n_steps']} {exploration_suggest['suggest']}]"
+ ]
+ terminal_debug = {
+ 'exploration_suggest': exploration_suggest,
+ 'debug_texts': debug_texts
+ }
+ return done, timeout, terminal_reward, terminal_debug
diff --git a/carla_gym/core/task_actor/ego_vehicle/terminal/valeo_no_det_px.py b/carla_gym/core/task_actor/ego_vehicle/terminal/valeo_no_det_px.py
new file mode 100644
index 0000000..890901a
--- /dev/null
+++ b/carla_gym/core/task_actor/ego_vehicle/terminal/valeo_no_det_px.py
@@ -0,0 +1,103 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import numpy as np
+
+
+class ValeoNoDetPx(object):
+ '''
+ Follow valeo paper as close as possible
+ '''
+
+ def __init__(self, ego_vehicle, exploration_suggest=True, eval_mode=False):
+ self._ego_vehicle = ego_vehicle
+ self._exploration_suggest = exploration_suggest
+
+ self._last_lat_dist = 0.0
+ self._min_thresh_lat_dist = 3.5
+
+ self._eval_mode = eval_mode
+ self._eval_time = 1200
+
+ def get(self, timestamp):
+ # Done condition 1: vehicle blocked
+ c_blocked = self._ego_vehicle.info_criteria['blocked'] is not None
+
+ # Done condition 2: lateral distance too large
+ ev_loc = self._ego_vehicle.vehicle.get_location()
+ wp_transform = self._ego_vehicle.get_route_transform()
+ d_vec = ev_loc - wp_transform.location
+ np_d_vec = np.array([d_vec.x, d_vec.y], dtype=np.float32)
+ wp_unit_forward = wp_transform.rotation.get_forward_vector()
+ np_wp_unit_right = np.array([-wp_unit_forward.y, wp_unit_forward.x], dtype=np.float32)
+ lat_dist = np.abs(np.dot(np_wp_unit_right, np_d_vec))
+
+ if lat_dist - self._last_lat_dist > 0.8:
+ thresh_lat_dist = lat_dist + 0.5
+ else:
+ thresh_lat_dist = max(self._min_thresh_lat_dist, self._last_lat_dist)
+ c_lat_dist = lat_dist > thresh_lat_dist + 1e-2
+ self._last_lat_dist = lat_dist
+
+ # Done condition 3: running red light
+ c_run_rl = self._ego_vehicle.info_criteria['run_red_light'] is not None
+ # Done condition 4: collision
+ c_collision = self._ego_vehicle.info_criteria['collision'] is not None
+ # Done condition 5: run stop sign
+ if self._ego_vehicle.info_criteria['run_stop_sign'] is not None \
+ and self._ego_vehicle.info_criteria['run_stop_sign']['event'] == 'run':
+ c_run_stop = True
+ else:
+ c_run_stop = False
+
+ # Done condition 6: collision_px
+ if self._eval_mode:
+ c_collision_px = False
+ else:
+ c_collision_px = self._ego_vehicle.collision_px
+
+ # endless env: timeout means succeed
+ if self._eval_mode:
+ timeout = timestamp['relative_simulation_time'] > self._eval_time
+ else:
+ timeout = False
+
+ done = c_blocked or c_lat_dist or c_run_rl or c_collision or c_run_stop or c_collision_px or timeout
+
+ # terminal reward
+ terminal_reward = 0.0
+ if done:
+ terminal_reward = -1.0
+ if c_run_rl or c_collision or c_run_stop or c_collision_px:
+ ev_vel = self._ego_vehicle.vehicle.get_velocity()
+ ev_speed = np.linalg.norm(np.array([ev_vel.x, ev_vel.y]))
+ terminal_reward -= ev_speed
+
+ # terminal guide
+ exploration_suggest = {
+ 'n_steps': 0,
+ 'suggest': ('', '')
+ }
+ if self._exploration_suggest:
+ if c_blocked:
+ exploration_suggest['n_steps'] = 100
+ exploration_suggest['suggest'] = ('go', '')
+ if c_lat_dist:
+ exploration_suggest['n_steps'] = 100
+ exploration_suggest['suggest'] = ('go', 'turn')
+ if c_run_rl or c_collision or c_run_stop or c_collision_px:
+ exploration_suggest['n_steps'] = 100
+ exploration_suggest['suggest'] = ('stop', '')
+
+ # debug info
+
+ debug_texts = [
+ f'ev: {int(self._eval_mode)} blo:{int(c_blocked)} to:{int(timeout)}',
+ f'c_px:{int(c_collision_px)} col:{int(c_collision)} red:{int(c_run_rl)} st:{int(c_run_stop)}',
+ f"latd:{int(c_lat_dist)}, {lat_dist:.2f}/{thresh_lat_dist:.2f}, "
+ f"[{exploration_suggest['n_steps']} {exploration_suggest['suggest']}]"
+ ]
+ terminal_debug = {
+ 'exploration_suggest': exploration_suggest,
+ 'debug_texts': debug_texts
+ }
+ return done, timeout, terminal_reward, terminal_debug
diff --git a/carla_gym/core/task_actor/scenario_actor/agents/basic_agent.py b/carla_gym/core/task_actor/scenario_actor/agents/basic_agent.py
new file mode 100644
index 0000000..92ebf39
--- /dev/null
+++ b/carla_gym/core/task_actor/scenario_actor/agents/basic_agent.py
@@ -0,0 +1,88 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+import carla
+import numpy as np
+from .utils.local_planner import LocalPlanner
+from .utils.misc import is_within_distance_ahead, compute_yaw_difference
+
+
+class BasicAgent(object):
+ def __init__(self, scenario_vehicle, hero_vehicles, target_speed=0.0, max_skip=20, success_dist=5.0):
+ self._scenario_vehicle = scenario_vehicle
+ self._world = self._scenario_vehicle.vehicle.get_world()
+ self._map = self._world.get_map()
+
+ self._dest_transform = scenario_vehicle.dest_transform
+ self._success_dist = success_dist
+ self._proximity_threshold = 9.5
+
+ self._local_planner = LocalPlanner(target_speed=target_speed)
+
+ def get_action(self):
+ transform = self._scenario_vehicle.vehicle.get_transform()
+
+ actor_list = self._world.get_actors()
+ vehicle_list = actor_list.filter('*vehicle*')
+ walkers_list = actor_list.filter('*walker*')
+ vehicle_hazard = self._is_vehicle_hazard(transform, self._scenario_vehicle.vehicle.id, vehicle_list)
+ pedestrian_ahead = self._is_walker_hazard(transform, walkers_list)
+
+ # check red light
+ redlight_ahead = self._scenario_vehicle.vehicle.is_at_traffic_light()
+ # target_reached
+ target_reached = transform.location.distance(self._dest_transform.location) < self._success_dist
+
+ if vehicle_hazard or pedestrian_ahead or redlight_ahead or target_reached:
+ throttle, steer, brake = 0.0, 0.0, 1.0
+ else:
+ route_plan = self._scenario_vehicle.route_plan
+ # ego_vehicle_speed
+ velocity = self._scenario_vehicle.vehicle.get_velocity()
+ forward_vec = transform.get_forward_vector()
+ vel = np.array([velocity.x, velocity.y, velocity.z])
+ f_vec = np.array([forward_vec.x, forward_vec.y, forward_vec.z])
+ forward_speed = np.dot(vel, f_vec)
+ speed = np.linalg.norm(vel)
+
+ throttle, steer, brake = self._local_planner.run_step(route_plan, transform, forward_speed)
+
+ return np.array([throttle, steer, brake], dtype=np.float64)
+
+ def _is_vehicle_hazard(self, ev_transform, ev_id, vehicle_list):
+ ego_vehicle_location = ev_transform.location
+ ego_vehicle_orientation = ev_transform.rotation.yaw
+
+ for target_vehicle in vehicle_list:
+ if target_vehicle.id == ev_id:
+ continue
+
+ loc = target_vehicle.get_location()
+ ori = target_vehicle.get_transform().rotation.yaw
+
+ if compute_yaw_difference(ego_vehicle_orientation, ori) <= 150 and \
+ is_within_distance_ahead(loc, ego_vehicle_location, ego_vehicle_orientation,
+ self._proximity_threshold, degree=45):
+ return True
+
+ return False
+
+ def _is_walker_hazard(self, ev_transform, walkers_list):
+ ego_vehicle_location = ev_transform.location
+
+ for walker in walkers_list:
+ loc = walker.get_location()
+ dist = loc.distance(ego_vehicle_location)
+ degree = 162 / (np.clip(dist, 1.5, 10.5)+0.3)
+ if self._is_point_on_sidewalk(loc):
+ continue
+
+ if is_within_distance_ahead(loc, ego_vehicle_location, ev_transform.rotation.yaw,
+ self._proximity_threshold, degree=degree):
+ return True
+ return False
+
+ def _is_point_on_sidewalk(self, loc):
+ wp = self._map.get_waypoint(loc, project_to_road=False, lane_type=carla.LaneType.Sidewalk)
+ if wp is None:
+ return False
+ else:
+ return True
diff --git a/carla_gym/core/task_actor/scenario_actor/agents/constant_speed_agent.py b/carla_gym/core/task_actor/scenario_actor/agents/constant_speed_agent.py
new file mode 100644
index 0000000..9ef43c8
--- /dev/null
+++ b/carla_gym/core/task_actor/scenario_actor/agents/constant_speed_agent.py
@@ -0,0 +1,31 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import numpy as np
+from .utils.local_planner import LocalPlanner
+
+
+class ConstantSpeedAgent(object):
+ def __init__(self, scenario_vehicle, hero_vehicles, target_speed=0.0, max_skip=20, success_dist=5.0):
+ self._scenario_vehicle = scenario_vehicle
+ self._dest_transform = scenario_vehicle.dest_transform
+ self._success_dist = success_dist
+
+ self._local_planner = LocalPlanner(target_speed=target_speed)
+
+ def get_action(self):
+ transform = self._scenario_vehicle.vehicle.get_transform()
+
+ if transform.location.distance(self._dest_transform.location) < self._success_dist:
+ throttle, steer, brake = 0.0, 0.0, 1.0
+ else:
+ route_plan = self._scenario_vehicle.route_plan
+ velocity = self._scenario_vehicle.vehicle.get_velocity()
+ # ego_vehicle_speed
+ forward_vec = transform.get_forward_vector()
+ vel = np.array([velocity.x, velocity.y, velocity.z])
+ f_vec = np.array([forward_vec.x, forward_vec.y, forward_vec.z])
+ forward_speed = np.dot(vel, f_vec)
+ speed = np.linalg.norm(vel)
+ throttle, steer, brake = self._local_planner.run_step(route_plan, transform, forward_speed)
+
+ return np.array([throttle, steer, brake], dtype=np.float64)
diff --git a/carla_gym/core/task_actor/scenario_actor/agents/utils/controller.py b/carla_gym/core/task_actor/scenario_actor/agents/utils/controller.py
new file mode 100644
index 0000000..38e8b72
--- /dev/null
+++ b/carla_gym/core/task_actor/scenario_actor/agents/utils/controller.py
@@ -0,0 +1,31 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+from collections import deque
+
+
+class PIDController(object):
+ def __init__(self, pid_list, n=30):
+ self._K_P, self._K_I, self._K_D = pid_list
+
+ self._dt = 1.0 / 10.0
+ self._window = deque(maxlen=n)
+
+ def reset(self):
+ self._window.clear()
+
+ def step(self, error):
+ self._window.append(error)
+
+ if len(self._window) >= 2:
+ integral = sum(self._window) * self._dt
+ derivative = (self._window[-1] - self._window[-2]) / self._dt
+ else:
+ integral = 0.0
+ derivative = 0.0
+
+ control = 0.0
+ control += self._K_P * error
+ control += self._K_I * integral
+ control += self._K_D * derivative
+
+ return control
diff --git a/carla_gym/core/task_actor/scenario_actor/agents/utils/local_planner.py b/carla_gym/core/task_actor/scenario_actor/agents/utils/local_planner.py
new file mode 100644
index 0000000..d3a1559
--- /dev/null
+++ b/carla_gym/core/task_actor/scenario_actor/agents/utils/local_planner.py
@@ -0,0 +1,80 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+from enum import Enum
+import numpy as np
+
+from .controller import PIDController
+import carla_gym.utils.transforms as trans_utils
+
+
+class RoadOption(Enum):
+ """
+ RoadOption represents the possible topological configurations
+ when moving from a segment of lane to other.
+ """
+ VOID = -1
+ LEFT = 1
+ RIGHT = 2
+ STRAIGHT = 3
+ LANEFOLLOW = 4
+ CHANGELANELEFT = 5
+ CHANGELANERIGHT = 6
+
+
+class LocalPlanner(object):
+
+ def __init__(self, target_speed=0.0,
+ longitudinal_pid_params=[0.5, 0.025, 0.1],
+ lateral_pid_params=[0.75, 0.05, 0.0],
+ threshold_before=7.5,
+ threshold_after=5.0):
+
+ self._target_speed = target_speed
+ self._speed_pid = PIDController(longitudinal_pid_params)
+ self._turn_pid = PIDController(lateral_pid_params)
+ self._threshold_before = threshold_before
+ self._threshold_after = threshold_after
+ self._max_skip = 20
+
+ self._last_command = 4
+
+ def run_step(self, route_plan, actor_transform, actor_speed):
+ target_index = -1
+ for i, (waypoint, road_option) in enumerate(route_plan[0:self._max_skip]):
+ if self._last_command == 4 and road_option.value != 4:
+ threshold = self._threshold_before
+ else:
+ threshold = self._threshold_after
+
+ distance = waypoint.transform.location.distance(actor_transform.location)
+ if distance < threshold:
+ self._last_command = road_option.value
+ target_index = i
+
+ if target_index < len(route_plan)-1:
+ target_index += 1
+ target_command = route_plan[target_index][1]
+ target_location_world_coord = route_plan[target_index][0].transform.location
+ target_location_actor_coord = trans_utils.loc_global_to_ref(target_location_world_coord, actor_transform)
+
+ # steer
+ x = target_location_actor_coord.x
+ y = target_location_actor_coord.y
+ theta = np.arctan2(y, x)
+ steer = self._turn_pid.step(theta)
+
+ # throttle
+ target_speed = self._target_speed
+ if target_command not in [3, 4]:
+ target_speed *= 0.75
+ delta = target_speed - actor_speed
+ throttle = self._speed_pid.step(delta)
+
+ # brake
+ brake = 0.0
+
+ # clip
+ steer = np.clip(steer, -1.0, 1.0)
+ throttle = np.clip(throttle, 0.0, 1.0)
+
+ return throttle, steer, brake
diff --git a/carla_gym/core/task_actor/scenario_actor/agents/utils/misc.py b/carla_gym/core/task_actor/scenario_actor/agents/utils/misc.py
new file mode 100644
index 0000000..357ccf6
--- /dev/null
+++ b/carla_gym/core/task_actor/scenario_actor/agents/utils/misc.py
@@ -0,0 +1,124 @@
+#!/usr/bin/env python
+
+# Copyright (c) 2018 Intel Labs.
+# authors: German Ros (german.ros@intel.com)
+#
+# This work is licensed under the terms of the MIT license.
+# For a copy, see .
+
+""" Module with auxiliary functions. """
+
+import math
+
+import numpy as np
+
+import carla
+
+
+def draw_waypoints(world, waypoints, z=0.5):
+ """
+ Draw a list of waypoints at a certain height given in z.
+
+ :param world: carla.world object
+ :param waypoints: list or iterable container with the waypoints to draw
+ :param z: height in meters
+ :return:
+ """
+ for w in waypoints:
+ t = w.transform
+ begin = t.location + carla.Location(z=z)
+ angle = math.radians(t.rotation.yaw)
+ end = begin + carla.Location(x=math.cos(angle), y=math.sin(angle))
+ world.debug.draw_arrow(begin, end, arrow_size=0.3, life_time=1.0)
+
+
+def get_speed(vehicle):
+ """
+ Compute speed of a vehicle in Kmh
+ :param vehicle: the vehicle for which speed is calculated
+ :return: speed as a float in Kmh
+ """
+ vel = vehicle.get_velocity()
+ return 3.6 * math.sqrt(vel.x ** 2 + vel.y ** 2 + vel.z ** 2)
+
+
+def compute_yaw_difference(yaw1, yaw2):
+ u = np.array([
+ math.cos(math.radians(yaw1)),
+ math.sin(math.radians(yaw1)),
+ ])
+
+ v = np.array([
+ math.cos(math.radians(yaw2)),
+ math.sin(math.radians(yaw2)),
+ ])
+
+
+ angle = math.degrees(math.acos(np.clip(np.dot(u, v), -1, 1)))
+
+ return angle
+
+
+def is_within_distance_ahead(target_location, current_location, orientation, max_distance, degree=60):
+ """
+ Check if a target object is within a certain distance in front of a reference object.
+
+ :param target_location: location of the target object
+ :param current_location: location of the reference object
+ :param orientation: orientation of the reference object
+ :param max_distance: maximum allowed distance
+ :return: True if target object is within max_distance ahead of the reference object
+ """
+ u = np.array([
+ target_location.x - current_location.x,
+ target_location.y - current_location.y])
+ distance = np.linalg.norm(u)
+
+ if distance > max_distance:
+ return False
+
+ v = np.array([
+ math.cos(math.radians(orientation)),
+ math.sin(math.radians(orientation))])
+
+ angle = math.degrees(math.acos(np.dot(u, v) / distance))
+
+ return angle < degree
+
+
+def compute_magnitude_angle(target_location, current_location, orientation):
+ """
+ Compute relative angle and distance between a target_location and a current_location
+
+ :param target_location: location of the target object
+ :param current_location: location of the reference object
+ :param orientation: orientation of the reference object
+ :return: a tuple composed by the distance to the object and the angle between both objects
+ """
+ target_vector = np.array([target_location.x - current_location.x, target_location.y - current_location.y])
+ norm_target = np.linalg.norm(target_vector)
+
+ forward_vector = np.array([math.cos(math.radians(orientation)), math.sin(math.radians(orientation))])
+ d_angle = math.degrees(math.acos(np.dot(forward_vector, target_vector) / norm_target))
+
+ return (norm_target, d_angle)
+
+
+def distance_vehicle(waypoint, vehicle_transform):
+ loc = vehicle_transform.location
+ dx = waypoint.transform.location.x - loc.x
+ dy = waypoint.transform.location.y - loc.y
+
+ return math.sqrt(dx * dx + dy * dy)
+
+def vector(location_1, location_2):
+ """
+ Returns the unit vector from location_1 to location_2
+ location_1, location_2 : carla.Location objects
+ """
+ x = location_2.x - location_1.x
+ y = location_2.y - location_1.y
+ z = location_2.z - location_1.z
+ norm = np.linalg.norm([x, y, z])
+
+ return [x/norm, y/norm, z/norm]
diff --git a/carla_gym/core/task_actor/scenario_actor/scenario_actor_handler.py b/carla_gym/core/task_actor/scenario_actor/scenario_actor_handler.py
new file mode 100644
index 0000000..0a55105
--- /dev/null
+++ b/carla_gym/core/task_actor/scenario_actor/scenario_actor_handler.py
@@ -0,0 +1,52 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+from carla_gym.core.task_actor.common.task_vehicle import TaskVehicle
+import numpy as np
+from importlib import import_module
+
+
+class ScenarioActorHandler(object):
+ def __init__(self, client):
+ self.scenario_actors = {}
+ self.scenario_agents = {}
+ self.hero_vehicles = {}
+
+ self._client = client
+ self._world = client.get_world()
+
+ def reset(self, task_config, hero_vehicles):
+ self.hero_vehicles = hero_vehicles
+
+ actor_config = task_config.get('actors', {})
+ route_config = task_config.get('routes', {})
+
+ for sa_id in actor_config:
+ # spawn actors
+ bp_filter = actor_config[sa_id]['model']
+ blueprint = np.random.choice(self._world.get_blueprint_library().filter(bp_filter))
+ blueprint.set_attribute('role_name', sa_id)
+ spawn_transform = route_config[sa_id][0]
+ carla_vehicle = self._world.try_spawn_actor(blueprint, spawn_transform)
+ self._world.tick()
+ target_transforms = route_config[sa_id][1:]
+ self.scenario_actors[sa_id] = TaskVehicle(carla_vehicle, target_transforms)
+ # make agents
+ module_str, class_str = actor_config[sa_id]['agent_entry_point'].split(':')
+ AgentClass = getattr(
+ import_module('carla_gym.core.task_actor.scenario_actor.agents.' + module_str),
+ class_str)
+ self.scenario_agents[sa_id] = AgentClass(self.scenario_actors[sa_id], self.hero_vehicles,
+ **actor_config[sa_id].get('agent_kwargs', {}))
+
+ def tick(self):
+ for sa_id in self.scenario_actors:
+ action = self.scenario_agents[sa_id].get_action()
+ self.scenario_actors[sa_id].apply_control(action)
+ self.scenario_actors[sa_id].tick()
+
+ def clean(self):
+ for sa_id in self.scenario_actors:
+ self.scenario_actors[sa_id].clean()
+ self.scenario_actors = {}
+ self.scenario_agents = {}
+ self.hero_vehicles = {}
diff --git a/carla_gym/core/zombie_vehicle/zombie_vehicle.py b/carla_gym/core/zombie_vehicle/zombie_vehicle.py
new file mode 100644
index 0000000..845cf5b
--- /dev/null
+++ b/carla_gym/core/zombie_vehicle/zombie_vehicle.py
@@ -0,0 +1,16 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import numpy as np
+import carla
+
+class ZombieVehicle(object):
+ def __init__(self, actor_id, world):
+ self._vehicle = world.get_actor(actor_id)
+
+ def teleport_to(self, transform):
+ self._vehicle.set_transform(transform)
+ self._vehicle.set_velocity(carla.Vector3D())
+
+ def clean(self):
+ # self._vehicle.set_autopilot(False)
+ self._vehicle.destroy()
diff --git a/carla_gym/core/zombie_vehicle/zombie_vehicle_handler.py b/carla_gym/core/zombie_vehicle/zombie_vehicle_handler.py
new file mode 100644
index 0000000..976c632
--- /dev/null
+++ b/carla_gym/core/zombie_vehicle/zombie_vehicle_handler.py
@@ -0,0 +1,85 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import numpy as np
+import carla
+import logging
+from .zombie_vehicle import ZombieVehicle
+
+
+class ZombieVehicleHandler(object):
+
+ def __init__(self, client, tm_port=8000, spawn_distance_to_ev=10.0):
+ self._logger = logging.getLogger(__name__)
+ self.zombie_vehicles = {}
+ self._client = client
+ self._world = client.get_world()
+ self._spawn_distance_to_ev = spawn_distance_to_ev
+ self._tm_port = tm_port
+
+ def reset(self, num_zombie_vehicles, ev_spawn_locations):
+ if type(num_zombie_vehicles) is list:
+ n_spawn = np.random.randint(num_zombie_vehicles[0], num_zombie_vehicles[1])
+ else:
+ n_spawn = num_zombie_vehicles
+ filtered_spawn_points = self._filter_spawn_points(ev_spawn_locations)
+ np.random.shuffle(filtered_spawn_points)
+
+ self._spawn_vehicles(filtered_spawn_points[0:n_spawn])
+
+ def _filter_spawn_points(self, ev_spawn_locations):
+ all_spawn_points = self._world.get_map().get_spawn_points()
+
+ def proximity_to_ev(transform): return any([ev_loc.distance(transform.location) < self._spawn_distance_to_ev
+ for ev_loc in ev_spawn_locations])
+
+ filtered_spawn_points = [transform for transform in all_spawn_points if not proximity_to_ev(transform)]
+
+ return filtered_spawn_points
+
+ def _spawn_vehicles(self, spawn_transforms):
+ zombie_vehicle_ids = []
+ blueprints = self._world.get_blueprint_library().filter("vehicle.*")
+ SpawnActor = carla.command.SpawnActor
+ SetAutopilot = carla.command.SetAutopilot
+ FutureActor = carla.command.FutureActor
+
+ batch = []
+ for transform in spawn_transforms:
+ blueprint = np.random.choice(blueprints)
+ if blueprint.has_attribute('color'):
+ color = np.random.choice(blueprint.get_attribute('color').recommended_values)
+ blueprint.set_attribute('color', color)
+ if blueprint.has_attribute('driver_id'):
+ driver_id = np.random.choice(blueprint.get_attribute('driver_id').recommended_values)
+ blueprint.set_attribute('driver_id', driver_id)
+ blueprint.set_attribute('role_name', 'zombie_vehicle')
+
+ batch.append(SpawnActor(blueprint, transform).then(SetAutopilot(FutureActor, True, self._tm_port)))
+
+ for response in self._client.apply_batch_sync(batch, do_tick=True):
+ if not response.error:
+ zombie_vehicle_ids.append(response.actor_id)
+
+ for zv_id in zombie_vehicle_ids:
+ self.zombie_vehicles[zv_id] = ZombieVehicle(zv_id, self._world)
+
+ self._logger.debug(f'Spawned {len(zombie_vehicle_ids)} zombie vehicles. '
+ f'Should spawn {len(spawn_transforms)}')
+
+ def tick(self):
+ pass
+
+ def clean(self):
+ live_vehicle_list = [vehicle.id for vehicle in self._world.get_actors().filter("*vehicle*")]
+ # batch1 = []
+ # batch2 = []
+ # SetAutopilot = carla.command.SetAutopilot
+ # DestroyActor = carla.command.DestroyActor
+ # batch1.append(SetAutopilot(zv_id, False))
+ # batch1.append(DestroyActor(zv_id))
+ # self._client.apply_batch_sync(batch1, do_tick=True)
+ # self._client.apply_batch_sync(batch2, do_tick=True)
+ for zv_id, zv in self.zombie_vehicles.items():
+ if zv_id in live_vehicle_list:
+ zv.clean()
+ self.zombie_vehicles = {}
diff --git a/carla_gym/core/zombie_walker/zombie_walker.py b/carla_gym/core/zombie_walker/zombie_walker.py
new file mode 100644
index 0000000..61be444
--- /dev/null
+++ b/carla_gym/core/zombie_walker/zombie_walker.py
@@ -0,0 +1,20 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import numpy as np
+
+
+class ZombieWalker(object):
+ def __init__(self, walker_id, controller_id, world):
+
+ self._walker = world.get_actor(walker_id)
+ self._controller = world.get_actor(controller_id)
+
+ self._controller.start()
+ self._controller.go_to_location(world.get_random_location_from_navigation())
+ self._controller.set_max_speed(1 + np.random.random())
+
+
+ def clean(self):
+ self._controller.stop()
+ self._controller.destroy()
+ self._walker.destroy()
diff --git a/carla_gym/core/zombie_walker/zombie_walker_handler.py b/carla_gym/core/zombie_walker/zombie_walker_handler.py
new file mode 100644
index 0000000..ab8ae46
--- /dev/null
+++ b/carla_gym/core/zombie_walker/zombie_walker_handler.py
@@ -0,0 +1,100 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import carla
+import numpy as np
+import logging
+from .zombie_walker import ZombieWalker
+
+
+class ZombieWalkerHandler(object):
+
+ def __init__(self, client, spawn_distance_to_ev=10.0):
+ self._logger = logging.getLogger(__name__)
+ self.zombie_walkers = {}
+ self._client = client
+ self._world = client.get_world()
+ self._spawn_distance_to_ev = spawn_distance_to_ev
+
+ def reset(self, num_zombie_walkers, ev_spawn_locations):
+ if type(num_zombie_walkers) is list:
+ n_spawn = np.random.randint(num_zombie_walkers[0], num_zombie_walkers[1])
+ else:
+ n_spawn = num_zombie_walkers
+ self._spawn(n_spawn, ev_spawn_locations)
+ self._logger.debug(f'Spawned {len(self.zombie_walkers)} zombie walkers. '
+ f'Should Spawn {num_zombie_walkers}')
+
+ def _spawn(self, num_zombie_walkers, ev_spawn_locations, max_trial=10, tick=True):
+ SpawnActor = carla.command.SpawnActor
+ walker_bp_library = self._world.get_blueprint_library().filter('walker.pedestrian.*')
+ walker_controller_bp = self._world.get_blueprint_library().find('controller.ai.walker')
+
+ def proximity_to_ev(location): return any([ev_loc.distance(location) < self._spawn_distance_to_ev
+ for ev_loc in ev_spawn_locations])
+
+ controller_ids = []
+ walker_ids = []
+ num_spawned = 0
+ n_trial = 0
+ while num_spawned < num_zombie_walkers:
+ spawn_points = []
+ _walkers = []
+ _controllers = []
+
+ for i in range(num_zombie_walkers - num_spawned):
+ is_proximity_to_ev = True
+ spawn_loc = None
+ while is_proximity_to_ev:
+ spawn_loc = self._world.get_random_location_from_navigation()
+ if spawn_loc is not None:
+ is_proximity_to_ev = proximity_to_ev(spawn_loc)
+ spawn_points.append(carla.Transform(location=spawn_loc))
+
+ batch = []
+ for spawn_point in spawn_points:
+ walker_bp = np.random.choice(walker_bp_library)
+ if walker_bp.has_attribute('is_invincible'):
+ walker_bp.set_attribute('is_invincible', 'false')
+ batch.append(SpawnActor(walker_bp, spawn_point))
+
+ for result in self._client.apply_batch_sync(batch, tick):
+ if not result.error:
+ num_spawned += 1
+ _walkers.append(result.actor_id)
+
+ batch = [SpawnActor(walker_controller_bp, carla.Transform(), walker) for walker in _walkers]
+ for result in self._client.apply_batch_sync(batch, tick):
+ if result.error:
+ self._logger.error(result.error)
+ else:
+ _controllers.append(result.actor_id)
+
+ controller_ids.extend(_controllers)
+ walker_ids.extend(_walkers)
+
+ n_trial += 1
+ if n_trial == max_trial and (num_zombie_walkers - num_spawned)>0:
+ self._logger.warning(f'{self._world.get_map().name}: '
+ f'Spawning zombie walkers max trial {n_trial} reached! '
+ f'spawned/to_spawn: {num_spawned}/{num_zombie_walkers}')
+ break
+
+ # wait for a tick to ensure client receives the last transform of the walkers we have just created
+ # self._world.tick()
+
+ for w_id, c_id in zip(walker_ids, controller_ids):
+ self.zombie_walkers[w_id] = ZombieWalker(w_id, c_id, self._world)
+
+ return self.zombie_walkers
+
+ def tick(self):
+ pass
+
+ def clean(self):
+ live_walkers_list = [walker.id for walker in self._world.get_actors().filter("*walker.pedestrian*")]
+
+ for zw_id, zw in self.zombie_walkers.items():
+ if zw_id in live_walkers_list:
+ zw.clean()
+
+ self.zombie_walkers = {}
diff --git a/carla_gym/envs/__init__.py b/carla_gym/envs/__init__.py
new file mode 100644
index 0000000..8efccdd
--- /dev/null
+++ b/carla_gym/envs/__init__.py
@@ -0,0 +1,9 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+from carla_gym.envs.suites.endless_env import EndlessEnv
+from carla_gym.envs.suites.leaderboard_env import LeaderboardEnv
+
+__all__ = [
+ 'EndlessEnv',
+ 'LeaderboardEnv',
+]
diff --git a/carla_gym/envs/scenario_descriptions/LeaderBoard/Town01/actors.json b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town01/actors.json
new file mode 100644
index 0000000..78e5152
--- /dev/null
+++ b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town01/actors.json
@@ -0,0 +1,7 @@
+{
+ "ego_vehicles":{
+ "hero": {
+ "model": "vehicle.lincoln.mkz_2017"
+ }
+ }
+}
diff --git a/carla_gym/envs/scenario_descriptions/LeaderBoard/Town01/routes.xml b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town01/routes.xml
new file mode 100644
index 0000000..7188845
--- /dev/null
+++ b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town01/routes.xml
@@ -0,0 +1,195 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/carla_gym/envs/scenario_descriptions/LeaderBoard/Town02/actors.json b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town02/actors.json
new file mode 100644
index 0000000..78e5152
--- /dev/null
+++ b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town02/actors.json
@@ -0,0 +1,7 @@
+{
+ "ego_vehicles":{
+ "hero": {
+ "model": "vehicle.lincoln.mkz_2017"
+ }
+ }
+}
diff --git a/carla_gym/envs/scenario_descriptions/LeaderBoard/Town02/routes.xml b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town02/routes.xml
new file mode 100644
index 0000000..f741b84
--- /dev/null
+++ b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town02/routes.xml
@@ -0,0 +1,214 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/carla_gym/envs/scenario_descriptions/LeaderBoard/Town03/actors.json b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town03/actors.json
new file mode 100644
index 0000000..78e5152
--- /dev/null
+++ b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town03/actors.json
@@ -0,0 +1,7 @@
+{
+ "ego_vehicles":{
+ "hero": {
+ "model": "vehicle.lincoln.mkz_2017"
+ }
+ }
+}
diff --git a/carla_gym/envs/scenario_descriptions/LeaderBoard/Town03/routes.xml b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town03/routes.xml
new file mode 100644
index 0000000..65f9da3
--- /dev/null
+++ b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town03/routes.xml
@@ -0,0 +1,774 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/carla_gym/envs/scenario_descriptions/LeaderBoard/Town04/actors.json b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town04/actors.json
new file mode 100644
index 0000000..78e5152
--- /dev/null
+++ b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town04/actors.json
@@ -0,0 +1,7 @@
+{
+ "ego_vehicles":{
+ "hero": {
+ "model": "vehicle.lincoln.mkz_2017"
+ }
+ }
+}
diff --git a/carla_gym/envs/scenario_descriptions/LeaderBoard/Town04/routes.xml b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town04/routes.xml
new file mode 100644
index 0000000..0db8b4a
--- /dev/null
+++ b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town04/routes.xml
@@ -0,0 +1,1042 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/carla_gym/envs/scenario_descriptions/LeaderBoard/Town04_test/actors.json b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town04_test/actors.json
new file mode 100644
index 0000000..78e5152
--- /dev/null
+++ b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town04_test/actors.json
@@ -0,0 +1,7 @@
+{
+ "ego_vehicles":{
+ "hero": {
+ "model": "vehicle.lincoln.mkz_2017"
+ }
+ }
+}
diff --git a/carla_gym/envs/scenario_descriptions/LeaderBoard/Town04_test/routes.xml b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town04_test/routes.xml
new file mode 100644
index 0000000..18ca664
--- /dev/null
+++ b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town04_test/routes.xml
@@ -0,0 +1,534 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/carla_gym/envs/scenario_descriptions/LeaderBoard/Town04_train/actors.json b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town04_train/actors.json
new file mode 100644
index 0000000..78e5152
--- /dev/null
+++ b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town04_train/actors.json
@@ -0,0 +1,7 @@
+{
+ "ego_vehicles":{
+ "hero": {
+ "model": "vehicle.lincoln.mkz_2017"
+ }
+ }
+}
diff --git a/carla_gym/envs/scenario_descriptions/LeaderBoard/Town04_train/routes.xml b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town04_train/routes.xml
new file mode 100644
index 0000000..38874cc
--- /dev/null
+++ b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town04_train/routes.xml
@@ -0,0 +1,511 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/carla_gym/envs/scenario_descriptions/LeaderBoard/Town05/actors.json b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town05/actors.json
new file mode 100644
index 0000000..78e5152
--- /dev/null
+++ b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town05/actors.json
@@ -0,0 +1,7 @@
+{
+ "ego_vehicles":{
+ "hero": {
+ "model": "vehicle.lincoln.mkz_2017"
+ }
+ }
+}
diff --git a/carla_gym/envs/scenario_descriptions/LeaderBoard/Town05/routes.xml b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town05/routes.xml
new file mode 100644
index 0000000..a5f4904
--- /dev/null
+++ b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town05/routes.xml
@@ -0,0 +1,373 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/carla_gym/envs/scenario_descriptions/LeaderBoard/Town06/actors.json b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town06/actors.json
new file mode 100644
index 0000000..78e5152
--- /dev/null
+++ b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town06/actors.json
@@ -0,0 +1,7 @@
+{
+ "ego_vehicles":{
+ "hero": {
+ "model": "vehicle.lincoln.mkz_2017"
+ }
+ }
+}
diff --git a/carla_gym/envs/scenario_descriptions/LeaderBoard/Town06/routes.xml b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town06/routes.xml
new file mode 100644
index 0000000..70aa152
--- /dev/null
+++ b/carla_gym/envs/scenario_descriptions/LeaderBoard/Town06/routes.xml
@@ -0,0 +1,391 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/carla_gym/envs/suites/endless_env.py b/carla_gym/envs/suites/endless_env.py
new file mode 100644
index 0000000..65aba79
--- /dev/null
+++ b/carla_gym/envs/suites/endless_env.py
@@ -0,0 +1,58 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+from carla_gym.carla_multi_agent_env import CarlaMultiAgentEnv
+
+
+class EndlessEnv(CarlaMultiAgentEnv):
+ def __init__(self, carla_map, host, port, seed, no_rendering, obs_configs, reward_configs, terminal_configs,
+ num_zombie_vehicles, num_zombie_walkers, weather_group):
+ all_tasks = self.build_all_tasks(num_zombie_vehicles, num_zombie_walkers, weather_group)
+ super().__init__(carla_map, host, port, seed, no_rendering,
+ obs_configs, reward_configs, terminal_configs, all_tasks)
+
+ @staticmethod
+ def build_all_tasks(num_zombie_vehicles, num_zombie_walkers, weather_group):
+ if weather_group == 'new':
+ weathers = ['SoftRainSunset', 'WetSunset']
+ elif weather_group == 'train':
+ weathers = ['ClearNoon', 'WetNoon', 'HardRainNoon', 'ClearSunset']
+ elif weather_group == 'all':
+ weathers = ['Default', 'ClearNoon', 'CloudyNoon', 'WetNoon', 'WetCloudyNoon', 'SoftRainNoon',
+ 'MidRainyNoon', 'HardRainNoon', 'ClearSunset', 'CloudySunset', 'WetSunset', 'WetCloudySunset',
+ 'SoftRainSunset', 'MidRainSunset', 'HardRainSunset']
+ else:
+ weathers = [weather_group]
+
+ actor_configs_dict = {
+ 'ego_vehicles': {
+ 'hero': {'model': 'vehicle.lincoln.mkz_2017'}
+ }
+ }
+ route_descriptions_dict = {
+ 'ego_vehicles': {
+ 'hero': []
+ }
+ }
+ endless_dict = {
+ 'ego_vehicles': {
+ 'hero': True
+ }
+ }
+ all_tasks = []
+ for weather in weathers:
+ task = {
+ 'weather': weather,
+ 'description_folder': 'None',
+ 'route_id': 0,
+ 'num_zombie_vehicles': num_zombie_vehicles,
+ 'num_zombie_walkers': num_zombie_walkers,
+ 'ego_vehicles': {
+ 'routes': route_descriptions_dict['ego_vehicles'],
+ 'actors': actor_configs_dict['ego_vehicles'],
+ 'endless': endless_dict['ego_vehicles']
+ },
+ 'scenario_actors': {},
+ }
+ all_tasks.append(task)
+
+ return all_tasks
diff --git a/carla_gym/envs/suites/leaderboard_env.py b/carla_gym/envs/suites/leaderboard_env.py
new file mode 100644
index 0000000..c06ad5d
--- /dev/null
+++ b/carla_gym/envs/suites/leaderboard_env.py
@@ -0,0 +1,85 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+from carla_gym import CARLA_GYM_ROOT_DIR
+from carla_gym.carla_multi_agent_env import CarlaMultiAgentEnv
+from carla_gym.utils import config_utils
+import json
+
+
+class LeaderboardEnv(CarlaMultiAgentEnv):
+ def __init__(self, carla_map, host, port, seed, no_rendering, obs_configs, reward_configs, terminal_configs,
+ weather_group, routes_group):
+
+ all_tasks = self.build_all_tasks(carla_map, weather_group, routes_group)
+ super().__init__(carla_map, host, port, seed, no_rendering,
+ obs_configs, reward_configs, terminal_configs, all_tasks)
+
+ @staticmethod
+ def build_all_tasks(carla_map, weather_group, routes_group):
+ assert carla_map in ['Town01', 'Town02', 'Town03', 'Town04', 'Town05', 'Town06']
+ num_zombie_vehicles = {
+ 'Town01': 120,
+ 'Town02': 70,
+ 'Town03': 70,
+ 'Town04': 150,
+ 'Town05': 120,
+ 'Town06': 120
+ }
+ num_zombie_walkers = {
+ 'Town01': 120,
+ 'Town02': 70,
+ 'Town03': 70,
+ 'Town04': 80,
+ 'Town05': 120,
+ 'Town06': 80
+ }
+
+ # weather
+ if weather_group == 'new':
+ weathers = ['SoftRainSunset', 'WetSunset', 'CloudyNoon', 'MidRainSunset']
+ elif weather_group == 'many_weathers':
+ weathers = ['SoftRainSunset', 'WetSunset', 'ClearNoon', 'WetNoon', 'HardRainNoon', 'ClearSunset']
+ elif weather_group == 'train':
+ weathers = ['ClearNoon', 'WetNoon', 'HardRainNoon', 'ClearSunset']
+ elif weather_group == 'simple':
+ weathers = ['ClearNoon']
+ elif weather_group == 'train_eval':
+ weathers = ['WetNoon', 'ClearSunset']
+ elif weather_group == 'all':
+ weathers = ['ClearNoon', 'CloudyNoon', 'WetNoon', 'WetCloudyNoon', 'SoftRainNoon', 'MidRainyNoon',
+ 'HardRainNoon', 'ClearSunset', 'CloudySunset', 'WetSunset', 'WetCloudySunset',
+ 'SoftRainSunset', 'MidRainSunset', 'HardRainSunset']
+ else:
+ weathers = [weather_group]
+
+ # task_type setup
+ if carla_map == 'Town04' and routes_group is not None:
+ description_folder = CARLA_GYM_ROOT_DIR / 'envs/scenario_descriptions/LeaderBoard' \
+ / f'Town04_{routes_group}'
+ else:
+ description_folder = CARLA_GYM_ROOT_DIR / 'envs/scenario_descriptions/LeaderBoard' / carla_map
+
+ actor_configs_dict = json.load(open(description_folder / 'actors.json'))
+ route_descriptions_dict = config_utils.parse_routes_file(description_folder / 'routes.xml')
+
+ all_tasks = []
+ for weather in weathers:
+ for route_id, route_description in route_descriptions_dict.items():
+ task = {
+ 'weather': weather,
+ 'description_folder': description_folder,
+ 'route_id': route_id,
+ 'num_zombie_vehicles': num_zombie_vehicles[carla_map],
+ 'num_zombie_walkers': num_zombie_walkers[carla_map],
+ 'ego_vehicles': {
+ 'routes': route_description['ego_vehicles'],
+ 'actors': actor_configs_dict['ego_vehicles'],
+ },
+ 'scenario_actors': {
+ 'routes': route_description['scenario_actors'],
+ 'actors': actor_configs_dict['scenario_actors']
+ } if 'scenario_actors' in actor_configs_dict else {}
+ }
+ all_tasks.append(task)
+
+ return all_tasks
diff --git a/carla_gym/utils/birdview_map.py b/carla_gym/utils/birdview_map.py
new file mode 100644
index 0000000..02bf2b9
--- /dev/null
+++ b/carla_gym/utils/birdview_map.py
@@ -0,0 +1,266 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import carla
+import pygame
+import numpy as np
+import h5py
+from pathlib import Path
+import os
+import argparse
+import time
+import subprocess
+from omegaconf import OmegaConf
+
+from carla_gym.utils.traffic_light import TrafficLightHandler
+# from utils.server_utils import CarlaServerManager
+
+COLOR_WHITE = (255, 255, 255)
+
+
+class MapImage(object):
+
+ @staticmethod
+ def draw_map_image(carla_map, pixels_per_meter, precision=0.05):
+
+ waypoints = carla_map.generate_waypoints(2)
+ margin = 100
+ max_x = max(waypoints, key=lambda x: x.transform.location.x).transform.location.x + margin
+ max_y = max(waypoints, key=lambda x: x.transform.location.y).transform.location.y + margin
+ min_x = min(waypoints, key=lambda x: x.transform.location.x).transform.location.x - margin
+ min_y = min(waypoints, key=lambda x: x.transform.location.y).transform.location.y - margin
+
+ world_offset = np.array([min_x, min_y], dtype=np.float32)
+ width_in_meters = max(max_x - min_x, max_y - min_y)
+ width_in_pixels = round(pixels_per_meter * width_in_meters)
+
+ road_surface = pygame.Surface((width_in_pixels, width_in_pixels))
+ shoulder_surface = pygame.Surface((width_in_pixels, width_in_pixels))
+ parking_surface = pygame.Surface((width_in_pixels, width_in_pixels))
+ sidewalk_surface = pygame.Surface((width_in_pixels, width_in_pixels))
+ lane_marking_yellow_broken_surface = pygame.Surface((width_in_pixels, width_in_pixels))
+ lane_marking_yellow_solid_surface = pygame.Surface((width_in_pixels, width_in_pixels))
+ lane_marking_white_broken_surface = pygame.Surface((width_in_pixels, width_in_pixels))
+ lane_marking_white_solid_surface = pygame.Surface((width_in_pixels, width_in_pixels))
+ lane_marking_all_surface = pygame.Surface((width_in_pixels, width_in_pixels))
+
+ topology = [x[0] for x in carla_map.get_topology()]
+ topology = sorted(topology, key=lambda w: w.transform.location.z)
+
+ for waypoint in topology:
+ waypoints = [waypoint]
+ # Generate waypoints of a road id. Stop when road id differs
+ nxt = waypoint.next(precision)
+ if len(nxt) > 0:
+ nxt = nxt[0]
+ while nxt.road_id == waypoint.road_id:
+ waypoints.append(nxt)
+ nxt = nxt.next(precision)
+ if len(nxt) > 0:
+ nxt = nxt[0]
+ else:
+ break
+ # Draw Shoulders, Parkings and Sidewalks
+ shoulder = [[], []]
+ parking = [[], []]
+ sidewalk = [[], []]
+
+ for w in waypoints:
+ # Classify lane types until there are no waypoints by going left
+ l = w.get_left_lane()
+ while l and l.lane_type != carla.LaneType.Driving:
+ if l.lane_type == carla.LaneType.Shoulder:
+ shoulder[0].append(l)
+ if l.lane_type == carla.LaneType.Parking:
+ parking[0].append(l)
+ if l.lane_type == carla.LaneType.Sidewalk:
+ sidewalk[0].append(l)
+ l = l.get_left_lane()
+ # Classify lane types until there are no waypoints by going right
+ r = w.get_right_lane()
+ while r and r.lane_type != carla.LaneType.Driving:
+ if r.lane_type == carla.LaneType.Shoulder:
+ shoulder[1].append(r)
+ if r.lane_type == carla.LaneType.Parking:
+ parking[1].append(r)
+ if r.lane_type == carla.LaneType.Sidewalk:
+ sidewalk[1].append(r)
+ r = r.get_right_lane()
+
+ MapImage.draw_lane(road_surface, waypoints, COLOR_WHITE, pixels_per_meter, world_offset)
+
+ MapImage.draw_lane(sidewalk_surface, sidewalk[0], COLOR_WHITE, pixels_per_meter, world_offset)
+ MapImage.draw_lane(sidewalk_surface, sidewalk[1], COLOR_WHITE, pixels_per_meter, world_offset)
+ MapImage.draw_lane(shoulder_surface, shoulder[0], COLOR_WHITE, pixels_per_meter, world_offset)
+ MapImage.draw_lane(shoulder_surface, shoulder[1], COLOR_WHITE, pixels_per_meter, world_offset)
+ MapImage.draw_lane(parking_surface, parking[0], COLOR_WHITE, pixels_per_meter, world_offset)
+ MapImage.draw_lane(parking_surface, parking[1], COLOR_WHITE, pixels_per_meter, world_offset)
+
+ if not waypoint.is_junction:
+ MapImage.draw_lane_marking_single_side(
+ lane_marking_yellow_broken_surface,
+ lane_marking_yellow_solid_surface,
+ lane_marking_white_broken_surface,
+ lane_marking_white_solid_surface,
+ lane_marking_all_surface,
+ waypoints, -1, pixels_per_meter, world_offset)
+ MapImage.draw_lane_marking_single_side(
+ lane_marking_yellow_broken_surface,
+ lane_marking_yellow_solid_surface,
+ lane_marking_white_broken_surface,
+ lane_marking_white_solid_surface,
+ lane_marking_all_surface,
+ waypoints, 1, pixels_per_meter, world_offset)
+
+ # stoplines
+ stopline_surface = pygame.Surface((width_in_pixels, width_in_pixels))
+
+ for stopline_vertices in TrafficLightHandler.list_stopline_vtx:
+ for loc_left, loc_right in stopline_vertices:
+ stopline_points = [
+ MapImage.world_to_pixel(loc_left, pixels_per_meter, world_offset),
+ MapImage.world_to_pixel(loc_right, pixels_per_meter, world_offset)
+ ]
+ MapImage.draw_line(stopline_surface, stopline_points, 2)
+
+ # np.uint8 mask
+ def _make_mask(x):
+ return pygame.surfarray.array3d(x)[..., 0].astype(np.uint8)
+ # make a dict
+ dict_masks = {
+ 'road': _make_mask(road_surface),
+ 'shoulder': _make_mask(shoulder_surface),
+ 'parking': _make_mask(parking_surface),
+ 'sidewalk': _make_mask(sidewalk_surface),
+ 'lane_marking_yellow_broken': _make_mask(lane_marking_yellow_broken_surface),
+ 'lane_marking_yellow_solid': _make_mask(lane_marking_yellow_solid_surface),
+ 'lane_marking_white_broken': _make_mask(lane_marking_white_broken_surface),
+ 'lane_marking_white_solid': _make_mask(lane_marking_white_solid_surface),
+ 'lane_marking_all': _make_mask(lane_marking_all_surface),
+ 'stopline': _make_mask(stopline_surface),
+ 'world_offset': world_offset,
+ 'pixels_per_meter': pixels_per_meter,
+ 'width_in_meters': width_in_meters,
+ 'width_in_pixels': width_in_pixels
+ }
+ return dict_masks
+
+ @staticmethod
+ def draw_lane_marking_single_side(lane_marking_yellow_broken_surface,
+ lane_marking_yellow_solid_surface,
+ lane_marking_white_broken_surface,
+ lane_marking_white_solid_surface,
+ lane_marking_all_surface,
+ waypoints, sign, pixels_per_meter, world_offset):
+ """Draws the lane marking given a set of waypoints and decides whether drawing the right or left side of
+ the waypoint based on the sign parameter"""
+ lane_marking = None
+
+ previous_marking_type = carla.LaneMarkingType.NONE
+ previous_marking_color = carla.LaneMarkingColor.Other
+ current_lane_marking = carla.LaneMarkingType.NONE
+
+ markings_list = []
+ temp_waypoints = []
+ for sample in waypoints:
+ lane_marking = sample.left_lane_marking if sign < 0 else sample.right_lane_marking
+
+ if lane_marking is None:
+ continue
+
+ if current_lane_marking != lane_marking.type:
+ # Get the list of lane markings to draw
+ markings = MapImage.get_lane_markings(
+ previous_marking_type, previous_marking_color, temp_waypoints, sign, pixels_per_meter, world_offset)
+ current_lane_marking = lane_marking.type
+
+ # Append each lane marking in the list
+ for marking in markings:
+ markings_list.append(marking)
+
+ temp_waypoints = temp_waypoints[-1:]
+
+ else:
+ temp_waypoints.append((sample))
+ previous_marking_type = lane_marking.type
+ previous_marking_color = lane_marking.color
+
+ # Add last marking
+ last_markings = MapImage.get_lane_markings(
+ previous_marking_type, previous_marking_color, temp_waypoints, sign, pixels_per_meter, world_offset)
+ for marking in last_markings:
+ markings_list.append(marking)
+
+ # Once the lane markings have been simplified to Solid or Broken lines, we draw them
+ for markings in markings_list:
+ if markings[1] == carla.LaneMarkingColor.White and markings[0] == carla.LaneMarkingType.Solid:
+ MapImage.draw_line(lane_marking_white_solid_surface, markings[2], 1)
+ elif markings[1] == carla.LaneMarkingColor.Yellow and markings[0] == carla.LaneMarkingType.Solid:
+ MapImage.draw_line(lane_marking_yellow_solid_surface, markings[2], 1)
+ elif markings[1] == carla.LaneMarkingColor.White and markings[0] == carla.LaneMarkingType.Broken:
+ MapImage.draw_line(lane_marking_white_broken_surface, markings[2], 1)
+ elif markings[1] == carla.LaneMarkingColor.Yellow and markings[0] == carla.LaneMarkingType.Broken:
+ MapImage.draw_line(lane_marking_yellow_broken_surface, markings[2], 1)
+
+ MapImage.draw_line(lane_marking_all_surface, markings[2], 1)
+
+ @staticmethod
+ def get_lane_markings(lane_marking_type, lane_marking_color, waypoints, sign, pixels_per_meter, world_offset):
+ """For multiple lane marking types (SolidSolid, BrokenSolid, SolidBroken and BrokenBroken), it converts them
+ as a combination of Broken and Solid lines"""
+ margin = 0.25
+ marking_1 = [MapImage.world_to_pixel(
+ MapImage.lateral_shift(w.transform, sign * w.lane_width * 0.5),
+ pixels_per_meter, world_offset) for w in waypoints]
+
+ if lane_marking_type == carla.LaneMarkingType.Broken or (lane_marking_type == carla.LaneMarkingType.Solid):
+ return [(lane_marking_type, lane_marking_color, marking_1)]
+ else:
+ marking_2 = [
+ MapImage.world_to_pixel(
+ MapImage.lateral_shift(w.transform, sign * (w.lane_width * 0.5 + margin * 2)),
+ pixels_per_meter, world_offset) for w in waypoints]
+ if lane_marking_type == carla.LaneMarkingType.SolidBroken:
+ return [(carla.LaneMarkingType.Broken, lane_marking_color, marking_1),
+ (carla.LaneMarkingType.Solid, lane_marking_color, marking_2)]
+ elif lane_marking_type == carla.LaneMarkingType.BrokenSolid:
+ return [(carla.LaneMarkingType.Solid, lane_marking_color, marking_1),
+ (carla.LaneMarkingType.Broken, lane_marking_color, marking_2)]
+ elif lane_marking_type == carla.LaneMarkingType.BrokenBroken:
+ return [(carla.LaneMarkingType.Broken, lane_marking_color, marking_1),
+ (carla.LaneMarkingType.Broken, lane_marking_color, marking_2)]
+ elif lane_marking_type == carla.LaneMarkingType.SolidSolid:
+ return [(carla.LaneMarkingType.Solid, lane_marking_color, marking_1),
+ (carla.LaneMarkingType.Solid, lane_marking_color, marking_2)]
+ return [(carla.LaneMarkingType.NONE, lane_marking_color, marking_1)]
+
+ @staticmethod
+ def draw_line(surface, points, width):
+ """Draws solid lines in a surface given a set of points, width and color"""
+ if len(points) >= 2:
+ pygame.draw.lines(surface, COLOR_WHITE, False, points, width)
+
+ @staticmethod
+ def draw_lane(surface, wp_list, color, pixels_per_meter, world_offset):
+ """Renders a single lane in a surface and with a specified color"""
+ lane_left_side = [MapImage.lateral_shift(w.transform, -w.lane_width * 0.5) for w in wp_list]
+ lane_right_side = [MapImage.lateral_shift(w.transform, w.lane_width * 0.5) for w in wp_list]
+
+ polygon = lane_left_side + [x for x in reversed(lane_right_side)]
+ polygon = [MapImage.world_to_pixel(x, pixels_per_meter, world_offset) for x in polygon]
+
+ if len(polygon) > 2:
+ pygame.draw.polygon(surface, color, polygon, 5)
+ pygame.draw.polygon(surface, color, polygon)
+
+ @staticmethod
+ def lateral_shift(transform, shift):
+ """Makes a lateral shift of the forward vector of a transform"""
+ transform.rotation.yaw += 90
+ return transform.location + shift * transform.get_forward_vector()
+
+ @staticmethod
+ def world_to_pixel(location, pixels_per_meter, world_offset):
+ """Converts the world coordinates to pixel coordinates"""
+ x = pixels_per_meter * (location.x - world_offset[0])
+ y = pixels_per_meter * (location.y - world_offset[1])
+ return [round(y), round(x)]
diff --git a/carla_gym/utils/config_utils.py b/carla_gym/utils/config_utils.py
new file mode 100644
index 0000000..f37dc24
--- /dev/null
+++ b/carla_gym/utils/config_utils.py
@@ -0,0 +1,150 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+from importlib import import_module
+import json
+from pathlib import Path
+import socket
+import xml.etree.ElementTree as ET
+import h5py
+import carla
+import numpy as np
+import hydra
+
+
+def check_h5_maps(env_configs, obs_configs, carla_sh_path):
+ pixels_per_meter = None
+ for agent_id, obs_cfg in obs_configs.items():
+ for k,v in obs_cfg.items():
+ if 'birdview' in v['module']:
+ pixels_per_meter = float(v['pixels_per_meter'])
+
+ if pixels_per_meter is None:
+ # agent does not require birdview map as observation
+ return
+
+ save_dir = Path(hydra.utils.get_original_cwd()) / 'carla_gym/core/obs_manager/birdview/maps'
+ txt_command = f'Please run map generation script from project root directory. \n' \
+ f'\033[93m' \
+ f'python -m carla_gym.utils.birdview_map ' \
+ f'--save_dir {save_dir} --pixels_per_meter {pixels_per_meter:.2f} ' \
+ f'--carla_sh_path {carla_sh_path}' \
+ f'\033[0m'
+
+ # check if pixels_per_meter match
+ for env_cfg in env_configs:
+ carla_map = env_cfg['env_configs']['carla_map']
+ hf_file_path = save_dir / (carla_map+'.h5')
+
+ file_exists = hf_file_path.exists()
+ if file_exists:
+ map_hf = h5py.File(hf_file_path, 'r')
+ hf_pixels_per_meter = float(map_hf.attrs['pixels_per_meter'])
+ map_hf.close()
+ pixels_per_meter_match = np.isclose(hf_pixels_per_meter, pixels_per_meter)
+ txt_assert = f'pixel_per_meter mismatch between h5 file ({hf_pixels_per_meter}) '\
+ f'and obs_config ({pixels_per_meter}). '
+ else:
+ txt_assert = f'{hf_file_path} does not exists. '
+ pixels_per_meter_match = False
+
+ assert file_exists and pixels_per_meter_match, txt_assert + txt_command
+
+
+def load_entry_point(name):
+ mod_name, attr_name = name.split(":")
+ mod = import_module(mod_name)
+ fn = getattr(mod, attr_name)
+ return fn
+
+
+def load_obs_configs(agent_configs_dict):
+ obs_configs = {}
+ for actor_id, cfg in agent_configs_dict.items():
+ obs_configs[actor_id] = json.load(open(cfg['path_to_conf_file'], 'r'))['obs_configs']
+ return obs_configs
+
+
+def init_agents(agent_configs_dict, **kwargs):
+ agents_dict = {}
+ for actor_id, cfg in agent_configs_dict.items():
+ AgentClass = load_entry_point(cfg['entry_point'])
+ agents_dict[actor_id] = AgentClass(cfg['path_to_conf_file'], **kwargs)
+ return agents_dict
+
+
+def parse_routes_file(routes_xml_filename):
+ route_descriptions_dict = {}
+ tree = ET.parse(routes_xml_filename)
+
+ for route in tree.iter("route"):
+
+ route_id = int(route.attrib['id'])
+
+ route_descriptions_dict[route_id] = {}
+
+ for actor_type in ['ego_vehicle', 'scenario_actor']:
+ route_descriptions_dict[route_id][actor_type+'s'] = {}
+ for actor in route.iter(actor_type):
+ actor_id = actor.attrib['id']
+
+ waypoint_list = [] # the list of waypoints that can be found on this route for this actor
+ for waypoint in actor.iter('waypoint'):
+ location = carla.Location(
+ x=float(waypoint.attrib['x']),
+ y=float(waypoint.attrib['y']),
+ z=float(waypoint.attrib['z']))
+ rotation = carla.Rotation(
+ roll=float(waypoint.attrib['roll']),
+ pitch=float(waypoint.attrib['pitch']),
+ yaw=float(waypoint.attrib['yaw']))
+ waypoint_list.append(carla.Transform(location, rotation))
+
+ route_descriptions_dict[route_id][actor_type+'s'][actor_id] = waypoint_list
+
+ return route_descriptions_dict
+
+
+def get_single_route(routes_xml_filename, route_id):
+ tree = ET.parse(routes_xml_filename)
+ route = tree.find(f'.//route[@id="{route_id}"]')
+
+ route_dict = {}
+ for actor_type in ['ego_vehicle', 'scenario_actor']:
+ route_dict[actor_type+'s'] = {}
+ for actor in route.iter(actor_type):
+ actor_id = actor.attrib['id']
+
+ waypoint_list = [] # the list of waypoints that can be found on this route for this actor
+ for waypoint in actor.iter('waypoint'):
+ location = carla.Location(
+ x=float(waypoint.attrib['x']),
+ y=float(waypoint.attrib['y']),
+ z=float(waypoint.attrib['z']))
+ rotation = carla.Rotation(
+ roll=float(waypoint.attrib['roll']),
+ pitch=float(waypoint.attrib['pitch']),
+ yaw=float(waypoint.attrib['yaw']))
+ waypoint_list.append(carla.Transform(location, rotation))
+
+ route_dict[actor_type+'s'][actor_id] = waypoint_list
+ return route_dict
+
+
+def to_camel_case(snake_str, init_capital=False):
+ # agent_class_str = to_camel_case(agent_module_str.split('.')[-1], init_capital=True)
+ components = snake_str.split('_')
+ if init_capital:
+ init_letter = components[0].title()
+ else:
+ init_letter = components[0]
+ return init_letter + ''.join(x.title() for x in components[1:])
+
+
+def get_free_tcp_port():
+ s = socket.socket()
+ s.bind(("", 0)) # Request the sys to provide a free port dynamically
+ server_port = s.getsockname()[1]
+ s.close()
+ # 2000 works fine for now
+ server_port = 2000
+ return server_port
diff --git a/carla_gym/utils/dynamic_weather.py b/carla_gym/utils/dynamic_weather.py
new file mode 100644
index 0000000..9103d70
--- /dev/null
+++ b/carla_gym/utils/dynamic_weather.py
@@ -0,0 +1,127 @@
+# modified from https://github.com/carla-simulator/carla/blob/master/PythonAPI/examples/dynamic_weather.py
+
+import carla
+import numpy as np
+from constants import CARLA_FPS
+
+WEATHERS = [
+ carla.WeatherParameters.Default,
+
+ carla.WeatherParameters.ClearNoon,
+ carla.WeatherParameters.ClearSunset,
+
+ carla.WeatherParameters.CloudyNoon,
+ carla.WeatherParameters.CloudySunset,
+
+ carla.WeatherParameters.WetNoon,
+ carla.WeatherParameters.WetSunset,
+
+ carla.WeatherParameters.MidRainyNoon,
+ carla.WeatherParameters.MidRainSunset,
+
+ carla.WeatherParameters.WetCloudyNoon,
+ carla.WeatherParameters.WetCloudySunset,
+
+ carla.WeatherParameters.HardRainNoon,
+ carla.WeatherParameters.HardRainSunset,
+
+ carla.WeatherParameters.SoftRainNoon,
+ carla.WeatherParameters.SoftRainSunset,
+]
+
+
+def clamp(value, minimum=0.0, maximum=100.0):
+ return max(minimum, min(value, maximum))
+
+
+class Sun(object):
+ def __init__(self, azimuth, altitude):
+ self.azimuth = azimuth
+ self.altitude = altitude
+ self._t = np.random.uniform(0.0, 2.0*np.pi)
+
+ def tick(self, delta_seconds):
+ self._t += 0.008 * delta_seconds
+ self._t %= 2.0 * np.pi
+ self.azimuth += 0.25 * delta_seconds
+ self.azimuth %= 360.0
+ self.altitude = (55 * np.sin(self._t)) + 35
+
+ def __str__(self):
+ return 'Sun(alt: %.2f, azm: %.2f)' % (self.altitude, self.azimuth)
+
+
+class Storm(object):
+ def __init__(self, precipitation):
+ self._t = precipitation if precipitation > 0.0 else -50.0
+ self._increasing = True
+ self.clouds = 0.0
+ self.rain = 0.0
+ self.wetness = 0.0
+ self.puddles = 0.0
+ self.wind = 0.0
+ self.fog = 0.0
+
+ def tick(self, delta_seconds):
+ delta = (1.3 if self._increasing else -1.3) * delta_seconds
+ self._t = clamp(delta + self._t, -250.0, 100.0)
+ self.clouds = clamp(self._t + 40.0, 0.0, 90.0)
+ self.rain = clamp(self._t, 0.0, 80.0)
+ delay = -10.0 if self._increasing else 90.0
+ self.puddles = clamp(self._t + delay, 0.0, 85.0)
+ self.wetness = clamp(self._t * 5, 0.0, 100.0)
+ self.wind = 5.0 if self.clouds <= 20 else 90 if self.clouds >= 70 else 40
+ self.fog = clamp(self._t - 10, 0.0, 30.0)
+ if self._t == -250.0:
+ self._increasing = True
+ if self._t == 100.0:
+ self._increasing = False
+
+ def __str__(self):
+ return 'Storm(clouds=%d%%, rain=%d%%, wind=%d%%)' % (self.clouds, self.rain, self.wind)
+
+
+class WeatherHandler(object):
+ def __init__(self, world):
+ self._world = world
+ self._dynamic = False
+
+ def reset(self, cfg_weather):
+ if hasattr(carla.WeatherParameters, cfg_weather):
+ self._world.set_weather(getattr(carla.WeatherParameters, cfg_weather))
+ self._dynamic = False
+ elif 'dynamic' in cfg_weather:
+ self._weather = np.random.choice(WEATHERS)
+ self._sun = Sun(self._weather.sun_azimuth_angle, self._weather.sun_altitude_angle)
+ self._storm = Storm(self._weather.precipitation)
+ self._dynamic = True
+ l = cfg_weather.split('_')
+ if len(l) == 2:
+ self._speed_factor = float(l[1])
+ else:
+ self._speed_factor = 1.0
+ self.tick(1.0 / CARLA_FPS)
+ else:
+ self._world.set_weather('ClearNoon')
+ self._dynamic = False
+
+ def tick(self, delta_seconds):
+ if self._dynamic:
+ self._sun.tick(delta_seconds * self._speed_factor)
+ self._storm.tick(delta_seconds * self._speed_factor)
+ self._weather.cloudiness = self._storm.clouds
+ self._weather.precipitation = self._storm.rain
+ self._weather.precipitation_deposits = self._storm.puddles
+ self._weather.wind_intensity = self._storm.wind
+ self._weather.fog_density = self._storm.fog
+ self._weather.wetness = self._storm.wetness
+ self._weather.sun_azimuth_angle = self._sun.azimuth
+ self._weather.sun_altitude_angle = self._sun.altitude
+ self._world.set_weather(self._weather)
+
+ def clean(self):
+ if self._dynamic:
+ self._weather = None
+ self._sun = None
+ self._storm = None
+ self._dynamic = False
diff --git a/carla_gym/utils/gps_utils.py b/carla_gym/utils/gps_utils.py
new file mode 100644
index 0000000..e1bb95e
--- /dev/null
+++ b/carla_gym/utils/gps_utils.py
@@ -0,0 +1,30 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import math
+
+EARTH_RADIUS_EQUA = 6378137.0
+
+
+def gps2xyz(lat, lon, z, lat_ref=49.0, lon_ref=8.0):
+ # pylint: disable=invalid-name
+ scale = math.cos(lat_ref * math.pi / 180.0)
+
+ mx = lon / 180.0 * (math.pi * EARTH_RADIUS_EQUA * scale)
+ my = math.log(math.tan((lat+90.0)*math.pi/360.0))*(EARTH_RADIUS_EQUA * scale)
+
+ x = mx - scale * lon_ref * math.pi * EARTH_RADIUS_EQUA / 180.0
+ y = scale * EARTH_RADIUS_EQUA * math.log(math.tan((90.0 + lat_ref) * math.pi / 360.0)) - my
+
+ return x, y, z
+
+
+def xyz2gps(x, y, z, lat_ref=49.0, lon_ref=8.0):
+ scale = math.cos(lat_ref * math.pi / 180.0)
+ mx = scale * lon_ref * math.pi * EARTH_RADIUS_EQUA / 180.0
+ my = scale * EARTH_RADIUS_EQUA * math.log(math.tan((90.0 + lat_ref) * math.pi / 360.0))
+ mx += x
+ my -= y
+
+ lon = mx * 180.0 / (math.pi * EARTH_RADIUS_EQUA * scale)
+ lat = 360.0 * math.atan(math.exp(my / (EARTH_RADIUS_EQUA * scale))) / math.pi - 90.0
+ return lat, lon, z
diff --git a/carla_gym/utils/hazard_actor.py b/carla_gym/utils/hazard_actor.py
new file mode 100644
index 0000000..23f8561
--- /dev/null
+++ b/carla_gym/utils/hazard_actor.py
@@ -0,0 +1,188 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import numpy as np
+
+
+def is_within_distance_ahead(target_location, max_distance, up_angle_th=60):
+ distance = np.linalg.norm(target_location[0:2])
+ if distance < 0.001:
+ return True
+ if distance > max_distance:
+ return False
+ x = target_location[0]
+ y = target_location[1]
+ angle = np.rad2deg(np.arctan2(y, x))
+ return abs(angle) < up_angle_th
+
+
+def lbc_hazard_vehicle(obs_surrounding_vehicles, ev_speed=None, proximity_threshold=9.5):
+ for i, is_valid in enumerate(obs_surrounding_vehicles['binary_mask']):
+ if not is_valid:
+ continue
+
+ sv_yaw = obs_surrounding_vehicles['rotation'][i][2]
+ same_heading = abs(sv_yaw) <= 150
+
+ sv_loc = obs_surrounding_vehicles['location'][i]
+ with_distance_ahead = is_within_distance_ahead(sv_loc, proximity_threshold, up_angle_th=45)
+ if same_heading and with_distance_ahead:
+ return sv_loc
+ return None
+
+
+def lbc_hazard_walker(obs_surrounding_pedestrians, ev_speed=None, proximity_threshold=9.5):
+ for i, is_valid in enumerate(obs_surrounding_pedestrians['binary_mask']):
+ if not is_valid:
+ continue
+ if int(obs_surrounding_pedestrians['on_sidewalk'][i]) == 1:
+ continue
+
+ ped_loc = obs_surrounding_pedestrians['location'][i]
+
+ dist = np.linalg.norm(ped_loc)
+ degree = 162 / (np.clip(dist, 1.5, 10.5)+0.3)
+
+ if is_within_distance_ahead(ped_loc, proximity_threshold, up_angle_th=degree):
+ return ped_loc
+ return None
+
+
+def get_collision(p1, v1, p2, v2):
+ A = np.stack([v1, -v2], 1)
+ b = p2 - p1
+
+ if abs(np.linalg.det(A)) < 1e-3:
+ return False, None
+
+ x = np.linalg.solve(A, b)
+ collides = all(x >= 0) and all(x <= 1)
+
+ return collides, p1 + x[0] * v1
+
+
+def challenge_hazard_walker(obs_surrounding_pedestrians, ev_speed=None):
+ p1 = np.float32([0, 0])
+ v1 = np.float32([10, 0])
+
+ for i, is_valid in enumerate(obs_surrounding_pedestrians['binary_mask']):
+ if not is_valid:
+ continue
+
+ ped_loc = obs_surrounding_pedestrians['location'][i]
+ ped_yaw = obs_surrounding_pedestrians['rotation'][i][2]
+ ped_vel = obs_surrounding_pedestrians['absolute_velocity'][i]
+
+ v2_hat = np.float32([np.cos(np.radians(ped_yaw)), np.sin(np.radians(ped_yaw))])
+ s2 = np.linalg.norm(ped_vel)
+
+ if s2 < 0.05:
+ v2_hat *= s2
+
+ p2 = -3.0 * v2_hat + ped_loc[0:2]
+ v2 = 8.0 * v2_hat
+
+ collides, collision_point = get_collision(p1, v1, p2, v2)
+
+ if collides:
+ return ped_loc
+ return None
+
+
+def challenge_hazard_vehicle(obs_surrounding_vehicles, ev_speed):
+ # np.linalg.norm(_numpy(self._vehicle.get_velocity())
+ o1 = np.float32([1, 0])
+ p1 = np.float32([0, 0])
+ s1 = max(9.5, 2.0 * ev_speed)
+ v1_hat = o1
+ v1 = s1 * v1_hat
+
+ for i, is_valid in enumerate(obs_surrounding_vehicles['binary_mask']):
+ if not is_valid:
+ continue
+
+ sv_loc = obs_surrounding_vehicles['location'][i]
+ sv_yaw = obs_surrounding_vehicles['rotation'][i][2]
+ sv_vel = obs_surrounding_vehicles['absolute_velocity'][i]
+
+ o2 = np.float32([np.cos(np.radians(sv_yaw)), np.sin(np.radians(sv_yaw))])
+ p2 = sv_loc[0:2]
+ s2 = max(5.0, 2.0 * np.linalg.norm(sv_vel[0:2]))
+ v2_hat = o2
+ v2 = s2 * v2_hat
+
+ p2_p1 = p2 - p1
+ distance = np.linalg.norm(p2_p1)
+ p2_p1_hat = p2_p1 / (distance + 1e-4)
+
+ angle_to_car = np.degrees(np.arccos(v1_hat.dot(p2_p1_hat)))
+ angle_between_heading = np.degrees(np.arccos(o1.dot(o2)))
+
+ if angle_between_heading > 60.0 and not (angle_to_car < 15 and distance < s1):
+ continue
+ elif angle_to_car > 30.0:
+ continue
+ elif distance > s1:
+ continue
+
+ return sv_loc
+
+ return None
+
+
+def behavior_hazard_vehicle(ego_vehicle, actors, route_plan, proximity_th, up_angle_th, lane_offset=0, at_junction=False):
+ '''
+ ego_vehicle: input_data['ego_vehicle']
+ actors: input_data['surrounding_vehicles']
+ route_plan: input_data['route_plan']
+ '''
+ # Get the right offset
+ if ego_vehicle['lane_id'] < 0 and lane_offset != 0:
+ lane_offset *= -1
+
+ for i, is_valid in enumerate(actors['binary_mask']):
+ if not is_valid:
+ continue
+
+ if not at_junction and (actors['road_id'][i] != ego_vehicle['road_id'] or
+ actors['lane_id'][i] != ego_vehicle['lane_id'] + lane_offset):
+
+ next_road_id = route_plan['road_id'][5]
+ next_lane_id = route_plan['lane_id'][5]
+
+ if actors['road_id'][i] != next_road_id or actors['lane_id'][i] != next_lane_id + lane_offset:
+ continue
+
+ if is_within_distance_ahead(actors['location'][i], proximity_th, up_angle_th):
+ return i
+ return None
+
+
+def behavior_hazard_walker(ego_vehicle, actors, route_plan, proximity_th, up_angle_th, lane_offset=0, at_junction=False):
+ '''
+ ego_vehicle: input_data['ego_vehicle']
+ actors: input_data['surrounding_vehicles']
+ route_plan: input_data['route_plan']
+ '''
+ # Get the right offset
+ if ego_vehicle['lane_id'] < 0 and lane_offset != 0:
+ lane_offset *= -1
+
+ for i, is_valid in enumerate(actors['binary_mask']):
+ if not is_valid:
+ continue
+
+ if int(actors['on_sidewalk'][i]) == 1:
+ continue
+
+ if not at_junction and (actors['road_id'][i] != ego_vehicle['road_id'] or
+ actors['lane_id'][i] != ego_vehicle['lane_id'] + lane_offset):
+
+ next_road_id = route_plan['road_id'][5]
+ next_lane_id = route_plan['lane_id'][5]
+
+ if actors['road_id'][i] != next_road_id or actors['lane_id'][i] != next_lane_id + lane_offset:
+ continue
+
+ if is_within_distance_ahead(actors['location'][i], proximity_th, up_angle_th):
+ return i
+ return None
diff --git a/carla_gym/utils/traffic_light.py b/carla_gym/utils/traffic_light.py
new file mode 100644
index 0000000..1aa391e
--- /dev/null
+++ b/carla_gym/utils/traffic_light.py
@@ -0,0 +1,201 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+from collections import deque
+import carla
+import numpy as np
+import carla_gym.utils.transforms as trans_utils
+
+
+def _get_traffic_light_waypoints(traffic_light, carla_map):
+ """
+ get area of a given traffic light
+ adapted from "carla-simulator/scenario_runner/srunner/scenariomanager/scenarioatomics/atomic_criteria.py"
+ """
+ base_transform = traffic_light.get_transform()
+ tv_loc = traffic_light.trigger_volume.location
+ tv_ext = traffic_light.trigger_volume.extent
+
+ # Discretize the trigger box into points
+ x_values = np.arange(-0.9 * tv_ext.x, 0.9 * tv_ext.x, 1.0) # 0.9 to avoid crossing to adjacent lanes
+ area = []
+ for x in x_values:
+ point_location = base_transform.transform(tv_loc + carla.Location(x=x))
+ area.append(point_location)
+
+ # Get the waypoints of these points, removing duplicates
+ ini_wps = []
+ for pt in area:
+ wpx = carla_map.get_waypoint(pt)
+ # As x_values are arranged in order, only the last one has to be checked
+ if not ini_wps or ini_wps[-1].road_id != wpx.road_id or ini_wps[-1].lane_id != wpx.lane_id:
+ ini_wps.append(wpx)
+
+ # Leaderboard: Advance them until the intersection
+ stopline_wps = []
+ stopline_vertices = []
+ junction_wps = []
+ for wpx in ini_wps:
+ # Below: just use trigger volume, otherwise it's on the zebra lines.
+ # stopline_wps.append(wpx)
+ # vec_forward = wpx.transform.get_forward_vector()
+ # vec_right = carla.Vector3D(x=-vec_forward.y, y=vec_forward.x, z=0)
+
+ # loc_left = wpx.transform.location - 0.4 * wpx.lane_width * vec_right
+ # loc_right = wpx.transform.location + 0.4 * wpx.lane_width * vec_right
+ # stopline_vertices.append([loc_left, loc_right])
+
+ while not wpx.is_intersection:
+ next_wp = wpx.next(0.5)[0]
+ if next_wp and not next_wp.is_intersection:
+ wpx = next_wp
+ else:
+ break
+ junction_wps.append(wpx)
+
+ stopline_wps.append(wpx)
+ vec_forward = wpx.transform.get_forward_vector()
+ vec_right = carla.Vector3D(x=-vec_forward.y, y=vec_forward.x, z=0)
+
+ loc_left = wpx.transform.location - 0.4 * wpx.lane_width * vec_right
+ loc_right = wpx.transform.location + 0.4 * wpx.lane_width * vec_right
+ stopline_vertices.append([loc_left, loc_right])
+
+ # all paths at junction for this traffic light
+ junction_paths = []
+ path_wps = []
+ wp_queue = deque(junction_wps)
+ while len(wp_queue) > 0:
+ current_wp = wp_queue.pop()
+ path_wps.append(current_wp)
+ next_wps = current_wp.next(1.0)
+ for next_wp in next_wps:
+ if next_wp.is_junction:
+ wp_queue.append(next_wp)
+ else:
+ junction_paths.append(path_wps)
+ path_wps = []
+
+ return carla.Location(base_transform.transform(tv_loc)), stopline_wps, stopline_vertices, junction_paths
+
+
+class TrafficLightHandler:
+ num_tl = 0
+ list_tl_actor = []
+ list_tv_loc = []
+ list_stopline_wps = []
+ list_stopline_vtx = []
+ list_junction_paths = []
+ carla_map = None
+
+ @staticmethod
+ def reset(world):
+ TrafficLightHandler.carla_map = world.get_map()
+
+ TrafficLightHandler.num_tl = 0
+ TrafficLightHandler.list_tl_actor = []
+ TrafficLightHandler.list_tv_loc = []
+ TrafficLightHandler.list_stopline_wps = []
+ TrafficLightHandler.list_stopline_vtx = []
+ TrafficLightHandler.list_junction_paths = []
+
+ all_actors = world.get_actors()
+ for _actor in all_actors:
+ if 'traffic_light' in _actor.type_id:
+ tv_loc, stopline_wps, stopline_vtx, junction_paths = _get_traffic_light_waypoints(
+ _actor, TrafficLightHandler.carla_map)
+
+ TrafficLightHandler.list_tl_actor.append(_actor)
+ TrafficLightHandler.list_tv_loc.append(tv_loc)
+ TrafficLightHandler.list_stopline_wps.append(stopline_wps)
+ TrafficLightHandler.list_stopline_vtx.append(stopline_vtx)
+ TrafficLightHandler.list_junction_paths.append(junction_paths)
+
+ TrafficLightHandler.num_tl += 1
+
+ @staticmethod
+ def get_light_state(vehicle, offset=0.0, dist_threshold=15.0):
+ '''
+ vehicle: carla.Vehicle
+ '''
+ vec_tra = vehicle.get_transform()
+ veh_dir = vec_tra.get_forward_vector()
+
+ hit_loc = vec_tra.transform(carla.Location(x=offset))
+ hit_wp = TrafficLightHandler.carla_map.get_waypoint(hit_loc)
+
+ light_loc = None
+ light_state = None
+ light_id = None
+ for i in range(TrafficLightHandler.num_tl):
+ traffic_light = TrafficLightHandler.list_tl_actor[i]
+ tv_loc = 0.5*TrafficLightHandler.list_stopline_wps[i][0].transform.location \
+ + 0.5*TrafficLightHandler.list_stopline_wps[i][-1].transform.location
+
+ distance = np.sqrt((tv_loc.x-hit_loc.x)**2 + (tv_loc.y-hit_loc.y)**2)
+ if distance > dist_threshold:
+ continue
+
+ for wp in TrafficLightHandler.list_stopline_wps[i]:
+
+ wp_dir = wp.transform.get_forward_vector()
+ dot_ve_wp = veh_dir.x * wp_dir.x + veh_dir.y * wp_dir.y + veh_dir.z * wp_dir.z
+
+ wp_1 = wp.previous(4.0)[0]
+ same_road = (hit_wp.road_id == wp.road_id) and (hit_wp.lane_id == wp.lane_id)
+ same_road_1 = (hit_wp.road_id == wp_1.road_id) and (hit_wp.lane_id == wp_1.lane_id)
+
+ # if (wp.road_id != wp_1.road_id) or (wp.lane_id != wp_1.lane_id):
+ # print(f'Traffic Light Problem: {wp.road_id}={wp_1.road_id}, {wp.lane_id}={wp_1.lane_id}')
+
+ if (same_road or same_road_1) and dot_ve_wp > 0:
+ # This light is red and is affecting our lane
+ loc_in_ev = trans_utils.loc_global_to_ref(wp.transform.location, vec_tra)
+ light_loc = np.array([loc_in_ev.x, loc_in_ev.y, loc_in_ev.z], dtype=np.float32)
+ light_state = traffic_light.state
+ light_id = traffic_light.id
+ break
+
+ return light_state, light_loc, light_id
+
+ @staticmethod
+ def get_junctoin_paths(veh_loc, color=0, dist_threshold=50.0):
+ if color == 0:
+ tl_state = carla.TrafficLightState.Green
+ elif color == 1:
+ tl_state = carla.TrafficLightState.Yellow
+ elif color == 2:
+ tl_state = carla.TrafficLightState.Red
+
+ junctoin_paths = []
+ for i in range(TrafficLightHandler.num_tl):
+ traffic_light = TrafficLightHandler.list_tl_actor[i]
+ tv_loc = TrafficLightHandler.list_tv_loc[i]
+ if tv_loc.distance(veh_loc) > dist_threshold:
+ continue
+ if traffic_light.state != tl_state:
+ continue
+
+ junctoin_paths += TrafficLightHandler.list_junction_paths[i]
+
+ return junctoin_paths
+
+ @staticmethod
+ def get_stopline_vtx(veh_loc, color, dist_threshold=50.0):
+ if color == 0:
+ tl_state = carla.TrafficLightState.Green
+ elif color == 1:
+ tl_state = carla.TrafficLightState.Yellow
+ elif color == 2:
+ tl_state = carla.TrafficLightState.Red
+
+ stopline_vtx = []
+ for i in range(TrafficLightHandler.num_tl):
+ traffic_light = TrafficLightHandler.list_tl_actor[i]
+ tv_loc = TrafficLightHandler.list_tv_loc[i]
+ if tv_loc.distance(veh_loc) > dist_threshold:
+ continue
+ if traffic_light.state != tl_state:
+ continue
+ stopline_vtx += TrafficLightHandler.list_stopline_vtx[i]
+
+ return stopline_vtx
diff --git a/carla_gym/utils/transforms.py b/carla_gym/utils/transforms.py
new file mode 100644
index 0000000..a3cb575
--- /dev/null
+++ b/carla_gym/utils/transforms.py
@@ -0,0 +1,104 @@
+"""Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license."""
+
+import numpy as np
+import carla
+
+
+def loc_global_to_ref(target_loc_in_global, ref_trans_in_global):
+ """
+ :param target_loc_in_global: carla.Location in global coordinate (world, actor)
+ :param ref_trans_in_global: carla.Transform in global coordinate (world, actor)
+ :return: carla.Location in ref coordinate
+ """
+ x = target_loc_in_global.x - ref_trans_in_global.location.x
+ y = target_loc_in_global.y - ref_trans_in_global.location.y
+ z = target_loc_in_global.z - ref_trans_in_global.location.z
+ vec_in_global = carla.Vector3D(x=x, y=y, z=z)
+ vec_in_ref = vec_global_to_ref(vec_in_global, ref_trans_in_global.rotation)
+
+ target_loc_in_ref = carla.Location(x=vec_in_ref.x, y=vec_in_ref.y, z=vec_in_ref.z)
+ return target_loc_in_ref
+
+
+def vec_global_to_ref(target_vec_in_global, ref_rot_in_global):
+ """
+ :param target_vec_in_global: carla.Vector3D in global coordinate (world, actor)
+ :param ref_rot_in_global: carla.Rotation in global coordinate (world, actor)
+ :return: carla.Vector3D in ref coordinate
+ """
+ R = carla_rot_to_mat(ref_rot_in_global)
+ np_vec_in_global = np.array([[target_vec_in_global.x],
+ [target_vec_in_global.y],
+ [target_vec_in_global.z]])
+ np_vec_in_ref = R.T.dot(np_vec_in_global)
+ target_vec_in_ref = carla.Vector3D(x=np_vec_in_ref[0, 0], y=np_vec_in_ref[1, 0], z=np_vec_in_ref[2, 0])
+ return target_vec_in_ref
+
+
+def rot_global_to_ref(target_rot_in_global, ref_rot_in_global):
+ target_roll_in_ref = cast_angle(target_rot_in_global.roll - ref_rot_in_global.roll)
+ target_pitch_in_ref = cast_angle(target_rot_in_global.pitch - ref_rot_in_global.pitch)
+ target_yaw_in_ref = cast_angle(target_rot_in_global.yaw - ref_rot_in_global.yaw)
+
+ target_rot_in_ref = carla.Rotation(roll=target_roll_in_ref, pitch=target_pitch_in_ref, yaw=target_yaw_in_ref)
+ return target_rot_in_ref
+
+def rot_ref_to_global(target_rot_in_ref, ref_rot_in_global):
+ target_roll_in_global = cast_angle(target_rot_in_ref.roll + ref_rot_in_global.roll)
+ target_pitch_in_global = cast_angle(target_rot_in_ref.pitch + ref_rot_in_global.pitch)
+ target_yaw_in_global = cast_angle(target_rot_in_ref.yaw + ref_rot_in_global.yaw)
+
+ target_rot_in_global = carla.Rotation(roll=target_roll_in_global, pitch=target_pitch_in_global, yaw=target_yaw_in_global)
+ return target_rot_in_global
+
+
+def carla_rot_to_mat(carla_rotation):
+ """
+ Transform rpy in carla.Rotation to rotation matrix in np.array
+
+ :param carla_rotation: carla.Rotation
+ :return: np.array rotation matrix
+ """
+ roll = np.deg2rad(carla_rotation.roll)
+ pitch = np.deg2rad(carla_rotation.pitch)
+ yaw = np.deg2rad(carla_rotation.yaw)
+
+ yaw_matrix = np.array([
+ [np.cos(yaw), -np.sin(yaw), 0],
+ [np.sin(yaw), np.cos(yaw), 0],
+ [0, 0, 1]
+ ])
+ pitch_matrix = np.array([
+ [np.cos(pitch), 0, -np.sin(pitch)],
+ [0, 1, 0],
+ [np.sin(pitch), 0, np.cos(pitch)]
+ ])
+ roll_matrix = np.array([
+ [1, 0, 0],
+ [0, np.cos(roll), np.sin(roll)],
+ [0, -np.sin(roll), np.cos(roll)]
+ ])
+
+ rotation_matrix = yaw_matrix.dot(pitch_matrix).dot(roll_matrix)
+ return rotation_matrix
+
+def get_loc_rot_vel_in_ev(actor_list, ev_transform):
+ location, rotation, absolute_velocity = [], [], []
+ for actor in actor_list:
+ # location
+ location_in_world = actor.get_transform().location
+ location_in_ev = loc_global_to_ref(location_in_world, ev_transform)
+ location.append([location_in_ev.x, location_in_ev.y, location_in_ev.z])
+ # rotation
+ rotation_in_world = actor.get_transform().rotation
+ rotation_in_ev = rot_global_to_ref(rotation_in_world, ev_transform.rotation)
+ rotation.append([rotation_in_ev.roll, rotation_in_ev.pitch, rotation_in_ev.yaw])
+ # velocity
+ vel_in_world = actor.get_velocity()
+ vel_in_ev = vec_global_to_ref(vel_in_world, ev_transform.rotation)
+ absolute_velocity.append([vel_in_ev.x, vel_in_ev.y, vel_in_ev.z])
+ return location, rotation, absolute_velocity
+
+def cast_angle(x):
+ # cast angle to [-180, +180)
+ return (x+180.0)%360.0-180.0
\ No newline at end of file
diff --git a/config/agent/mile.yaml b/config/agent/mile.yaml
new file mode 100644
index 0000000..3d67095
--- /dev/null
+++ b/config/agent/mile.yaml
@@ -0,0 +1,9 @@
+defaults:
+ - mile/obs_configs: central_rgb_wide
+
+mile:
+ entry_point: agents.muvo.mile_agent:MileAgent
+ ckpt: ''
+ online_deployment: false
+ env_wrapper:
+ entry_point: agents.muvo.mile_wrapper:MileWrapper
diff --git a/config/agent/mile/obs_configs/central_rgb_wide.yaml b/config/agent/mile/obs_configs/central_rgb_wide.yaml
new file mode 100644
index 0000000..dd6b2b5
--- /dev/null
+++ b/config/agent/mile/obs_configs/central_rgb_wide.yaml
@@ -0,0 +1,22 @@
+speed:
+ module: actor_state.speed
+gnss:
+ module: navigation.gnss
+central_rgb:
+ module: camera.rgb
+ fov: 100
+ width: 960
+ height: 600
+ location: [-1.5, 0.0, 2.0]
+ rotation: [0.0, 0.0, 0.0]
+route_plan:
+ module: navigation.waypoint_plan
+ steps: 20
+birdview:
+ module: birdview.chauffeurnet
+ width_in_pixels: 192
+ pixels_ev_to_bottom: 32
+ pixels_per_meter: 5.0
+ history_idx: [-16, -11, -6, -1]
+ scale_bbox: true
+ scale_mask_col: 1.0
diff --git a/config/agent/my.yaml b/config/agent/my.yaml
new file mode 100644
index 0000000..a1bcf2a
--- /dev/null
+++ b/config/agent/my.yaml
@@ -0,0 +1,2 @@
+defaults:
+ - my/obs_configs: camera_lidar_semantic
diff --git a/config/agent/my/obs_configs/camera_lidar.yaml b/config/agent/my/obs_configs/camera_lidar.yaml
new file mode 100644
index 0000000..7b1cae4
--- /dev/null
+++ b/config/agent/my/obs_configs/camera_lidar.yaml
@@ -0,0 +1,48 @@
+speed:
+ module: actor_state.speed
+gnss:
+ module: navigation.gnss
+route_plan:
+ module: navigation.waypoint_plan
+ steps: 20
+central_rgb:
+ module: camera.rgb
+ fov: 100
+ width: 960
+ height: 600
+ location: [0.5, 0.0, 2.0]
+ rotation: [0.0, 0.0, 0.0]
+left_rgb:
+ module: camera.rgb
+ fov: 100
+ width: 960
+ height: 600
+ location: [-2.0, -0.7, 2.0]
+ rotation: [0.0, 0.0, -60.0]
+right_rgb:
+ module: camera.rgb
+ fov: 100
+ width: 960
+ height: 600
+ location: [-2.0, 0.7, 2.0]
+ rotation: [0.0, 0.0, 60.0]
+lidar_points:
+ module: lidar.ray_cast
+ location: [-1.0, 0.0, 2.0]
+ rotation: [0.0, 0.0, 0.0]
+ render_o3d: False
+ show_axis: False
+ no_noise: False
+ lidar_options:
+ channels: 64
+ range: 100
+ rotation_frequency: 25
+ points_per_second: 600000
+ upper_fov: 10.0
+ lower_fov: -30.0 # -30.0
+ atmosphere_attenuation_rate: 0.004
+ # if no_noise
+ dropoff_general_rate: 0.45
+ dropoff_intensity_limit: 0.8
+ dropoff_zero_intensity: 0.4
+
diff --git a/config/agent/my/obs_configs/camera_lidar_semantic.yaml b/config/agent/my/obs_configs/camera_lidar_semantic.yaml
new file mode 100644
index 0000000..0b358ac
--- /dev/null
+++ b/config/agent/my/obs_configs/camera_lidar_semantic.yaml
@@ -0,0 +1,98 @@
+speed:
+ module: actor_state.speed
+gnss:
+ module: navigation.gnss
+route_plan:
+ module: navigation.waypoint_plan
+ steps: 20
+birdview:
+ module: birdview.chauffeurnet
+ width_in_pixels: 192
+ pixels_ev_to_bottom: 32
+ pixels_per_meter: 5.0
+ history_idx: [-16, -11, -6, -1]
+ scale_bbox: true
+ scale_mask_col: 1.0
+central_rgb:
+ module: camera.rgb
+ fov: 110
+ width: 960
+ height: 600
+ location: [1.0, 0.0, 2.0]
+ rotation: [0.0, 0.0, 0.0]
+depth_semantic:
+ module: camera.depth_semantic
+ fov: 110
+ width: 960
+ height: 600
+ location: [1.0, 0.0, 2.0]
+ rotation: [0.0, 0.0, 0.0]
+#left_rgb:
+# module: camera.rgb
+# fov: 100
+# width: 960
+# height: 600
+# location: [-2.0, -0.7, 2.0]
+# rotation: [0.0, 0.0, -60.0]
+#right_rgb:
+# module: camera.rgb
+# fov: 100
+# width: 960
+# height: 600
+# location: [-2.0, 0.7, 2.0]
+# rotation: [0.0, 0.0, 60.0]
+#lidar_points:
+# module: lidar.ray_cast
+# location: [1.0, 0.0, 2.0]
+# rotation: [0.0, 0.0, 0.0]
+# render_o3d: False
+# show_axis: False
+# no_noise: False
+# lidar_options:
+# channels: 64
+# range: 100
+# rotation_frequency: 10
+# points_per_second: 600000
+# upper_fov: 10.0
+# lower_fov: -30.0 # -30.0
+# atmosphere_attenuation_rate: 0.004
+# # if no_noise
+# dropoff_general_rate: 0.45
+# dropoff_intensity_limit: 0.8
+# dropoff_zero_intensity: 0.4
+lidar_points_semantic:
+ module: lidar.ray_cast_semantic
+ location: [1.0, 0.0, 2.0]
+ rotation: [0.0, 0.0, 0.0]
+ render_o3d: False
+ show_axis: False
+ lidar_options:
+ channels: 64
+ range: 100
+ rotation_frequency: 10
+ points_per_second: 600000
+ upper_fov: 10.0
+ lower_fov: -30.0 # -30.0
+#lidar_points_multi:
+# module: lidar.ray_cast_multi
+# location: [-1.0, 0.0, 2.0]
+# rotation: [0.0, 0.0, 0.0]
+# box_size: [20, 12, 20]
+# render_o3d: False
+# show_axis: False
+# lidar_options:
+# channels: 64
+# range: 100
+# rotation_frequency: 25
+# points_per_second: 600000
+# upper_fov: 30.0
+# lower_fov: -30.0 # -30.0
+
+#depth_semantic_m:
+# module: camera.depth_semantic_m
+# fov: 90
+# width: 320
+# height: 320
+# box_size: [10, 10, 10]
+# sensor_num: [2, 1]
+
diff --git a/config/agent/ppo.yaml b/config/agent/ppo.yaml
new file mode 100644
index 0000000..428cb51
--- /dev/null
+++ b/config/agent/ppo.yaml
@@ -0,0 +1,13 @@
+defaults:
+ - ppo/obs_configs: birdview
+ - ppo/policy: xtma_beta
+
+ppo:
+ entry_point: agents.rl_birdview.rl_birdview_agent:RlBirdviewAgent
+ wb_run_path: null
+ wb_ckpt_step: null
+ env_wrapper:
+ entry_point: agents.rl_birdview.utils.rl_birdview_wrapper:RlBirdviewWrapper
+ kwargs:
+ input_states: [control, vel_xy]
+ acc_as_action: True
\ No newline at end of file
diff --git a/config/agent/ppo/obs_configs/birdview.yaml b/config/agent/ppo/obs_configs/birdview.yaml
new file mode 100644
index 0000000..48c0ae9
--- /dev/null
+++ b/config/agent/ppo/obs_configs/birdview.yaml
@@ -0,0 +1,14 @@
+birdview:
+ module: birdview.chauffeurnet
+ width_in_pixels: 192
+ pixels_ev_to_bottom: 32
+ pixels_per_meter: 5.0
+ history_idx: [-16, -11, -6, -1]
+ scale_bbox: true
+ scale_mask_col: 1.0
+speed:
+ module: actor_state.speed
+control:
+ module: actor_state.control
+velocity:
+ module: actor_state.velocity
\ No newline at end of file
diff --git a/config/agent/ppo/policy/xtma_beta.yaml b/config/agent/ppo/policy/xtma_beta.yaml
new file mode 100644
index 0000000..dd3580c
--- /dev/null
+++ b/config/agent/ppo/policy/xtma_beta.yaml
@@ -0,0 +1,10 @@
+entry_point: agents.rl_birdview.models.ppo_policy:PpoPolicy
+kwargs:
+ policy_head_arch: [256, 256]
+ value_head_arch: [256, 256]
+ features_extractor_entry_point: agents.rl_birdview.models.torch_layers:XtMaCNN
+ features_extractor_kwargs:
+ states_neurons: [256,256]
+ distribution_entry_point: agents.rl_birdview.models.distributions:BetaDistribution
+ distribution_kwargs:
+ dist_init: null
\ No newline at end of file
diff --git a/config/data_collect.yaml b/config/data_collect.yaml
new file mode 100644
index 0000000..1becda5
--- /dev/null
+++ b/config/data_collect.yaml
@@ -0,0 +1,63 @@
+hydra:
+ run:
+ dir: ${work_dir}/${now:%Y-%m-%d}/${now:%H-%M-%S}
+# dir: /mnt/d/python/dataset/outputs
+
+work_dir: /mnt/d/python/dataset/outputs
+
+defaults:
+ - agent:
+# - muvo
+ - ppo
+ - my
+ - test_suites: lb_data
+ - _self_
+
+dataset_root: 'dataset'
+carla_sh_path: '/mnt/d/software/carla/0.9.14/CarlaUE4.exe'
+port: 2000
+n_episodes: 1
+run_time: ${actors.hero.terminal.kwargs.max_time}
+
+ev_id: hero
+resume: false
+log_level: INFO
+host: localhost
+seed: 0
+no_rendering: false
+render_off_screen: true
+kill_running: true
+remove_final_steps: true
+cml_project: MasterThesis/world_model
+cml_task_name: carla_data
+cml_task_type: application
+cml_tags: null
+log_video: false
+
+actors:
+ hero:
+ driver: ppo
+ reward:
+ entry_point: reward.valeo_action:ValeoAction
+ terminal:
+ entry_point: terminal.leaderboard_dagger:LeaderboardDagger
+ kwargs:
+ max_time: 2
+ no_collision: true
+ no_run_rl: true
+ no_run_stop: true
+
+agent:
+ ppo:
+ wb_run_path: iccv21-roach/trained-models/1929isj0
+ wb_ckpt_step: null
+ my:
+ obs_configs:
+ birdview_label:
+ module: birdview.chauffeurnet_label
+ width_in_pixels: 192
+ pixels_ev_to_bottom: 32
+ pixels_per_meter: 5.0
+ history_idx: [-16, -11, -6, -1]
+ scale_bbox: true
+ scale_mask_col: 1.0
diff --git a/config/evaluate.yaml b/config/evaluate.yaml
new file mode 100644
index 0000000..f72b75d
--- /dev/null
+++ b/config/evaluate.yaml
@@ -0,0 +1,26 @@
+defaults:
+ - test_suites: lb_test
+ - agent: muvo
+ - _self_
+
+carla_sh_path: ''
+port: 2000
+
+log_level: INFO
+host: localhost
+seed: 2021
+no_rendering: false
+kill_running: true
+resume: true
+wb_project: muvo
+wb_group: evaluation
+wb_tags: null
+log_video: true
+
+actors:
+ hero:
+ agent: muvo
+ reward:
+ entry_point: reward.valeo_action:ValeoAction
+ terminal:
+ entry_point: terminal.leaderboard:Leaderboard
diff --git a/config/test_suites/lb_data.yaml b/config/test_suites/lb_data.yaml
new file mode 100644
index 0000000..2a7bc7a
--- /dev/null
+++ b/config/test_suites/lb_data.yaml
@@ -0,0 +1,24 @@
+- env_id: Endless-v0
+ env_configs:
+ carla_map: Town01
+ num_zombie_vehicles: [80, 160]
+ num_zombie_walkers: [80, 160]
+ weather_group: train
+- env_id: Endless-v0
+ env_configs:
+ carla_map: Town03
+ num_zombie_vehicles: [40, 100]
+ num_zombie_walkers: [40, 100]
+ weather_group: train
+- env_id: Endless-v0
+ env_configs:
+ carla_map: Town04
+ num_zombie_vehicles: [100, 200]
+ num_zombie_walkers: [40, 120]
+ weather_group: train
+- env_id: Endless-v0
+ env_configs:
+ carla_map: Town06
+ num_zombie_vehicles: [80, 160]
+ num_zombie_walkers: [40, 120]
+ weather_group: train
diff --git a/config/test_suites/lb_test.yaml b/config/test_suites/lb_test.yaml
new file mode 100644
index 0000000..088ca2c
--- /dev/null
+++ b/config/test_suites/lb_test.yaml
@@ -0,0 +1,31 @@
+# new town / new weather
+- env_id: LeaderBoard-v0
+ env_configs:
+ carla_map: Town02
+ routes_group: null
+ weather_group: new
+- env_id: LeaderBoard-v0
+ env_configs:
+ carla_map: Town05
+ routes_group: null
+ weather_group: new
+- env_id: LeaderBoard-v0
+ env_configs:
+ carla_map: Town01
+ routes_group: null
+ weather_group: new
+- env_id: LeaderBoard-v0
+ env_configs:
+ carla_map: Town03
+ routes_group: null
+ weather_group: new
+- env_id: LeaderBoard-v0
+ env_configs:
+ carla_map: Town04
+ routes_group: null
+ weather_group: new
+- env_id: LeaderBoard-v0
+ env_configs:
+ carla_map: Town06
+ routes_group: null
+ weather_group: new
diff --git a/constants.py b/constants.py
new file mode 100644
index 0000000..c5791ba
--- /dev/null
+++ b/constants.py
@@ -0,0 +1,229 @@
+import numpy as np
+
+CARLA_FPS = 10
+DISPLAY_SEGMENTATION = True
+DISTORT_IMAGES = False
+WHEEL_BASE = 2.8711279296875
+# Ego-vehicle is 4.902m long and 2.128m wide. See `self._parent_actor.vehicle.bounding_box` in chaffeurnet_label
+EGO_VEHICLE_DIMENSION = [4.902, 2.128, 1.511]
+
+# https://github.com/carla-simulator/carla/blob/master/PythonAPI/carla/agents/navigation/local_planner.py
+# However when processed, see "process_obs" function, unknown becomes lane_follow and the rest has a value between
+# [0, 5] by substracting 1.
+ROUTE_COMMANDS = {0: 'UNKNOWN',
+ 1: 'LEFT',
+ 2: 'RIGHT',
+ 3: 'STRAIGHT',
+ 4: 'LANEFOLLOW',
+ 5: 'CHANGELANELEFT',
+ 6: 'CHANGELANERIGHT',
+ }
+
+BIRDVIEW_COLOURS = np.array([[255, 255, 255], # Background
+ [225, 225, 225], # Road
+ [160, 160, 160], # Lane marking
+ [0, 83, 138], # Vehicle
+ [127, 255, 212], # Pedestrian
+ [50, 205, 50], # Green light
+ [255, 215, 0], # Yellow light
+ [220, 20, 60], # Red light and stop sign
+ ], dtype=np.uint8)
+
+# Obtained with sqrt of inverse frequency
+SEMANTIC_SEG_WEIGHTS = np.array([1.0, 1.0, 1.0, 2.0, 3.0, 1.0, 1.0, 1.0])
+
+# VOXEL_SEG_WEIGHTS = np.ones(23, dtype=float)
+# VOXEL_SEG_WEIGHTS[4] = 3.0
+# VOXEL_SEG_WEIGHTS[10] = 2.0
+
+VOXEL_SEG_WEIGHTS = np.array([1.0, 1.0, 1.0, 1.5, 2.0, 3.0, 1.0, 1.0, 1.0])
+
+VOXEL_LABEL_CARLA = {
+ 0: 'Background', # None
+ 1: 'Building', # Building
+ 2: 'Fences', # Fences
+ 3: 'Other', # Other
+ 4: 'Pedestrian', # Pedestrian
+ 5: 'Pole', # Pole
+ 6: 'RoadLines', # RoadLines
+ 7: 'Road', # Road
+ 8: 'Sidewalk', # Sidewalk
+ 9: 'Vegetation', # Vegetation
+ 10: 'Vehicle', # Vehicle
+ 11: 'Wall', # Wall
+ 12: 'TrafficSign', # TrafficSign
+ 13: 'Sky', # Sky
+ 14: 'Ground', # Ground
+ 15: 'Bridge', # Bridge
+ 16: 'RailTrack', # RailTrack
+ 17: 'GuardRail', # GuardRail
+ 18: 'TrafficLight', # TrafficLight
+ 19: 'Static', # Static
+ 20: 'Dynamic', # Dynamic
+ 21: 'Water', # Water
+ 22: 'Terrain', # Terrain
+}
+
+# VOXEL_LABEL = {
+# 0: 'Background',
+# 1: 'Road',
+# 2: 'RoadLines',
+# 3: 'Sidewalk',
+# 4: 'Vehicle',
+# 5: 'Pedestrian',
+# 6: 'TrafficSign',
+# 7: 'TrafficLight',
+# 8: 'Others'
+# }
+VOXEL_LABEL = {
+ 0: 'Background',
+ 1: 'Occupancy',
+}
+# VOXEL_LABEL = VOXEL_LABEL_CARLA
+
+# VOXEL_COLOURS = np.array([[255, 255, 255], # Background
+# [150, 150, 150], # Road
+# [200, 200, 20], # Road Lines
+# [200, 200, 200], # Sidewalk
+# [0, 83, 138], # Vehicle
+# [127, 255, 212], # Pedestrian
+# [220, 20, 60], # Traffic Sign
+# [100, 150, 35], # Traffic light
+# [0, 0, 0], # Others
+# ], dtype=np.uint8)
+VOXEL_COLOURS = np.array([[255, 255, 255], # Background
+ [115, 115, 115], # Others
+ ], dtype=np.uint8)
+# VOXEL_COLOURS = np.array([[255, 255, 255], # None
+# [70, 70, 70], # Building
+# [100, 40, 40], # Fences
+# [55, 90, 80], # Other
+# [220, 20, 60], # Pedestrian
+# [153, 153, 153], # Pole
+# [157, 234, 50], # RoadLines
+# [128, 64, 128], # Road
+# [244, 35, 232], # Sidewalk
+# [107, 142, 35], # Vegetation
+# [0, 0, 142], # Vehicle
+# [102, 102, 156], # Wall
+# [220, 220, 0], # TrafficSign
+# [70, 130, 180], # Sky
+# [81, 0, 81], # Ground
+# [150, 100, 100], # Bridge
+# [230, 150, 140], # RailTrack
+# [180, 165, 180], # GuardRail
+# [250, 170, 30], # TrafficLight
+# [110, 190, 160], # Static
+# [170, 120, 50], # Dynamic
+# [45, 60, 150], # Water
+# [145, 170, 100], # Terrain
+# ], dtype=np.uint8)
+
+# VOXEL_COLOURS = np.array([[0, 0, 0], # unlabeled
+# # cityscape
+# [128, 64, 128], # road = 1
+# [244, 35, 232], # sidewalk = 2
+# [70, 70, 70], # building = 3
+# [102, 102, 156], # wall = 4
+# [190, 153, 153], # fence = 5
+# [153, 153, 153], # pole = 6
+# [250, 170, 30], # traffic light = 7
+# [220, 220, 0], # traffic sign = 8
+# [107, 142, 35], # vegetation = 9
+# [152, 251, 152], # terrain = 10
+# [70, 130, 180], # sky = 11
+# [220, 20, 60], # pedestrian = 12
+# [255, 0, 0], # rider = 13
+# [0, 0, 142], # Car = 14
+# [0, 0, 70], # truck = 15
+# [0, 60, 100], # bs = 16
+# [0, 80, 100], # train = 17
+# [0, 0, 230], # motorcycle = 18
+# [119, 11, 32], # bicycle = 19
+# # custom
+# [110, 190, 160], # static = 20
+# [170, 120, 50], # dynamic = 21
+# [55, 90, 80], # other = 22
+# [45, 60, 150], # water = 23
+# [157, 234, 50], # road line = 24
+# [81, 0, 81], # grond = 25
+# [150, 100, 100], # bridge = 26
+# [230, 150, 140], # rail track = 27
+# [180, 165, 180], # gard rail = 28
+# ], dtype=np.uint8)
+
+# LABEL_MAP = {
+# 0: 0, # None
+# 1: 8, # Building
+# 2: 8, # Fences
+# 3: 8, # Other
+# 4: 5, # Pedestrian
+# 5: 8, # Pole
+# 6: 2, # RoadLines
+# 7: 1, # Road
+# 8: 3, # Sidewalk
+# 9: 8, # Vegetation
+# 10: 4, # Vehicle
+# 11: 8, # Wall
+# 12: 6, # TrafficSign
+# 13: 0, # Sky
+# 14: 8, # Ground
+# 15: 8, # Bridge
+# 16: 8, # RailTrack
+# 17: 8, # GuardRail
+# 18: 7, # TrafficLight
+# 19: 8, # Static
+# 20: 8, # Dynamic
+# 21: 8, # Water
+# 22: 8, # Terrain
+# }
+LABEL_MAP = {
+ 0: 0, # None
+ 1: 1, # Building
+ 2: 1, # Fences
+ 3: 1, # Other
+ 4: 1, # Pedestrian
+ 5: 1, # Pole
+ 6: 1, # RoadLines
+ 7: 1, # Road
+ 8: 1, # Sidewalk
+ 9: 1, # Vegetation
+ 10: 1, # Vehicle
+ 11: 1, # Wall
+ 12: 1, # TrafficSign
+ 13: 0, # Sky
+ 14: 1, # Ground
+ 15: 1, # Bridge
+ 16: 1, # RailTrack
+ 17: 1, # GuardRail
+ 18: 1, # TrafficLight
+ 19: 1, # Static
+ 20: 1, # Dynamic
+ 21: 1, # Water
+ 22: 1, # Terrain
+}
+# LABEL_MAP = {
+# 0: 0, # None
+# 1: 1, # Building
+# 2: 2, # Fences
+# 3: 3, # Other
+# 4: 4, # Pedestrian
+# 5: 5, # Pole
+# 6: 6, # RoadLines
+# 7: 7, # Road
+# 8: 8, # Sidewalk
+# 9: 9, # Vegetation
+# 10: 10, # Vehicle
+# 11: 11, # Wall
+# 12: 12, # TrafficSign
+# 13: 13, # Sky
+# 14: 14, # Ground
+# 15: 15, # Bridge
+# 16: 16, # RailTrack
+# 17: 17, # GuardRail
+# 18: 18, # TrafficLight
+# 19: 19, # Static
+# 20: 20, # Dynamic
+# 21: 21, # Water
+# 22: 22, # Terrain
+# }
diff --git a/data/data_preprocess.yaml b/data/data_preprocess.yaml
new file mode 100644
index 0000000..bd83e21
--- /dev/null
+++ b/data/data_preprocess.yaml
@@ -0,0 +1,22 @@
+root: /mnt/d/python/dataset/test/trainval/train/
+#data_path: ${root}/2023-04-19/17-32-11/dataset/Town01/0000/depth_semantic
+
+n_process: 8
+#date: 2023-04-19
+#time: 17-32-11
+#town: Town01
+#run_name: 0000
+camera_position: [1.0, 0.0, 2.0]
+lidar_position: [1.0, 0.0, 2.0]
+fov: 110
+voxel_resolution: 0.5
+voxel_size: [192, 192, 64]
+bev_offset_forward: 0 # in px
+bev_resolution: 0.2
+offset_z: -20 # in px
+
+
+cml_project: MasterThesis/world_model
+cml_task_name: voxelization
+cml_task_type: data_processing
+cml_tags: null
diff --git a/data/data_preprocessing.py b/data/data_preprocessing.py
new file mode 100644
index 0000000..bff5806
--- /dev/null
+++ b/data/data_preprocessing.py
@@ -0,0 +1,247 @@
+import numpy as np
+# import open3d as o3d
+import cv2
+
+EGO_VEHICLE_DIMENSION = [4.902, 2.128, 1.511]
+# import matplotlib.pyplot as plt
+
+
+# LABEL = np.array([
+# (0, 0, 0, 'ego'), # unlabeled
+# # cityscape
+# (128, 64, 128, 'road'), # road = 1
+# (244, 35, 232, 'sidewalk'), # sidewalk = 2
+# (70, 70, 70, 'building'), # building = 3
+# (102, 102, 156, 'wall'), # wall = 4
+# (190, 153, 153, 'fence'), # fence = 5
+# (153, 153, 153, 'pole'), # pole = 6
+# (250, 170, 30, 'traffic'), # traffic light = 7
+# (220, 220, 0, 'traffic'), # traffic sign = 8
+# (107, 142, 35, 'vegetation'), # vegetation = 9
+# (152, 251, 152, 'terrain'), # terrain = 10
+# (70, 130, 180, 'sky'), # sky = 11
+# (220, 20, 60, 'pedestrian'), # pedestrian = 12
+# (255, 0, 0, 'rider'), # rider = 13
+# (0, 0, 142, 'Car'), # Car = 14
+# (0, 0, 70, 'truck'), # truck = 15
+# (0, 60, 100, 'bs'), # bs = 16
+# (0, 80, 100, 'train'), # train = 17
+# (0, 0, 230, 'motorcycle'), # motorcycle = 18
+# (119, 11, 32, 'bicycle'), # bicycle = 19
+# # , 'customcustom
+# (110, 190, 160, 'static'), # static = 20
+# (170, 120, 50, 'dynamic'), # dynamic = 21
+# (55, 90, 80, 'other'), # other = 22
+# (45, 60, 150, 'water'), # water = 23
+# (157, 234, 50, 'roadlines'), # road line = 24
+# (81, 0, 81, 'grond'), # grond = 25
+# (150, 100, 100, 'bridge'), # bridge = 26
+# (230, 150, 140, 'rail'), # rail track = 27
+# (180, 165, 180, 'gard') # gard rail = 28
+# ])
+LABEL = np.array([
+ (255, 255, 255, 'Ego'), # None
+ (70, 70, 70, 'Building'), # Building
+ (100, 40, 40, 'Fences'), # Fences
+ (55, 90, 80, 'Other'), # Other
+ (220, 20, 60, 'Pedestrian'), # Pedestrian
+ (153, 153, 153, 'Pole'), # Pole
+ (157, 234, 50, 'RoadLines'), # RoadLines
+ (128, 64, 128, 'Road'), # Road
+ (244, 35, 232, 'Sidewalk'), # Sidewalk
+ (107, 142, 35, 'Vegetation'), # Vegetation
+ (0, 0, 142, 'Vehicle'), # Vehicle
+ (102, 102, 156, 'Wall'), # Wall
+ (220, 220, 0, 'TrafficSign'), # TrafficSign
+ (70, 130, 180, 'Sky'), # Sky
+ (81, 0, 81, 'Ground'), # Ground
+ (150, 100, 100, 'Bridge'), # Bridge
+ (230, 150, 140, 'RailTrack'), # RailTrack
+ (180, 165, 180, 'GuardRail'), # GuardRail
+ (250, 170, 30, 'TrafficLight'), # TrafficLight
+ (110, 190, 160, 'Static'), # Static
+ (170, 120, 50, 'Dynamic'), # Dynamic
+ (45, 60, 150, 'Water'), # Water
+ (145, 170, 100, 'Terrain'), # Terrain
+])
+LABEL_COLORS = LABEL[:, :-1].astype(np.uint8) / 255.0
+LABEL_CLASS = np.char.lower(LABEL[:, -1])
+
+
+def read_img(file):
+ img = cv2.imread(file, -1)
+ depth_color = img[..., :-1].astype(float)
+ semantic = img[..., -1]
+ depth = 1000 * ((256 ** 2 * depth_color[..., 2] + 256 * depth_color[..., 1] + depth_color[..., 0]) / (256 ** 3 - 1))
+ return depth, semantic, depth_color
+
+
+def load_lidar(file):
+ data = np.load(file, allow_pickle=True).item()
+ pcd = data['points_xyz']
+ semantic = data['ObjTag']
+ return data, pcd, semantic
+
+
+def depth2pcd(depth, semantic, fov, range=100):
+ h, w = depth.shape
+ f = w / (2.0 * np.tan(fov * np.pi / 360.0))
+ cx, cy = w / 2.0, h / 2.0
+
+ depth = depth.reshape((-1, 1))
+ valid = (depth < 1000).squeeze()
+ depth = depth[valid]
+
+ g_x = np.arange(0, w)
+ g_y = np.arange(0, h)
+ xx, yy = np.meshgrid(g_x, g_y)
+ xx, yy = xx.reshape((-1, 1))[valid], yy.reshape((-1, 1))[valid]
+ x, y = (xx - cx) * depth / f, (yy - cy) * depth / f
+ points_list = np.concatenate([x, y, depth], axis=1)
+ sem_list = semantic.reshape((-1, 1))[valid]
+ valid_ = (np.linalg.norm(points_list, axis=1) < range).squeeze()
+ return points_list[valid_], sem_list[valid_]
+
+
+def convert_coor_img(pcd, camera_pos):
+ forward, right, up = camera_pos
+ mat = np.float32([
+ [0, 0, 1, forward],
+ [-1, 0, 0, -right],
+ [0, -1, 0, up],
+ [0, 0, 0, 1],
+ ])
+ pcd = np.insert(pcd, 3, 1, axis=1)
+ pcd = (mat @ pcd.T).T
+ return pcd[..., :-1]
+
+
+def convert_coor_lidar(pcd, lidar_pos):
+ pcd += np.asarray(lidar_pos)
+ pcd[:, 1] *= -1
+ return pcd
+
+
+def merge_pcd(depth_file, lidar_file, camera_pos, lidar_pos, fov=110, mask_ego=True):
+ depth, semantic, _ = read_img(depth_file)
+ img_pcd, img_semantic = depth2pcd(depth, semantic, fov)
+ img_pcd = convert_coor_img(img_pcd, camera_pos)
+ _, lidar_pcd, lidar_semantic = load_lidar(lidar_file)
+ lidar_pcd = convert_coor_lidar(lidar_pcd, lidar_pos)
+ pcd = np.concatenate([img_pcd, lidar_pcd], axis=0)
+ semantic = np.concatenate([img_semantic, lidar_semantic[:, None]], axis=0)
+ if mask_ego:
+ x, y, z = EGO_VEHICLE_DIMENSION
+ ego_box = np.array([[-x/2, -y/2, 0], [x/2, y/2, z]])
+ ego_idx = ((ego_box[0] < pcd) & (pcd < ego_box[1])).all(axis=1)
+ semantic = semantic[~ego_idx]
+ pcd = pcd[~ego_idx]
+ return pcd, semantic
+
+
+def get_all_points(depth, semantic, fov=90, size=(320, 320), offset=(10, 10, 10), mask_ego=True):
+ h, w = size
+ num_sensor = (int(depth.shape[0] / size[1]), int(depth.shape[1] / size[0]))
+ points_list = []
+ sem_list = []
+ for i in range(num_sensor[0]):
+ for j in range(num_sensor[1]):
+ y0, y1 = i * h, (i + 1) * h
+ x0, x1 = j * w, (j + 1) * w
+ pl, sl = depth2pcd(depth[y0: y1, x0: x1], semantic[y0: y1, x0: x1], fov)
+ x, y, z = offset[1] * (j - int(num_sensor[1] / 2)), offset[0] * (i - int(num_sensor[0] / 2)), -offset[2]
+ pl += np.array([x, y, z])
+ points_list.append(pl)
+ sem_list.append(sl)
+ points_list = np.concatenate(points_list, axis=0)
+ sem_list = np.concatenate(sem_list, axis=0)
+ points_list[:, 2] = -points_list[:, 2]
+ points_list[:, :2] = -points_list[:, :2][:, ::-1]
+ if mask_ego is not False:
+ if type(mask_ego) is not bool:
+ mask_ego = np.asarray(mask_ego)
+ ego_box = np.array([-mask_ego, mask_ego])
+ ego_box[0, 2] = 0
+ else:
+ ego_box = np.array([[-2.5, -1.1, 0], [2.5, 1.1, 2]])
+ ego_idx = ((ego_box[0] < points_list) & (points_list < ego_box[1])).all(axis=1)
+ sem_list[ego_idx] = 255
+ return points_list, sem_list
+
+
+def voxel_filter(pcd, sem, voxel_resolution, voxel_size, offset):
+ voxel_size = np.asarray(voxel_size)
+ offset = np.asarray(offset)
+ voxel_resolution = np.asarray(voxel_resolution)
+ offset += voxel_resolution * voxel_size / 2
+ pcd_b = pcd + offset
+ idx = ((0 <= pcd_b) & (pcd_b < voxel_size * voxel_resolution)).all(axis=1)
+ pcd_b, sem_b = pcd_b[idx], sem[idx]
+
+ Dx, Dy, Dz = voxel_size
+ # compute index for every point in a voxel
+ hxyz, hmod = np.divmod(pcd_b, voxel_resolution)
+ h = hxyz[:, 0] + hxyz[:, 1] * Dx + hxyz[:, 2] * Dx * Dy
+
+ # h_n = np.nonzero(np.bincount(h.astype(np.int32)))
+ h_idx = np.argsort(h) # sort the h.
+ # h, hxyz, pcd_b, hmod are all arranged in ascending order according to the value of 'h'
+ h, hxyz, sem_b, pcd_b, hmod = h[h_idx], hxyz[h_idx], sem_b[h_idx], pcd_b[h_idx], hmod[h_idx]
+ # Retrieve all unique values, 'h_n', in 'h' along with their 'indices'.
+ # i.e. Obtain all points residing within the same voxel grid.
+ h_n, indices = np.unique(h, return_index=True)
+ n_f = h_n.shape[0] # number of filtered points (i.e. occupied voxel grids.)
+ n_all = h.shape[0] # number of all points
+ voxels = np.zeros((n_f, 3), dtype=np.uint16) # coordinates of occupied voxel grids.
+ semantics = np.zeros((n_f, ), dtype=np.uint8) # labels of occupied voxel grids.
+ # points_f = np.zeros((n_f, 3))
+ road_idx = np.where(LABEL_CLASS == 'roadlines')[0][0] # label idx of 'roadline'
+ # voxels = []
+ # semantics = []
+ # points_f = []
+ for i in range(n_f): # for each filtered point. i.e. occupied voxel grid.
+
+ # Retrieve the indices of all points within 'h' (all original points)
+ # that are located in this voxel grid (sharing the same voxel grid as the filtered points).
+ # idx_ = (h == h_n[i])
+
+ # the same as previous line.
+ # Since 'h' has been sorted beforehand, all points between indices[i] and indices[i+1] are in the same voxel grid.
+ idx_ = np.arange(indices[i], indices[i+1]) if i < n_f - 1 else np.arange(indices[i], n_all)
+ # The distances from all points located in this voxel grid to the center of this voxel grid.
+ dis = np.sum(hmod[idx_] ** 2, axis=1)
+ # The label of the nearest point is thus the label of this voxel grid.
+ # or absolutely you can choose the label that occurs most.
+ # 'roadline' is very thin and easily overlooked, so if this voxel grid contains 'roadline',
+ # then the label for this voxel grid is directly assigned as 'roadline'.
+ semantic = sem_b[idx_][np.argmin(dis)] if not np.isin(sem_b[idx_], road_idx).any() else road_idx
+ # semantic = np.bincount(sem_b.squeeze()[idx_]).argmax() if not np.isin(sem_b[idx_], road_idx).any() else road_idx
+ voxels[i] = hxyz[idx_][0] # get the coordinate of this voxel grid.
+ semantics[i] = semantic
+
+ # get the coordinates of filtered points. (not coordinate of voxels)
+ # points_f[i] = pcd_b[idx_].mean(axis=0) - offset
+ # voxels.append(hxyz[idx_][0])
+ # semantics.append(semantic)
+ # points_f.append(pcd_b[idx_].mean(axis=0) - center)
+
+ return voxels, semantics
+
+
+def transform_pcd(xyz, transition):
+ xyz[:, 1] *= -1
+ xyz += np.asarray(transition)
+ return xyz
+
+
+def process_pcd(lidar_unprocessed, transition):
+ xyz = lidar_unprocessed['data']['points_xyz']
+ xyz = transform_pcd(xyz, transition)
+ if len(lidar_unprocessed['data'].keys()) == 2:
+ intensity = lidar_unprocessed['data']['intensity']
+ return np.concatenate([xyz, intensity[:, None]], axis=1)
+ elif len(lidar_unprocessed['data'].keys()) == 4:
+ sem = lidar_unprocessed['data']['ObjTag']
+ idx = lidar_unprocessed['data']['ObjIdx']
+ cos = lidar_unprocessed['data']['CosAngel']
+ return {'points': xyz, 'semantics': sem, 'ObjIdx': idx, 'CosAngel': cos}
diff --git a/data/dataset_utils.py b/data/dataset_utils.py
new file mode 100644
index 0000000..00f5997
--- /dev/null
+++ b/data/dataset_utils.py
@@ -0,0 +1,121 @@
+import torch
+import numpy as np
+
+import carla
+import carla_gym.utils.transforms as trans_utils
+import carla_gym.core.task_actor.common.navigation.route_manipulation as gps_util
+
+
+def binary_to_integer(binary_array, n_bits):
+ """
+ Parameters
+ ----------
+ binary_array: shape (n, n_bits)
+
+ Returns
+ -------
+ integer_array: shape (n,) np.int32
+ """
+ return (binary_array @ 2 ** np.arange(n_bits, dtype=binary_array.dtype)).astype(np.int32)
+
+
+def integer_to_binary(integer_array, n_bits):
+ """
+ Parameters
+ ----------
+ integer_array: np.ndarray (n,)
+ n_bits: int
+
+ Returns
+ -------
+ binary_array: np.ndarray (n, n_bits)
+
+ """
+ return (((integer_array[:, None] & (1 << np.arange(n_bits)))) > 0).astype(np.float32)
+
+
+def calculate_birdview_labels(birdview, n_classes, has_time_dimension=False):
+ """
+ Parameters
+ ----------
+ birdview: torch.Tensor (C, H, W)
+ n_classes: int
+ number of total classes
+ has_time_dimension: bool
+
+ Returns
+ -------
+ birdview_label: (H, W)
+ """
+ # When a pixel contains two labels, argmax will output the first one that is encountered.
+ # By reversing the order, we prioritise traffic lights over road.
+ dim = 0
+ if has_time_dimension:
+ dim = 1
+ birdview_label = torch.argmax(birdview.flip(dims=[dim]), dim=dim)
+ # We then re-normalise the classes in the normal order.
+ birdview_label = (n_classes - 1) - birdview_label
+ return birdview_label
+
+
+def preprocess_measurements(route_command, ego_gps, target_gps, imu):
+ # preprocess measurements
+ route_command = route_command.copy()
+ route_command[route_command < 0] = 4
+ route_command -= 1
+ route_command = np.array(route_command[0], dtype=np.int64)
+
+ loc_in_ev = preprocess_gps(ego_gps, target_gps, imu)
+ gps_vector = np.array([loc_in_ev.x, loc_in_ev.y], dtype=np.float32)
+ return route_command, gps_vector
+
+
+def preprocess_gps(ego_gps, target_gps, imu):
+ # imu nan bug
+ compass = 0.0 if np.isnan(imu[-1]) else imu[-1]
+ target_vec_in_global = gps_util.gps_to_location(target_gps) - gps_util.gps_to_location(ego_gps)
+ ref_rot_in_global = carla.Rotation(yaw=np.rad2deg(compass) - 90.0)
+ loc_in_ev = trans_utils.vec_global_to_ref(target_vec_in_global, ref_rot_in_global)
+ return loc_in_ev
+
+
+def preprocess_birdview_and_routemap(birdview):
+ ROUTE_MAP_INDEX = 1
+ # road, lane markings, vehicles, pedestrians
+ relevant_indices = [0, 2, 6, 10]
+
+ if isinstance(birdview, np.ndarray):
+ birdview = torch.from_numpy(birdview)
+ has_time_dimension = True
+ if len(birdview.shape) == 3:
+ birdview = birdview.unsqueeze(0)
+ has_time_dimension = False
+ # Birdview has values in {0, 255}. Convert to {0, 1}
+
+ # lights and stops
+ light_and_stop_channel = birdview[:, -1:]
+ green_light = (light_and_stop_channel == 80).float()
+ yellow_light = (light_and_stop_channel == 170).float()
+ red_light_and_stop = (light_and_stop_channel == 255).float()
+
+ remaining = birdview[:, relevant_indices]
+ remaining[remaining > 0] = 1
+ remaining = remaining.float()
+
+ # Traffic light and stop.
+ processed_birdview = torch.cat([remaining, green_light, yellow_light, red_light_and_stop], dim=1)
+ # background
+ tmp = processed_birdview.sum(dim=1, keepdim=True)
+ background = (tmp == 0).float()
+
+ processed_birdview = torch.cat([background, processed_birdview], dim=1)
+
+ # Route map
+ route_map = birdview[:, ROUTE_MAP_INDEX]
+ route_map[route_map > 0] = 255
+ route_map = route_map.to(torch.uint8)
+
+ if not has_time_dimension:
+ processed_birdview = processed_birdview[0]
+ route_map = route_map[0]
+ return processed_birdview, route_map
\ No newline at end of file
diff --git a/data/generate_voxels.py b/data/generate_voxels.py
new file mode 100644
index 0000000..f26d462
--- /dev/null
+++ b/data/generate_voxels.py
@@ -0,0 +1,168 @@
+import numpy as np
+import hydra
+import pandas as pd
+from omegaconf import DictConfig, OmegaConf
+from pathlib import Path
+import shutil
+import scipy.sparse as sp
+import re
+from tqdm import tqdm
+from clearml import Task
+import logging
+from multiprocessing import Pool, RLock, Pipe
+from threading import Thread
+
+from data_preprocessing import *
+
+log = logging.getLogger(__name__)
+
+
+def progress_bar_total(parent, total_len, desc):
+ desc = desc if desc else "Main"
+ pbar_main = tqdm(total=total_len, desc=desc, position=0)
+ nums = 0
+ while True:
+ msg = parent.recv()[0]
+ if msg is not None:
+ pbar_main.update()
+ nums += 1
+ if nums == total_len:
+ break
+ pbar_main.close()
+
+
+def voxelize_dir(data_path, cfg, task_idx, all_task, pipe):
+ log.info(f'Converting Depth Image to Voxels in {data_path}.')
+ save_path = data_path.parent
+ # if not save_path.exists():
+ # save_path.mkdir()
+ file_list = sorted([str(f) for f in data_path.glob('*.png')])
+ data_dict = {}
+ voxels_dict = {}
+ for file in tqdm(file_list, desc=f'{task_idx + 1:04} / {all_task:04}', position=task_idx % cfg.n_process + 1):
+ depth, semantic, _ = read_img(file)
+ points_list, sem_list = get_all_points(
+ depth, semantic, fov=cfg.fov, size=cfg.size, offset=cfg.offset, mask_ego=cfg.mask_ego)
+ voxel_points, semantics = voxel_filter(points_list, sem_list, cfg.voxel_size, cfg.center)
+ data = np.concatenate([voxel_points, semantics[:, None]], axis=1)
+ voxels = np.zeros(shape=(2 * np.asarray(cfg.center) / cfg.voxel_size).astype(int), dtype=np.uint8)
+ voxels[voxel_points[:, 0], voxel_points[:, 1], voxel_points[:, 2]] = semantics
+ name = re.match(r'.*/.*_(\d{9})\.png', file).group(1)
+ # np.savez_compressed(f"{save_path}/{name}.npz", data=data)
+ coo_voxels = sp.coo_matrix(voxels.reshape(voxels.shape[0], -1))
+ # np.savez_compressed(f"{save_path}/v{name}.npz", data=coo_voxels)
+ voxels_dict[name] = coo_voxels
+ data_dict[name] = data
+
+ if pipe is not None:
+ pipe.send(['x'])
+
+ np.savez_compressed(f"{save_path}/voxels.npz", coo_voxels=voxels_dict, voxel_points=data_dict)
+ log.info(f"Saved Voxels Data in {save_path}/voxels.npz.")
+
+
+def voxelize_one(depth_file, lidar_file, cfg, save_name, pipe=None):
+ pcd, sem = merge_pcd(depth_file, lidar_file, cfg.camera_position, cfg.lidar_position, cfg.fov)
+ offset_x = cfg.bev_offset_forward * cfg.bev_resolution
+ offset_z = cfg.offset_z * cfg.voxel_resolution
+ voxel_points, semantics = voxel_filter(pcd, sem, cfg.voxel_resolution, cfg.voxel_size, [offset_x, 0, offset_z])
+ data = np.concatenate([voxel_points, semantics[:, None]], axis=1)
+ # voxels = np.zeros(shape=cfg.voxel_size, dtype=np.uint8)
+ # voxels[voxel_points[:, 0], voxel_points[:, 1], voxel_points[:, 2]] = semantics
+ # csr_voxels = sp.csr_matrix(voxels.reshape(voxels.shape[0], -1))
+ np.save(f'{save_name}', data)
+ # np.save(f'{save_path}/voxel_coo/voxel_coo_{name}.npy', csr_voxels)
+
+ if pipe is not None:
+ pipe.send(['x'])
+
+
+@hydra.main(config_path='./', config_name='data_preprocess')
+def main_(cfg: DictConfig):
+ task = Task.init(project_name=cfg.cml_project, task_name=cfg.cml_task_name, task_type=cfg.cml_task_type,
+ tags=cfg.cml_tags)
+ task.connect(cfg)
+ cml_logger = task.get_logger()
+
+ root_path = Path(cfg.root)
+ data_paths = sorted([p for p in root_path.glob('**/depth_semantic') if p.is_dir()])
+ n_files = len([f for f in root_path.glob('**/depth_semantic/*.png')])
+
+ if not root_path.exists() or n_files == 0:
+ print('Root Path does not EXSIT or there are NO LEGAL files!!!')
+ return
+
+ log.info(f'{n_files} will be proceed.')
+
+ parent, child = Pipe()
+ main_thread = Thread(target=progress_bar_total, args=(parent, n_files, "Total"))
+ main_thread.start()
+ p = Pool(cfg.n_process, initializer=tqdm.set_lock, initargs=(RLock(),))
+ for i, path in enumerate(data_paths):
+ p.apply_async(func=voxelize_dir, args=(path, cfg, i, len(data_paths), child))
+ p.close()
+ p.join()
+ main_thread.join()
+
+ log.info("Finished Voxelization!")
+
+
+@hydra.main(config_path='./', config_name='data_preprocess')
+def main(cfg: DictConfig):
+ # task = Task.init(project_name=cfg.cml_project, task_name=cfg.cml_task_name, task_type=cfg.cml_task_type,
+ # tags=cfg.cml_tags)
+ # task.connect(cfg)
+ # cml_logger = task.get_logger()
+ root_path = Path(cfg.root)
+ data_paths = sorted([p for p in root_path.glob('**/Town*/*/') if p.is_dir()])
+
+ if not root_path.exists() or len(data_paths) == 0:
+ print('Root Path does not EXIST or there are NO LEGAL files!!!')
+ return
+
+ log.info(f'{len(data_paths)} runs will be voxelized.')
+ log.info(f'{data_paths}')
+
+ for i, path in enumerate(data_paths):
+ pd_file = f'{path}/pd_dataframe.pkl'
+ pd_dataframe = pd.read_pickle(pd_file)
+ data_len = len(pd_dataframe)
+
+ # parent, child = Pipe()
+ # main_thread = Thread(target=progress_bar_total, args=(parent, data_len, f'{i+1}/{len(data_paths)}'))
+ # main_thread.start()
+ p = Pool(cfg.n_process)
+
+ log.info(f'start voxelizing in dir {path}.')
+
+ save_path = path.joinpath('voxel')
+ if save_path.exists():
+ shutil.rmtree(f'{save_path}')
+ save_path.mkdir()
+
+ voxel_paths = []
+ pbar = tqdm(total=data_len, desc=f'{i+1}/{len(data_paths)}')
+ for j in range(data_len):
+ # voxelize_one(depth_file, lidar_file, cfg, save_path)
+ data_row = pd_dataframe.iloc[j]
+ depth_file = str(path.joinpath(data_row['depth_semantic_path']))
+ lidar_file = str(path.joinpath(data_row['points_semantic_path']))
+ name = re.match(r'.*/.*_(\d{9})\.png', depth_file).group(1)
+ name_ = re.match(r'.*/.*_(\d{9})\.npy', lidar_file).group(1)
+ assert name == name_, 'file sequence is false.'
+ file_name = f'{save_path.name}/voxel_{name}.npy'
+ save_name = f'{path}/{file_name}'
+ p.apply_async(func=voxelize_one, args=(depth_file, lidar_file, cfg, save_name),
+ callback=lambda _: pbar.update())
+ voxel_paths.append(file_name)
+ p.close()
+ p.join()
+ # main_thread.join()
+ pbar.close()
+ pd_dataframe['voxel_path'] = voxel_paths
+ pd_dataframe.to_pickle(pd_file)
+ log.info(f'finished, save in {save_path}')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/data/pcd.py b/data/pcd.py
new file mode 100644
index 0000000..2cf57b4
--- /dev/null
+++ b/data/pcd.py
@@ -0,0 +1,80 @@
+import pandas as pd
+
+from data_preprocessing import *
+import hydra
+from omegaconf import DictConfig, OmegaConf
+from pathlib import Path
+from tqdm import tqdm
+# from clearml import Task
+import logging
+from multiprocessing import Pool, RLock
+
+log = logging.getLogger(__name__)
+
+
+def process_one_file(data_path, transition, task_idx, all_task, position, is_semantic=True):
+ if is_semantic:
+ save_path = data_path.joinpath('points_semantic')
+ file = data_path.joinpath('point_clouds_semantic.npy')
+ prefix = 'points_semantic'
+ else:
+ save_path = data_path.joinpath('points')
+ file = data_path.joinpath('point_clouds.npy')
+ prefix = 'points'
+ if not save_path.exists():
+ save_path.mkdir()
+ path_list = []
+ try:
+ pcd_list, _, _ = load_lidar(file)
+ pbar = tqdm(total=len(pcd_list), desc=f'{task_idx + 1:04} / {all_task:04}',
+ position=position, postfix='semantic' if is_semantic else 'points')
+ for name, lidar_unprocessed in pcd_list.items():
+ lidar_processed = process_pcd(lidar_unprocessed, transition)
+ np.save(f'{save_path}/{prefix}_{name}.npy', lidar_processed)
+ path_list.append(f'{save_path.name}/{prefix}_{name}.npy')
+ pbar.update()
+ except Exception as e:
+ log.error(f'{e}')
+
+ log.info(f"Saved processed points clouds in {save_path}.")
+ return path_list
+
+
+def process_dir(data_path, cfg, task_idx, all_task):
+ log.info(f'Process points clouds in {data_path}.')
+ pd_file = f'{data_path}/pd_dataframe.pkl'
+ pd_dataframe = pd.read_pickle(pd_file)
+ points_semantic_path = process_one_file(data_path, cfg.lidar_transition, task_idx, all_task, task_idx % cfg.n_process)
+ pd_dataframe['points_semantic_path'] = points_semantic_path
+ points_path = process_one_file(data_path, cfg.lidar_transition, task_idx, all_task, task_idx % cfg.n_process, False)
+ pd_dataframe['points_path'] = points_path
+ pd_dataframe.to_pickle(pd_file)
+
+
+@hydra.main(config_path='./', config_name='data_preprocess')
+def main(cfg: DictConfig):
+ # task = Task.init(project_name=cfg.cml_project, task_name=cfg.cml_task_name, task_type=cfg.cml_task_type,
+ # tags=cfg.cml_tags)
+ # task.connect(cfg)
+ # cml_logger = task.get_logger()
+
+ root_path = Path(cfg.root)
+ data_paths = sorted([p for p in root_path.glob('**/Town*/*/') if p.is_dir()])
+
+ if not root_path.exists() or len(data_paths) == 0:
+ print('Root Path does not EXSIT or there are NO LEGAL files!!!')
+ return
+
+ log.info(f'{len(data_paths)} points sequences will be proceed.')
+
+ p = Pool(cfg.n_process, initializer=tqdm.set_lock, initargs=(RLock(),))
+ for i, path in enumerate(data_paths):
+ p.apply_async(func=process_dir, args=(path, cfg, i, len(data_paths)))
+ p.close()
+ p.join()
+
+ log.info("Finished!")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/data_collect.py b/data_collect.py
new file mode 100644
index 0000000..2ca5ffc
--- /dev/null
+++ b/data_collect.py
@@ -0,0 +1,301 @@
+import carla
+import gym
+import math
+from pathlib import Path
+import json
+import numpy as np
+from tqdm import tqdm
+# import wandb
+from clearml import Task
+import pandas as pd
+import hydra
+from omegaconf import DictConfig, OmegaConf
+import logging
+import subprocess
+import os
+import sys
+from constants import CARLA_FPS
+
+from stable_baselines3.common.vec_env.base_vec_env import tile_images
+
+from carla_gym.utils import config_utils
+from utils import saving_utils, server_utils
+from agents.rl_birdview.utils.wandb_callback import WandbCallback
+
+log = logging.getLogger(__name__)
+
+
+def run_single(run_name, env, data_writer, driver_dict, driver_log_dir, log_video, remove_final_steps, pbar):
+ list_debug_render = []
+ list_data_render = []
+ ep_stat_dict = {}
+ ep_event_dict = {}
+
+ for actor_id, driver in driver_dict.items():
+ log_dir = driver_log_dir / actor_id
+ log_dir.mkdir(parents=True, exist_ok=True)
+ driver.reset(log_dir / f'{run_name}.log')
+
+ obs = env.reset()
+ timestamp = env.timestamp
+ done = {'__all__': False}
+ valid = True
+ # pbar = tqdm(total=CARLA_FPS*2)
+ while not done['__all__']:
+ driver_control = {}
+ driver_supervision = {}
+
+ for actor_id, driver in driver_dict.items():
+ driver_control[actor_id] = driver.run_step(obs[actor_id], timestamp)
+ driver_supervision[actor_id] = driver.supervision_dict
+ # control = carla.VehicleControl(throttle=1.0, steer=0.0, brake=0.0)
+ # driver_control[actor_id] = control
+ # driver_supervision[actor_id] = {'action': np.array([1.0, 0.0, 0.0]),
+ # 'speed': obs[actor_id]['speed']['forward_speed']
+ # }
+
+ new_obs, reward, done, info = env.step(driver_control)
+
+ im_rgb = data_writer.write(timestamp=timestamp, obs=obs,
+ supervision=driver_supervision, reward=reward, control_diff=None,
+ weather=env.world.get_weather())
+
+ obs = new_obs
+
+ # debug_imgs = []
+ for actor_id, driver in driver_dict.items():
+ # if log_video:
+ # debug_imgs.append(driver.render(info[actor_id]['reward_debug'], info[actor_id]['terminal_debug']))
+ if done[actor_id] and (actor_id not in ep_stat_dict):
+ episode_stat = info[actor_id]['episode_stat']
+ ep_stat_dict[actor_id] = episode_stat
+ ep_event_dict[actor_id] = info[actor_id]['episode_event']
+
+ valid = data_writer.close(
+ info[actor_id]['terminal_debug'],
+ remove_final_steps, None)
+ log.info(f'Episode {run_name} done, valid={valid}')
+
+ # if log_video:
+ # list_debug_render.append(tile_images(debug_imgs))
+ # list_data_render.append(im_rgb)
+ timestamp = env.timestamp
+ pbar.update(1)
+
+ return valid, list_debug_render, list_data_render, ep_stat_dict, ep_event_dict, timestamp
+
+
+@hydra.main(config_path='config', config_name='data_collect')
+def main(cfg: DictConfig):
+ # if cfg.host == 'localhost' and cfg.kill_running:
+ # server_utils.kill_carla(cfg.port)
+ log.setLevel(getattr(logging, cfg.log_level.upper()))
+
+ # start carla servers
+ # server_manager = server_utils.CarlaServerManager(
+ # cfg.carla_sh_path, port=cfg.port, render_off_screen=cfg.render_off_screen)
+ # server_manager.start()
+
+ driver_dict = {}
+ obs_configs = {}
+ reward_configs = {}
+ terminal_configs = {}
+ for ev_id, ev_cfg in cfg.actors.items():
+ # initiate driver agent
+ cfg_driver = cfg.agent[ev_cfg.driver]
+ OmegaConf.save(config=cfg_driver, f='config_driver.yaml')
+ DriverAgentClass = config_utils.load_entry_point(cfg_driver.entry_point)
+ driver_dict[ev_id] = DriverAgentClass('config_driver.yaml')
+ obs_configs[ev_id] = driver_dict[ev_id].obs_configs
+ # driver_dict[ev_id] = 'hero'
+ # obs_configs[ev_id] = OmegaConf.to_container(cfg_driver.obs_configs)
+
+ for k, v in OmegaConf.to_container(cfg.agent.my.obs_configs).items():
+ if k not in obs_configs[ev_id]:
+ obs_configs[ev_id][k] = v
+
+ # get obs_configs from agent
+ reward_configs[ev_id] = OmegaConf.to_container(ev_cfg.reward)
+ terminal_configs[ev_id] = OmegaConf.to_container(ev_cfg.terminal)
+
+ OmegaConf.save(config=obs_configs, f='obs_config.yaml')
+
+ # check h5 birdview maps have been generated
+ config_utils.check_h5_maps(cfg.test_suites, obs_configs, cfg.carla_sh_path)
+
+ last_checkpoint_path = f'{cfg.work_dir}/port_{cfg.port}_checkpoint.txt'
+ if cfg.resume and os.path.isfile(last_checkpoint_path):
+ with open(last_checkpoint_path, 'r') as f:
+ env_idx = int(f.read())
+ else:
+ env_idx = 0
+
+ # resume task_idx from ep_stat_buffer_{env_idx}.json
+ ep_state_buffer_json = f'{cfg.work_dir}/port_{cfg.port}_ep_stat_buffer_{env_idx}.json'
+ if cfg.resume and os.path.isfile(ep_state_buffer_json):
+ ep_stat_buffer = json.load(open(ep_state_buffer_json, 'r'))
+ ckpt_task_idx = len(ep_stat_buffer['hero'])
+ else:
+ ckpt_task_idx = 0
+ ep_stat_buffer = {}
+ for actor_id in driver_dict.keys():
+ ep_stat_buffer[actor_id] = []
+
+ # resume clearml task
+ cml_checkpoint_path = f'{cfg.work_dir}/port_{cfg.port}_cml_task_id.txt'
+ if cfg.resume and os.path.isfile(cml_checkpoint_path):
+ with open(cml_checkpoint_path, 'r') as f:
+ cml_task_id = f.read()
+ else:
+ cml_task_id = False
+ # env_idx = 0
+ # ckpt_task_idx = 0
+
+ log.info(f'Start from env_idx: {env_idx}, task_idx {ckpt_task_idx}')
+
+ # make save directories
+ dataset_root = Path(cfg.dataset_root)
+ dataset_root.mkdir(parents=True, exist_ok=True)
+ im_stack_idx = [-1]
+ # cml_task_name = f'{dataset_root.name}'
+
+ dataset_dir = Path(os.path.join(cfg.dataset_root, cfg.test_suites[env_idx]['env_configs']['carla_map']))
+ dataset_dir.mkdir(parents=True, exist_ok=True)
+
+ diags_dir = Path('diagnostics')
+ driver_log_dir = Path('driver_log')
+ video_dir = Path('videos')
+ diags_dir.mkdir(parents=True, exist_ok=True)
+ driver_log_dir.mkdir(parents=True, exist_ok=True)
+ video_dir.mkdir(parents=True, exist_ok=True)
+
+ # init wandb
+ task = Task.init(project_name=cfg.cml_project, task_name=cfg.cml_task_name, task_type=cfg.cml_task_type,
+ tags=cfg.cml_tags, continue_last_task=cml_task_id)
+ task.connect(cfg)
+ cml_logger = task.get_logger()
+ with open(cml_checkpoint_path, 'w') as f:
+ f.write(task.task_id)
+
+ # This is used in case we re-run the data_collect job after it has been interrupted for example.
+ if env_idx >= len(cfg.test_suites):
+ log.info(f'Finished! env_idx: {env_idx}, resave to wandb')
+ # server_manager.stop()
+ return
+
+ env_setup = OmegaConf.to_container(cfg.test_suites[env_idx])
+
+ env = gym.make(env_setup['env_id'], obs_configs=obs_configs, reward_configs=reward_configs,
+ terminal_configs=terminal_configs, host=cfg.host, port=cfg.port,
+ seed=cfg.seed, no_rendering=cfg.no_rendering, **env_setup['env_configs'])
+
+ # main loop
+ n_episodes_per_env = math.ceil(cfg.n_episodes / len(cfg.test_suites))
+
+ for task_idx in range(ckpt_task_idx, n_episodes_per_env):
+ idx_episode = task_idx + n_episodes_per_env * env_idx
+ run_name = f'{idx_episode:04}'
+ log.info(f"Start data collection env_idx {env_idx}, task_idx {task_idx}, run_name {run_name}")
+
+ while True:
+ pbar = tqdm(
+ total=CARLA_FPS*cfg.run_time,
+ desc=f"Env {env_idx:02} / {len(cfg.test_suites):02} - Task {task_idx:04} / {n_episodes_per_env:04}")
+ env.set_task_idx(np.random.choice(env.num_tasks))
+
+ run_info = {
+ 'is_expert': True,
+ 'weather': env.task['weather'],
+ 'town': cfg.test_suites[env_idx]['env_configs']['carla_map'],
+ 'n_vehicles': env.task['num_zombie_vehicles'],
+ 'n_walkers': env.task['num_zombie_walkers'],
+ 'route_id': env.task['route_id'],
+ 'env_id': cfg.test_suites[env_idx]['env_id'],
+
+ }
+ save_birdview_label = 'birdview_label' in obs_configs['hero']
+ data_writer = saving_utils.DataWriter(dataset_dir / f'{run_name}', cfg.ev_id, im_stack_idx,
+ run_info=run_info,
+ save_birdview_label=save_birdview_label,
+ render_image=cfg.log_video)
+
+ valid, list_debug_render, list_data_render, ep_stat_dict, ep_event_dict, timestamp = run_single(
+ run_name, env, data_writer, driver_dict, driver_log_dir,
+ cfg.log_video,
+ cfg.remove_final_steps,
+ pbar)
+
+ if valid:
+ break
+
+ diags_json_path = (diags_dir / f'{run_name}.json').as_posix()
+ with open(diags_json_path, 'w') as fd:
+ json.dump(ep_event_dict, fd, indent=4, sort_keys=False)
+
+ # save time
+ cml_logger.report_table(
+ title='time',
+ series=run_name,
+ iteration=idx_episode,
+ table_plot=pd.DataFrame({'total_step': timestamp['step'],
+ 'fps': timestamp['step'] / timestamp['relative_wall_time']
+ }, index=['time']))
+
+ # save statistics
+ # for actor_id, ep_stat in ep_stat_dict.items():
+ # ep_stat_buffer[actor_id].append(ep_stat)
+ # log_dict = {}
+ # for k, v in ep_stat.items():
+ # k_actor = f'{actor_id}/{k}'
+ # log_dict[k_actor] = v
+ # wandb.log(log_dict, step=idx_episode)
+ cml_logger.report_table(
+ title='statistics', series=run_name, iteration=idx_episode, table_plot=pd.DataFrame(ep_stat_dict))
+
+ with open(ep_state_buffer_json, 'w') as fd:
+ json.dump(ep_stat_buffer, fd, indent=4, sort_keys=True)
+
+ # clean up
+ list_debug_render.clear()
+ list_data_render.clear()
+ ep_stat_dict = None
+ ep_event_dict = None
+
+ saving_utils.report_dataset_size(dataset_dir)
+ dataset_size = subprocess.check_output(['du', '-sh', dataset_dir]).split()[0].decode('utf-8')
+ log.warning(f'{dataset_dir}: dataset_size {dataset_size}')
+
+ env.close()
+ env = None
+ # server_manager.stop()
+
+ # log after all episodes are completed
+ table_data = []
+ ep_stat_keys = None
+ for actor_id, list_ep_stat in json.load(open(ep_state_buffer_json, 'r')).items():
+ avg_ep_stat = WandbCallback.get_avg_ep_stat(list_ep_stat)
+ data = [actor_id, cfg.actors[actor_id].driver, env_idx, str(len(list_ep_stat))]
+ if ep_stat_keys is None:
+ ep_stat_keys = list(avg_ep_stat.keys())
+ data += [f'{avg_ep_stat[k]:.4f}' for k in ep_stat_keys]
+ table_data.append(data)
+
+ table_columns = ['actor_id', 'driver', 'env_idx', 'n_episode'] + ep_stat_keys
+ # wandb.log({'table/summary': wandb.Table(data=table_data, columns=table_columns)})
+ cml_logger.report_table(title='table', series='summary', iteration=env_idx,
+ table_plot=pd.DataFrame(table_data, columns=table_columns))
+
+ with open(last_checkpoint_path, 'w') as f:
+ f.write(f'{env_idx + 1}')
+
+ log.info(f"Finished data collection env_idx {env_idx}, {env_setup['env_id']}.")
+ if env_idx + 1 == len(cfg.test_suites):
+ log.info(f"Finished, {env_idx + 1}/{len(cfg.test_suites)}")
+ return
+ else:
+ log.info(f"Not finished, {env_idx + 1}/{len(cfg.test_suites)}")
+ sys.exit(1)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/muvo/config.py b/muvo/config.py
new file mode 100644
index 0000000..3870c27
--- /dev/null
+++ b/muvo/config.py
@@ -0,0 +1,369 @@
+import argparse
+from fvcore.common.config import CfgNode as _CfgNode
+
+
+def convert_to_dict(cfg_node, key_list=[]):
+ """Convert a config node to dictionary."""
+ _VALID_TYPES = {tuple, list, str, int, float, bool}
+ if not isinstance(cfg_node, _CfgNode):
+ if type(cfg_node) not in _VALID_TYPES:
+ print(
+ 'Key {} with value {} is not a valid type; valid types: {}'.format(
+ '.'.join(key_list), type(cfg_node), _VALID_TYPES
+ ),
+ )
+ return cfg_node
+ else:
+ cfg_dict = dict(cfg_node)
+ for k, v in cfg_dict.items():
+ cfg_dict[k] = convert_to_dict(v, key_list + [k])
+ return cfg_dict
+
+
+class CfgNode(_CfgNode):
+ """Remove once https://github.com/rbgirshick/yacs/issues/19 is merged."""
+
+ def convert_to_dict(self):
+ return convert_to_dict(self)
+
+CN = CfgNode
+
+_C = CN()
+_C.LOG_DIR = 'tensorboard_logs'
+_C.TAG = 'default'
+_C.CML_PROJECT = ''
+_C.CML_TASK = ''
+_C.CML_TYPE = ''
+_C.CML_DATASET = ''
+
+_C.GPUS = 1 # how many gpus to use
+_C.PRECISION = '16-mixed' # 16bit or 32bit
+_C.BATCHSIZE = 3
+_C.STEPS = 50000
+_C.N_WORKERS = 4
+
+_C.VAL_CHECK_INTERVAL = 5000
+_C.LOGGING_INTERVAL = 500
+_C.LIMIT_VAL_BATCHES = 1
+_C.LOG_VIDEO_INTERVAL = 5000
+
+_C.RECEPTIVE_FIELD = 1
+_C.FUTURE_HORIZON = 1
+
+_C.PREDICTION = CN()
+_C.PREDICTION.N_SAMPLES = 2
+
+###########
+# Optimizer
+###########
+_C.OPTIMIZER = CN()
+_C.OPTIMIZER.LR = 1e-4
+_C.OPTIMIZER.WEIGHT_DECAY = 0.01
+_C.OPTIMIZER.ACCUMULATE_GRAD_BATCHES = 1
+_C.OPTIMIZER.FROZEN = CN()
+_C.OPTIMIZER.FROZEN.ENABLED = False
+_C.OPTIMIZER.FROZEN.TRAIN_LIST = []
+
+_C.SCHEDULER = CN()
+_C.SCHEDULER.NAME = 'OneCycleLR'
+_C.SCHEDULER.PCT_START = 0.2
+
+#########
+# Dataset
+#########
+_C.DATASET = CN()
+_C.DATASET.DATAROOT = ''
+_C.DATASET.VERSION = 'trainval'
+_C.DATASET.STRIDE_SEC = 0.2 # stride between two frames
+_C.DATASET.FILTER_BEGINNING_OF_RUN_SEC = 1.0 # in seconds. the beginning of the run is stationary.
+_C.DATASET.FILTER_NORM_REWARD = 0.6 # filter runs that have a normalised reward below this value.
+
+#############
+# Input lidar points
+#############
+_C.POINTS = CN()
+_C.POINTS.LIDAR_POSITION = [1.0, 0.0, 2.0]
+_C.POINTS.LIDAR_ROTATION = [0.0, 0.0, 0.0]
+_C.POINTS.FOV = [-30, 10]
+_C.POINTS.CHANNELS = 64
+_C.POINTS.N_PER_SECOND = 600000
+_C.POINTS.HORIZON_RESOLUTION = 1024
+
+_C.POINTS.HISTOGRAM = CN()
+_C.POINTS.HISTOGRAM.RESOLUTION = 10 # pixels per meter
+_C.POINTS.HISTOGRAM.HIST_MAX = 5 # max histogram per pixel
+_C.POINTS.HISTOGRAM.X_RANGE = 384 # in pxs
+_C.POINTS.HISTOGRAM.Y_RANGE = 384 # in pxs
+_C.POINTS.HISTOGRAM.Z_RANGE = 128 # in pxs
+
+#############
+# Input Voxels
+#############
+_C.VOXEL = CN()
+_C.VOXEL.SIZE = [192, 192, 64]
+_C.VOXEL.RESOLUTION = 0.2
+_C.VOXEL.EV_POSITION = [32, 96, 12]
+
+#############
+# Input image
+#############
+_C.IMAGE = CN()
+_C.IMAGE.SIZE = (600, 960)
+_C.IMAGE.CROP = [64, 138, 896, 458] # (left, top, right, bottom)
+_C.IMAGE.FOV = 100
+_C.IMAGE.CAMERA_POSITION = [1.0, 0.0, 2.0] # (forward, right, up)
+
+# carla defines as (pitch, yaw, roll)
+# /!\ roach defines as (roll, pitch, yaw)
+# this is fine for now because all of them are equal to zero.
+_C.IMAGE.CAMERA_ROTATION = [0.0, 0.0, 0.0]
+_C.IMAGE.IMAGENET_MEAN = (0.485, 0.456, 0.406)
+_C.IMAGE.IMAGENET_STD = (0.229, 0.224, 0.225)
+
+_C.IMAGE.AUGMENTATION = CN() # image augmentations
+_C.IMAGE.AUGMENTATION.BLUR_PROB = .3
+_C.IMAGE.AUGMENTATION.BLUR_WINDOW = 5
+_C.IMAGE.AUGMENTATION.BLUR_STD = [.1, 1.7]
+_C.IMAGE.AUGMENTATION.SHARPEN_PROB = .3
+_C.IMAGE.AUGMENTATION.SHARPEN_FACTOR = [1, 5]
+_C.IMAGE.AUGMENTATION.COLOR_PROB = .3
+_C.IMAGE.AUGMENTATION.COLOR_JITTER_BRIGHTNESS = .3
+_C.IMAGE.AUGMENTATION.COLOR_JITTER_CONTRAST = .3
+_C.IMAGE.AUGMENTATION.COLOR_JITTER_SATURATION = .3
+_C.IMAGE.AUGMENTATION.COLOR_JITTER_HUE = .1
+
+_C.BEV = CN()
+_C.BEV.SIZE = [192, 192] # width, height. note that the bev is rotated, so width corresponds to forward direction.
+_C.BEV.RESOLUTION = 0.2 # pixel size in m
+_C.BEV.OFFSET_FORWARD = -64 # offset of the center of gravity of the egocar relative to the center of bev in px
+_C.BEV.FEATURE_DOWNSAMPLE = 4 # Downsample factor for bev features
+
+_C.BEV.FRUSTUM_POOL = CN()
+_C.BEV.FRUSTUM_POOL.D_BOUND = [1.0, 38.0, 1.0]
+_C.BEV.FRUSTUM_POOL.SPARSE = True
+_C.BEV.FRUSTUM_POOL.SPARSE_COUNT = 10
+
+###########
+# Route map
+###########
+_C.ROUTE = CN()
+_C.ROUTE.SIZE = 64 # spatial resolution
+
+_C.ROUTE.AUGMENTATION_DROPOUT = .025
+_C.ROUTE.AUGMENTATION_END_OF_ROUTE = .025
+_C.ROUTE.AUGMENTATION_SMALL_ROTATION = .025
+_C.ROUTE.AUGMENTATION_LARGE_ROTATION = .025
+_C.ROUTE.AUGMENTATION_DEGREES = 8.
+_C.ROUTE.AUGMENTATION_TRANSLATE = (.1, .1)
+_C.ROUTE.AUGMENTATION_SCALE = (.95, 1.05)
+_C.ROUTE.AUGMENTATION_SHEAR = (.1, .1)
+
+#######
+# Speed
+#######
+_C.SPEED = CN()
+_C.SPEED.NOISE_STD = 1.4 # in m/s
+_C.SPEED.NORMALISATION = 5.0 # in m/s
+
+#######
+# Model
+#######
+_C.MODEL = CN()
+
+_C.MODEL.ACTION_DIM = 2
+
+_C.MODEL.TRANSFORMER = CN()
+_C.MODEL.TRANSFORMER.CHANNELS = 256
+_C.MODEL.TRANSFORMER.ENABLED = False
+_C.MODEL.TRANSFORMER.BEV = False
+_C.MODEL.TRANSFORMER.LARGE = False
+
+_C.MODEL.ENCODER = CN()
+_C.MODEL.ENCODER.NAME = 'resnet18'
+_C.MODEL.ENCODER.OUT_CHANNELS = 64
+
+_C.MODEL.BEV = CN()
+_C.MODEL.BEV.BACKBONE = 'resnet18'
+_C.MODEL.BEV.CHANNELS = 64
+
+_C.MODEL.LIDAR = CN()
+_C.MODEL.LIDAR.ENABLED = True
+_C.MODEL.LIDAR.MULTI_VIEW = False
+_C.MODEL.LIDAR.ENCODER = 'resnet18'
+_C.MODEL.LIDAR.OUT_CHANNELS = 64
+_C.MODEL.LIDAR.BACKBONE = 'resnet18'
+
+_C.MODEL.LIDAR.POINT_PILLAR = CN()
+_C.MODEL.LIDAR.POINT_PILLAR.ENABLED = False
+
+_C.MODEL.SPEED = CN()
+_C.MODEL.SPEED.CHANNELS = 16
+
+_C.MODEL.ROUTE = CN()
+_C.MODEL.ROUTE.ENABLED = True
+_C.MODEL.ROUTE.BACKBONE = 'resnet18'
+_C.MODEL.ROUTE.CHANNELS = 16
+
+_C.MODEL.MEASUREMENTS = CN()
+_C.MODEL.MEASUREMENTS.ENABLED = False
+_C.MODEL.MEASUREMENTS.COMMAND_CHANNELS = 8
+_C.MODEL.MEASUREMENTS.GPS_CHANNELS = 16
+
+_C.MODEL.EMBEDDING_DIM = 512
+
+_C.MODEL.TRANSITION = CN()
+_C.MODEL.TRANSITION.ENABLED = True
+_C.MODEL.TRANSITION.HIDDEN_STATE_DIM = 1024 # Dimention of the RNN hidden representation
+_C.MODEL.TRANSITION.STATE_DIM = 512 # Dimension of prior/posterior
+_C.MODEL.TRANSITION.ACTION_LATENT_DIM = 64 # Latent dimension of action
+_C.MODEL.TRANSITION.USE_DROPOUT = True
+_C.MODEL.TRANSITION.DROPOUT_PROBABILITY = 0.15
+
+###########
+# LOSSES
+###########
+_C.SEMANTIC_SEG = CN()
+_C.SEMANTIC_SEG.ENABLED = True
+_C.SEMANTIC_SEG.N_CHANNELS = 8
+_C.SEMANTIC_SEG.USE_TOP_K = True # backprop only top-k hardest pixels
+_C.SEMANTIC_SEG.TOP_K_RATIO = 0.25
+_C.SEMANTIC_SEG.USE_WEIGHTS = True
+
+# Always enabled with seg
+_C.INSTANCE_SEG = CN()
+_C.INSTANCE_SEG.CENTER_LABEL_SIGMA_PX = 4
+_C.INSTANCE_SEG.IGNORE_INDEX = 255
+_C.INSTANCE_SEG.CENTER_LOSS_WEIGHT = 200.0
+_C.INSTANCE_SEG.OFFSET_LOSS_WEIGHT = 0.1
+
+# Voxels reconstruction
+_C.VOXEL_SEG = CN()
+_C.VOXEL_SEG.ENABLED = True
+_C.VOXEL_SEG.DIMENSION = 256
+_C.VOXEL_SEG.N_CLASSES = 9
+_C.VOXEL_SEG.USE_TOP_K = False
+_C.VOXEL_SEG.TOP_K_RATIO = 0.5
+_C.VOXEL_SEG.USE_WEIGHTS = True
+
+# lidar reconstruction
+_C.LIDAR_RE = CN()
+_C.LIDAR_RE.ENABLED = True
+_C.LIDAR_RE.N_CHANNELS = 4
+_C.LIDAR_RE.SCALE = 50.0
+
+# lidar segmentation
+_C.LIDAR_SEG = CN()
+_C.LIDAR_SEG.ENABLED = True
+_C.LIDAR_SEG.N_CLASSES = 9
+_C.LIDAR_SEG.USE_TOP_K = True
+_C.LIDAR_SEG.TOP_K_RATIO = 0.5
+_C.LIDAR_SEG.USE_WEIGHTS = True
+
+# semantic image
+_C.SEMANTIC_IMAGE = CN()
+_C.SEMANTIC_IMAGE.ENABLED = False
+_C.SEMANTIC_IMAGE.N_CLASSES = 9
+_C.SEMANTIC_IMAGE.USE_TOP_K = False
+_C.SEMANTIC_IMAGE.TOP_K_RATIO = 0.5
+_C.SEMANTIC_IMAGE.USE_WEIGHTS = True
+
+# depth
+_C.DEPTH = CN()
+_C.DEPTH.ENABLED = False
+_C.DEPTH.N_CHANNELS = 1
+
+_C.LOSSES = CN()
+_C.LOSSES.WEIGHT_ACTION = 1.0
+_C.LOSSES.WEIGHT_SEGMENTATION = 0.1
+_C.LOSSES.WEIGHT_INSTANCE = 0.1
+_C.LOSSES.WEIGHT_REWARD = 0.1
+_C.LOSSES.WEIGHT_PROBABILISTIC = 1e-3
+_C.LOSSES.KL_BALANCING_ALPHA = 0.75
+_C.LOSSES.WEIGHT_LIDAR_RE = 0.1
+_C.LOSSES.WEIGHT_LIDAR_SEG = 0.1
+_C.LOSSES.WEIGHT_SEM_IMAGE = 0.1
+_C.LOSSES.WEIGHT_DEPTH = 0.1
+_C.LOSSES.WEIGHT_VOXEL = 0.1
+_C.LOSSES.RGB_INSTANCE = False
+_C.LOSSES.SSIM = False
+
+# pre_trained ckpt path
+_C.PRETRAINED = CN()
+_C.PRETRAINED.PATH = ''
+_C.PRETRAINED.CML_MODEL = ''
+
+# There parameters are only used to benchmark other models.
+_C.EVAL = CN()
+_C.EVAL.MASK_VIEW = False
+_C.EVAL.RGB_SUPERVISION = False
+_C.EVAL.CHECKPOINT_PATH = ''
+_C.EVAL.NO_LIFTING = False
+# Dataset size experiments
+_C.EVAL.DATASET_REDUCTION = False
+_C.EVAL.DATASET_REDUCTION_FACTOR = 1
+# Image resolution experiments
+_C.EVAL.RESOLUTION = CN()
+_C.EVAL.RESOLUTION.ENABLED = False
+_C.EVAL.RESOLUTION.FACTOR = 1
+
+#########
+# Sampler
+#########
+_C.SAMPLER = CN()
+_C.SAMPLER.ENABLED = False
+_C.SAMPLER.WITH_ACCELERATION = False
+_C.SAMPLER.WITH_STEERING = False
+_C.SAMPLER.N_BINS = 5
+_C.SAMPLER.WITH_ROUTE_COMMAND = False # not used
+_C.SAMPLER.COMMAND_WEIGHTS = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
+
+_C.MODEL.POLICY = CN()
+
+_C.MODEL.REWARD = CN()
+_C.MODEL.REWARD.ENABLED = False
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(description='World model training')
+ parser.add_argument('--config-file', default='', metavar='FILE', help='path to config file')
+ parser.add_argument(
+ 'opts', help='Modify config options using the command-line', default=None, nargs=argparse.REMAINDER,
+ )
+ return parser
+
+
+def _find_extra_keys(dict1, dict2, path=''):
+ """
+ Recursively finds keys that exist in dict2 but not in dict1.
+ Returns the full path of the missing keys, including the parent key names.
+ """
+ results = []
+ for key in dict2.keys():
+ new_path = f"{path}.{key}" if path else key
+ if key in dict1:
+ if isinstance(dict1[key], dict) and isinstance(dict2[key], dict):
+ results.extend(_find_extra_keys(dict1[key], dict2[key], new_path))
+ else:
+ results.append(new_path)
+ results.sort()
+ return results
+
+
+def get_cfg(args=None, cfg_dict=None):
+ """ First get default config. Then merge cfg_dict. Then merge according to args. """
+
+ cfg = _C.clone()
+
+ if cfg_dict is not None:
+ extra_keys = _find_extra_keys(cfg, cfg_dict)
+ if len(extra_keys) > 0:
+ print(f"Warning - the cfg_dict merging into the main cfg has keys that do not exist in main: {extra_keys}")
+ cfg.set_new_allowed(True)
+ cfg.merge_from_other_cfg(CfgNode(cfg_dict))
+
+ if args is not None:
+ if args.config_file:
+ cfg.merge_from_file(args.config_file)
+ cfg.merge_from_list(args.opts)
+ cfg.freeze()
+ return cfg
diff --git a/muvo/configs/debug.yml b/muvo/configs/debug.yml
new file mode 100644
index 0000000..0c594fe
--- /dev/null
+++ b/muvo/configs/debug.yml
@@ -0,0 +1,78 @@
+_BASE_: 'muvo.yml'
+
+LOG_DIR: 'tensorboard_logs'
+TAG: 'debug'
+
+GPUS: 1
+BATCHSIZE: 1
+
+STEPS: 8
+LOGGING_INTERVAL: 2
+LOG_VIDEO_INTERVAL: 2
+VAL_CHECK_INTERVAL: 3
+LIMIT_VAL_BATCHES: 2
+N_WORKERS: 0
+
+PREDICTION:
+ N_SAMPLES: 1
+
+OPTIMIZER:
+ ACCUMULATE_GRAD_BATCHES: 1
+ FROZEN:
+ ENABLED: False
+ TRAIN_LIST: ['bev_decoder', 'voxel_decoder']
+
+SEMANTIC_SEG:
+ ENABLED: False
+
+VOXEL_SEG:
+ ENABLED: True
+ DIMENSION: 64
+ N_CLASSES: 2
+ USE_WEIGHTS: False
+
+LIDAR_SEG:
+ ENABLED: False
+ N_CLASSES: 9
+ USE_WEIGHTS: True
+
+LIDAR_RE:
+ ENABLED: True
+
+SEMANTIC_IMAGE:
+ ENABLED: False
+ N_CLASSES: 9
+ USE_WEIGHTS: True
+
+DEPTH:
+ ENABLED: False
+
+LOSSES:
+ RGB_INSTANCE: False
+ SSIM: False
+
+MODEL:
+ EMBEDDING_DIM: 512
+ TRANSFORMER:
+ ENABLED: True
+ BEV: False
+ LARGE: False
+ LIDAR:
+ POINT_PILLAR:
+ ENABLED: False
+
+EVAL:
+ RGB_SUPERVISION: True
+
+RECEPTIVE_FIELD: 2
+FUTURE_HORIZON: 1
+
+PRETRAINED:
+# PATH: 'weights/epoch=37-step=57600.ckpt'
+# PATH: 'weights/epoch=65-step=99000.ckpt'
+ PATH: 'weights/epoch=3-step=50000.ckpt'
+# PATH: ''
+#
+#DATASET:
+# DATAROOT: ''
+# VERSION: 'mini'
diff --git a/muvo/configs/muvo.yml b/muvo/configs/muvo.yml
new file mode 100644
index 0000000..1e312ce
--- /dev/null
+++ b/muvo/configs/muvo.yml
@@ -0,0 +1,98 @@
+LOG_DIR: 'tensorboard_logs'
+TAG: '2d, resnet18, w/o bev, range view, transformer, with voxel'
+
+CML_PROJECT: 'muvo'
+CML_TASK: 'muvo'
+CML_TYPE: 'training'
+CML_DATASET: 'carla_dataset'
+CML_DATASET_VERSION: '2.0.0'
+
+GPUS: 1
+BATCHSIZE: 1
+STEPS: 100000
+VAL_CHECK_INTERVAL: 5000
+LOGGING_INTERVAL: 200
+LOG_VIDEO_INTERVAL: 2500
+N_WORKERS: 16
+
+OPTIMIZER:
+ ACCUMULATE_GRAD_BATCHES: 16
+ FROZEN:
+ ENABLED: False
+ TRAIN_LIST: ['voxel_decoder']
+
+PREDICTION:
+ N_SAMPLES: 1
+
+MODEL:
+ TRANSFORMER_TRANSITION:
+ ENABLED: True
+ ROUTE:
+ ENABLED: True
+ TRANSFORMER:
+ ENABLED: True
+ CHANNELS: 384
+ BEV: False
+ LARGE: False
+ LIDAR:
+ POINT_PILLAR:
+ ENABLED: False
+ ENCODER:
+ NAME: 'resnet18'
+ BEV:
+ BACKBONE: 'resnet18'
+ LIDAR:
+ ENCODER: 'resnet18'
+ ROUTE:
+ BACKBONE: 'resnet18'
+
+EVAL:
+ RGB_SUPERVISION: True
+ NO_LIFTING: False
+
+SEMANTIC_SEG:
+ ENABLED: False
+
+VOXEL_SEG:
+ ENABLED: True
+ DIMENSION: 64
+ N_CLASSES: 2
+ USE_WEIGHTS: False
+
+LIDAR_SEG:
+ ENABLED: False
+ N_CLASSES: 9
+ USE_WEIGHTS: True
+
+LIDAR_RE:
+ ENABLED: True
+
+SEMANTIC_IMAGE:
+ ENABLED: False
+ N_CLASSES: 9
+ USE_WEIGHTS: True
+
+DEPTH:
+ ENABLED: False
+
+LOSSES:
+ SSIM: False
+ RGB_INSTANCE: False
+ PERCEPTUAL:
+ ENABLED: False
+ MODEL: 'resnet18'
+
+RECEPTIVE_FIELD: 4
+FUTURE_HORIZON: 2
+
+PRETRAINED:
+ PATH: 'path to weight'
+ CML_MODEL: ''
+
+DATASET:
+# DATAROOT: '/disk/vanishing_data/qw825/carla_dataset'
+ DATAROOT: 'path to dataset'
+
+IMAGE:
+ CAMERA_POSITION: [1.0, 0.0, 2.0] # (forward, right, up)
+
diff --git a/muvo/configs/one_frame.yml b/muvo/configs/one_frame.yml
new file mode 100644
index 0000000..cf8a09e
--- /dev/null
+++ b/muvo/configs/one_frame.yml
@@ -0,0 +1,16 @@
+LOG_DIR: 'tensorboard_logs'
+TAG: 'one_frame_baseline'
+
+GPUS: 8
+BATCHSIZE: 8
+STEPS: 50000
+
+DATASET:
+ DATAROOT: ''
+
+RECEPTIVE_FIELD: 1
+FUTURE_HORIZON: 0
+
+MODEL:
+ TRANSITION:
+ ENABLED: False
diff --git a/muvo/configs/test_base_1d.yml b/muvo/configs/test_base_1d.yml
new file mode 100644
index 0000000..e4c6e6b
--- /dev/null
+++ b/muvo/configs/test_base_1d.yml
@@ -0,0 +1,42 @@
+_BASE_: 'muvo.yml'
+
+LOG_DIR: 'tensorboard_logs'
+TAG: 'test_base, 1d, resnet18, w/o bev, range view, transformer, with voxel'
+
+CML_TYPE: 'application'
+
+PREDICTION:
+ N_SAMPLES: 1
+
+MODEL:
+ TRANSFORMER_TRANSITION:
+ ENABLED: False
+ ROUTE:
+ ENABLED: True
+ TRANSFORMER:
+ ENABLED: True
+ CHANNELS: 384
+ BEV: False
+ LARGE: False
+ LIDAR:
+ POINT_PILLAR:
+ ENABLED: False
+ ENCODER:
+ NAME: 'resnet18'
+ BEV:
+ BACKBONE: 'resnet18'
+ LIDAR:
+ ENCODER: 'resnet18'
+ ROUTE:
+ BACKBONE: 'resnet18'
+
+RECEPTIVE_FIELD: 6
+FUTURE_HORIZON: 10
+
+PRETRAINED:
+ PATH: 'path to weights'
+# CML_MODEL: ''
+
+DATASET:
+# DATAROOT: '/disk/vanishing_data/qw825/carla_dataset'
+ DATAROOT: 'path to dataset'
diff --git a/muvo/configs/test_base_1d_without_voxel.yml b/muvo/configs/test_base_1d_without_voxel.yml
new file mode 100644
index 0000000..dfb0f48
--- /dev/null
+++ b/muvo/configs/test_base_1d_without_voxel.yml
@@ -0,0 +1,45 @@
+_BASE_: 'muvo.yml'
+
+LOG_DIR: 'tensorboard_logs'
+TAG: 'test_base, 1d, resnet18, w/o bev, range view, transformer, without voxel'
+
+CML_TYPE: 'application'
+
+PREDICTION:
+ N_SAMPLES: 1
+
+MODEL:
+ TRANSFORMER_TRANSITION:
+ ENABLED: False
+ ROUTE:
+ ENABLED: True
+ TRANSFORMER:
+ ENABLED: True
+ CHANNELS: 384
+ BEV: False
+ LARGE: False
+ LIDAR:
+ POINT_PILLAR:
+ ENABLED: False
+ ENCODER:
+ NAME: 'resnet18'
+ BEV:
+ BACKBONE: 'resnet18'
+ LIDAR:
+ ENCODER: 'resnet18'
+ ROUTE:
+ BACKBONE: 'resnet18'
+
+VOXEL_SEG:
+ ENABLED: FALSE
+
+RECEPTIVE_FIELD: 6
+FUTURE_HORIZON: 10
+
+PRETRAINED:
+ PATH: 'path to weights'
+# CML_MODEL: ''
+
+DATASET:
+# DATAROOT: '/disk/vanishing_data/qw825/carla_dataset'
+ DATAROOT: 'path to dataset'
diff --git a/muvo/configs/test_base_2d.yml b/muvo/configs/test_base_2d.yml
new file mode 100644
index 0000000..6b0b8f4
--- /dev/null
+++ b/muvo/configs/test_base_2d.yml
@@ -0,0 +1,43 @@
+_BASE_: 'muvo.yml'
+
+LOG_DIR: 'tensorboard_logs'
+TAG: 'test_base, 2d, resnet18, w/o bev, range view, transformer, with voxel'
+
+CML_TYPE: 'application'
+
+PREDICTION:
+ N_SAMPLES: 1
+
+MODEL:
+ TRANSFORMER_TRANSITION:
+ ENABLED: True
+ ROUTE:
+ ENABLED: True
+ TRANSFORMER:
+ ENABLED: True
+ CHANNELS: 384
+ BEV: False
+ LARGE: False
+ LIDAR:
+ POINT_PILLAR:
+ ENABLED: False
+ ENCODER:
+ NAME: 'resnet18'
+ BEV:
+ BACKBONE: 'resnet18'
+ LIDAR:
+ ENCODER: 'resnet18'
+ ROUTE:
+ BACKBONE: 'resnet18'
+
+RECEPTIVE_FIELD: 6
+FUTURE_HORIZON: 10
+
+PRETRAINED:
+ PATH: 'path to weights'
+# CML_MODEL: ''
+
+DATASET:
+# DATAROOT: '/disk/vanishing_data/qw825/carla_dataset'
+ DATAROOT: 'path to dataset'
+
diff --git a/muvo/configs/test_mobilevit_2d.yml b/muvo/configs/test_mobilevit_2d.yml
new file mode 100644
index 0000000..5cdf16c
--- /dev/null
+++ b/muvo/configs/test_mobilevit_2d.yml
@@ -0,0 +1,42 @@
+_BASE_: 'muvo.yml'
+
+LOG_DIR: 'tensorboard_logs'
+TAG: 'test_base, 2d, mobilevit, w/o bev, range view, transformer, with voxel'
+
+CML_TYPE: 'application'
+
+PREDICTION:
+ N_SAMPLES: 1
+
+MODEL:
+ TRANSFORMER_TRANSITION:
+ ENABLED: True
+ ROUTE:
+ ENABLED: True
+ TRANSFORMER:
+ ENABLED: True
+ CHANNELS: 384
+ BEV: False
+ LARGE: False
+ LIDAR:
+ POINT_PILLAR:
+ ENABLED: False
+ ENCODER:
+ NAME: 'mobilevitv2_100'
+ BEV:
+ BACKBONE: 'mobilevitv2_100'
+ LIDAR:
+ ENCODER: 'mobilevitv2_100'
+ ROUTE:
+ BACKBONE: 'resnet18'
+
+RECEPTIVE_FIELD: 6
+FUTURE_HORIZON: 10
+
+PRETRAINED:
+ PATH: 'path to weights'
+# CML_MODEL: ''
+
+DATASET:
+# DATAROOT: '/disk/vanishing_data/qw825/carla_dataset'
+ DATAROOT: 'path to dataset'
diff --git a/muvo/data/carlagym_utils.py b/muvo/data/carlagym_utils.py
new file mode 100644
index 0000000..86d7b04
--- /dev/null
+++ b/muvo/data/carlagym_utils.py
@@ -0,0 +1,66 @@
+import numpy as np
+import carla
+import math
+
+EARTH_RADIUS_EQUA = 6378137.0
+
+
+def vec_global_to_ref(target_vec_in_global, ref_rot_in_global):
+ """
+ :param target_vec_in_global: carla.Vector3D in global coordinate (world, actor)
+ :param ref_rot_in_global: carla.Rotation in global coordinate (world, actor)
+ :return: carla.Vector3D in ref coordinate
+ """
+ R = carla_rot_to_mat(ref_rot_in_global)
+ np_vec_in_global = np.array([[target_vec_in_global.x],
+ [target_vec_in_global.y],
+ [target_vec_in_global.z]])
+ np_vec_in_ref = R.T.dot(np_vec_in_global)
+ target_vec_in_ref = carla.Vector3D(x=np_vec_in_ref[0, 0], y=np_vec_in_ref[1, 0], z=np_vec_in_ref[2, 0])
+ return target_vec_in_ref
+
+
+def carla_rot_to_mat(carla_rotation):
+ """
+ Transform rpy in carla.Rotation to rotation matrix in np.array
+
+ :param carla_rotation: carla.Rotation
+ :return: np.array rotation matrix
+ """
+ roll = np.deg2rad(carla_rotation.roll)
+ pitch = np.deg2rad(carla_rotation.pitch)
+ yaw = np.deg2rad(carla_rotation.yaw)
+
+ yaw_matrix = np.array([
+ [np.cos(yaw), -np.sin(yaw), 0],
+ [np.sin(yaw), np.cos(yaw), 0],
+ [0, 0, 1]
+ ])
+ pitch_matrix = np.array([
+ [np.cos(pitch), 0, -np.sin(pitch)],
+ [0, 1, 0],
+ [np.sin(pitch), 0, np.cos(pitch)]
+ ])
+ roll_matrix = np.array([
+ [1, 0, 0],
+ [0, np.cos(roll), np.sin(roll)],
+ [0, -np.sin(roll), np.cos(roll)]
+ ])
+
+ rotation_matrix = yaw_matrix.dot(pitch_matrix).dot(roll_matrix)
+ return rotation_matrix
+
+
+def gps_to_location(gps):
+ lat, lon, z = gps
+ lat = float(lat)
+ lon = float(lon)
+ z = float(z)
+
+ location = carla.Location(z=z)
+
+ location.x = lon / 180.0 * (math.pi * EARTH_RADIUS_EQUA)
+
+ location.y = -1.0 * math.log(math.tan((lat + 90.0) * math.pi / 360.0)) * EARTH_RADIUS_EQUA
+
+ return location
diff --git a/muvo/data/dataset.py b/muvo/data/dataset.py
new file mode 100644
index 0000000..34911b5
--- /dev/null
+++ b/muvo/data/dataset.py
@@ -0,0 +1,385 @@
+import os
+from glob import glob
+from PIL import Image
+
+import numpy as np
+import pandas as pd
+import lightning.pytorch as pl
+import scipy.ndimage
+import torch
+from torch.utils.data import Dataset, DataLoader
+
+from constants import CARLA_FPS, EGO_VEHICLE_DIMENSION, LABEL_MAP, VOXEL_LABEL, VOXEL_LABEL_CARLA
+from muvo.data.dataset_utils import integer_to_binary, calculate_birdview_labels, calculate_instance_mask
+from muvo.utils.geometry_utils import get_out_of_view_mask, calculate_geometry, lidar_to_histogram_features
+from muvo.utils.geometry_utils import PointCloud
+from data.data_preprocessing import convert_coor_lidar
+
+
+class DataModule(pl.LightningDataModule):
+ def __init__(self, cfg, dataset_root=None):
+ super().__init__()
+ self.cfg = cfg
+ self.batch_size = self.cfg.BATCHSIZE
+ self.sequence_length = self.cfg.RECEPTIVE_FIELD + self.cfg.FUTURE_HORIZON
+
+ self.dataset_root = dataset_root if dataset_root else self.cfg.DATASET.DATAROOT
+
+ # Will be populated with self.setup()
+ self.train_dataset, self.val_dataset_0, self.val_dataset_1, self.val_dataset_2 = None, None, None, None
+ self.test_dataset = None
+
+ def setup(self, stage=None):
+ self.train_dataset = CarlaDataset(
+ self.cfg, mode='train', sequence_length=self.sequence_length, dataset_root=self.dataset_root
+ )
+ # self.val_dataset = CarlaDataset(
+ # self.cfg, mode='train', sequence_length=self.sequence_length, dataset_root=self.dataset_root
+ # )
+ # mutil validation dataset
+ self.val_dataset_0 = CarlaDataset(
+ self.cfg, mode='val0', sequence_length=self.sequence_length, dataset_root=self.dataset_root
+ )
+ self.val_dataset_1 = CarlaDataset(
+ self.cfg, mode='val1', sequence_length=self.sequence_length, dataset_root=self.dataset_root
+ )
+ self.val_dataset_2 = CarlaDataset(
+ self.cfg, mode='val2', sequence_length=self.sequence_length, dataset_root=self.dataset_root
+ )
+ self.test_dataset = CarlaDataset(
+ self.cfg, mode='train', sequence_length=self.sequence_length, dataset_root=self.dataset_root
+ )
+
+ print(f'{len(self.train_dataset)} data points in {self.train_dataset.dataset_path}')
+ # print(f'{len(self.val_dataset)} data points in {self.val_dataset.dataset_path}')
+ print(f'{len(self.val_dataset_0)} data points in {self.val_dataset_0.dataset_path}')
+ print(f'{len(self.val_dataset_1)} data points in {self.val_dataset_1.dataset_path}')
+ print(f'{len(self.val_dataset_2)} data points in {self.val_dataset_2.dataset_path}')
+ print(f'{len(self.test_dataset)} data points in prediction')
+
+ # self.train_sampler = range(10, len(self.train_dataset))
+ self.train_sampler = None
+ self.val_sampler_0 = range(0, len(self.val_dataset_0), 50)
+ self.val_sampler_1 = range(1500, len(self.val_dataset_1), 50)
+ self.val_sampler_2 = range(3000, len(self.val_dataset_2), 50)
+ # self.val_sampler = None
+ self.test_sampler_0 = range(0, len(self.test_dataset), 900)
+ self.test_sampler_1 = range(1500, len(self.test_dataset), 600)
+ self.test_sampler_2 = range(0, len(self.test_dataset), 150)
+
+ def train_dataloader(self):
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.batch_size,
+ shuffle=self.train_sampler is None,
+ num_workers=self.cfg.N_WORKERS,
+ pin_memory=True,
+ drop_last=True,
+ sampler=self.train_sampler,
+ )
+
+ def val_dataloader(self):
+ return [
+ DataLoader(
+ self.val_dataset_0,
+ batch_size=self.batch_size,
+ shuffle=False,
+ num_workers=self.cfg.N_WORKERS,
+ pin_memory=True,
+ drop_last=True,
+ sampler=self.val_sampler_0,
+ ),
+ DataLoader(
+ self.val_dataset_1,
+ batch_size=self.batch_size,
+ shuffle=False,
+ num_workers=self.cfg.N_WORKERS,
+ pin_memory=True,
+ drop_last=True,
+ sampler=self.val_sampler_1,
+ ),
+ DataLoader(
+ self.val_dataset_2,
+ batch_size=self.batch_size,
+ shuffle=False,
+ num_workers=self.cfg.N_WORKERS,
+ pin_memory=True,
+ drop_last=True,
+ sampler=self.val_sampler_2,
+ ),
+ ]
+
+ def test_dataloader(self):
+ return [
+ DataLoader(
+ self.test_dataset,
+ batch_size=self.batch_size,
+ shuffle=False,
+ num_workers=self.cfg.N_WORKERS,
+ pin_memory=True,
+ drop_last=True,
+ sampler=self.test_sampler_0,
+ ),
+ DataLoader(
+ self.test_dataset,
+ batch_size=self.batch_size,
+ shuffle=False,
+ num_workers=self.cfg.N_WORKERS,
+ pin_memory=True,
+ drop_last=True,
+ sampler=self.test_sampler_1,
+ ),
+ DataLoader(
+ self.test_dataset,
+ batch_size=self.batch_size,
+ shuffle=False,
+ num_workers=self.cfg.N_WORKERS,
+ pin_memory=True,
+ drop_last=True,
+ sampler=self.test_sampler_2,
+ ),
+ ]
+
+
+class CarlaDataset(Dataset):
+ def __init__(self, cfg, mode='train', sequence_length=1, dataset_root=None, towns_filter='*', runs_filter='*'):
+ self.cfg = cfg
+ self.mode = mode
+ self.sequence_length = sequence_length
+
+ self.dataset_path = os.path.join(dataset_root, self.cfg.DATASET.VERSION, mode)
+ self.intrinsics, self.extrinsics = calculate_geometry_from_config(self.cfg)
+ self.bev_out_of_view_mask = get_out_of_view_mask(self.cfg)
+ self.pcd = PointCloud(
+ self.cfg.POINTS.CHANNELS,
+ self.cfg.POINTS.HORIZON_RESOLUTION,
+ *self.cfg.POINTS.FOV,
+ self.cfg.POINTS.LIDAR_POSITION,
+ )
+
+ # Iterate over all runs in the data folder
+
+ self.data = dict()
+
+ towns = sorted(glob(os.path.join(self.dataset_path, towns_filter)))
+ for town_path in towns:
+ town = os.path.basename(town_path)
+
+ runs = sorted(glob(os.path.join(self.dataset_path, town, runs_filter)))
+ for run_path in runs:
+ run = os.path.basename(run_path)
+ pd_dataframe_path = os.path.join(run_path, 'pd_dataframe.pkl')
+
+ if os.path.isfile(pd_dataframe_path):
+ self.data[f'{town}/{run}'] = pd.read_pickle(pd_dataframe_path)
+
+ self.data_pointers = self.get_data_pointers()
+
+ def get_data_pointers(self):
+ data_pointers = []
+
+ n_filtered_run = 0
+ for run, data_run in self.data.items():
+ # Calculate normalised reward of the run
+ run_length = len(data_run['reward'])
+ cumulative_reward = data_run['reward'].sum()
+ normalised_reward = cumulative_reward / run_length
+ if normalised_reward < self.cfg.DATASET.FILTER_NORM_REWARD:
+ n_filtered_run += 1
+ continue
+
+ stride = int(self.cfg.DATASET.STRIDE_SEC * CARLA_FPS)
+ # Loop across all elements in the dataset, and make all elements in a sequence belong to the same run
+ start_index = int(CARLA_FPS * self.cfg.DATASET.FILTER_BEGINNING_OF_RUN_SEC)
+ total_length = len(data_run) - stride * self.sequence_length
+ for i in range(start_index, total_length):
+ frame_indices = range(i, i + stride * self.sequence_length, stride)
+ data_pointers.append((run, list(frame_indices)))
+
+ print(f'Filtered {n_filtered_run} runs in {self.dataset_path}')
+
+ if self.cfg.EVAL.DATASET_REDUCTION:
+ import random
+ random.seed(0)
+ final_size = int(len(data_pointers) / self.cfg.EVAL.DATASET_REDUCTION_FACTOR)
+ data_pointers = random.sample(data_pointers, final_size)
+
+ return data_pointers
+
+ def __len__(self):
+ return len(self.data_pointers)
+
+ def __getitem__(self, i):
+ batch = {}
+
+ run_id, indices = self.data_pointers[i]
+ for t in indices:
+ try:
+ single_element_t = self.load_single_element_time_t(run_id, t)
+ except:
+ print(f'{run_id}, {t} data is invalid')
+ continue
+
+ for k, v in single_element_t.items():
+ batch[k] = batch.get(k, []) + [v]
+
+ for k, v in batch.items():
+ batch[k] = torch.from_numpy(np.stack(v))
+
+ return batch
+
+ def load_single_element_time_t(self, run_id, t):
+ data_row = self.data[run_id].iloc[t]
+ single_element_t = {}
+
+ # Load image
+ image = Image.open(
+ os.path.join(self.dataset_path, run_id, data_row['image_path'])
+ )
+ image = np.asarray(image).transpose((2, 0, 1))
+ single_element_t['image'] = image
+
+ # Load route map
+ route_map = Image.open(
+ os.path.join(self.dataset_path, run_id, data_row['routemap_path'])
+ )
+ route_map = np.asarray(route_map)[None]
+ # Make the grayscale image an RGB image
+ _, h, w = route_map.shape
+ route_map = np.broadcast_to(route_map, (3, h, w)).copy()
+ single_element_t['route_map'] = route_map
+
+ # Load bird's-eye view segmentation label
+ birdview = np.asarray(Image.open(
+ os.path.join(self.dataset_path, run_id, data_row['birdview_path'])
+ ))
+ h, w = birdview.shape
+ n_classes = data_row['n_classes']
+ birdview = integer_to_binary(birdview.reshape(-1), n_classes).reshape(h, w, n_classes)
+ birdview = birdview.transpose((2, 0, 1))
+ single_element_t['birdview'] = birdview
+ birdview_label = calculate_birdview_labels(torch.from_numpy(birdview), n_classes).numpy()
+ birdview_label = birdview_label[None]
+ single_element_t['birdview_label'] = birdview_label
+
+ # TODO: get person and car instance ids with json
+ instance_mask = birdview[3].astype(np.bool) | birdview[4].astype(np.bool)
+ instance_label, _ = scipy.ndimage.label(instance_mask[None].astype(np.int64))
+ single_element_t['instance_label'] = instance_label
+
+ # # Load lidar points clouds
+ # pcd = np.load(
+ # os.path.join(self.dataset_path, run_id, data_row['points_path']),
+ # allow_pickle=True).item() # n x 4, (x, y, z, intensity)
+ # single_element_t['points_intensity'] = np.concatenate([pcd['points_xyz'], pcd['intensity'][:, None]], axis=-1)
+ pcd_semantic = np.load(
+ os.path.join(self.dataset_path, run_id, data_row['points_semantic_path']),
+ allow_pickle=True).item()
+ points = convert_coor_lidar(pcd_semantic['points_xyz'], self.cfg.POINTS.LIDAR_POSITION)
+
+ # remap labels
+ remap = np.full((max(LABEL_MAP.keys()) + 1), max(LABEL_MAP.values()), dtype=np.uint8)
+ remap[list(LABEL_MAP.keys())] = list(LABEL_MAP.values())
+ semantics = remap[pcd_semantic['ObjTag']]
+
+ # mask ego-vehicle
+ x, y, z = EGO_VEHICLE_DIMENSION
+ ego_box = np.array([[-x / 2, -y / 2, 0], [x / 2, y / 2, z]])
+ ego_idx = ((ego_box[0] < points) & (points < ego_box[1])).all(axis=1)
+ semantics = semantics[~ego_idx]
+ points = points[~ego_idx]
+
+ # histogram of lidar
+ # single_element_t['points'] = points
+ # single_element_t['points_label'] = pcd_semantic['ObjTag'].astype('uint8')
+ # single_element_t['points_histogram_xy'], \
+ # single_element_t['points_histogram_xz'], \
+ # single_element_t['points_histogram_yz'] = lidar_to_histogram_features(points, self.cfg)
+
+ # range-view of lidar point cloud
+ range_view_pcd_depth, range_view_pcd_xyz, range_view_pcd_sem = self.pcd.do_range_projection(points, semantics)
+ if self.cfg.MODEL.LIDAR.ENABLED:
+ single_element_t['range_view_pcd_xyzd'] = np.concatenate(
+ [range_view_pcd_xyz, range_view_pcd_depth[..., None]], axis=-1).transpose((2, 0, 1)) # x y z d
+ if self.cfg.LIDAR_SEG.ENABLED:
+ single_element_t['range_view_pcd_seg'] = range_view_pcd_sem[None].astype(int)
+
+ # data type for point-pillar
+ if self.cfg.MODEL.LIDAR.POINT_PILLAR.ENABLED:
+ max_num_points = int(self.cfg.POINTS.N_PER_SECOND / CARLA_FPS)
+ fixed_points = np.empty((max_num_points, 3), dtype=np.float32)
+ num_points = min(points.shape[0], max_num_points)
+ fixed_points[:num_points] = points[:num_points]
+ single_element_t['points_raw'] = fixed_points
+ single_element_t['num_points'] = num_points
+
+ # Load voxels, saved as voxel coordinates.
+ if self.cfg.VOXEL_SEG.ENABLED:
+ voxel_data = np.load(
+ os.path.join(self.dataset_path, run_id, data_row['voxel_path'])
+ )
+ voxel_points = voxel_data[:, :-1]
+ voxel_semantics = voxel_data[:, -1]
+ voxel_semantics[voxel_semantics == 255] = 0
+ voxel_semantics = remap[voxel_semantics]
+ voxels = np.zeros(self.cfg.VOXEL.SIZE, dtype=np.uint8)
+ voxels[voxel_points[:, 0], voxel_points[:, 1], voxel_points[:, 2]] = voxel_semantics
+ single_element_t['voxel'] = voxels[None]
+
+ # load depth and semantic image
+ depth_semantic = Image.open(
+ os.path.join(self.dataset_path, run_id, data_row['depth_semantic_path'])
+ )
+ depth_semantic = np.asarray(depth_semantic)
+ semantic_image = depth_semantic[..., -1]
+ if self.cfg.LOSSES.RGB_INSTANCE:
+ single_element_t['image_instance_mask'] = calculate_instance_mask(
+ semantic_image[None],
+ vehicle_idx=list(VOXEL_LABEL_CARLA.keys())[list(VOXEL_LABEL_CARLA.values()).index('Vehicle')],
+ pedestrian_idx=list(VOXEL_LABEL_CARLA.keys())[list(VOXEL_LABEL_CARLA.values()).index('Pedestrian')],
+ )
+
+ # load semantic image
+ if self.cfg.SEMANTIC_IMAGE.ENABLED:
+ single_element_t['semantic_image'] = remap[semantic_image][None].astype(int)
+ # load depth
+ if self.cfg.DEPTH.ENABLED:
+ depth_color = depth_semantic[..., :-1].transpose((2, 0, 1)).astype(float)
+ single_element_t['depth_color'] = depth_color / 255.0
+ # in carla, depth is saved in rgb-channel
+ depth = (256 ** 2 * depth_color[0] + 256 * depth_color[1] + depth_color[2]) / (256 ** 3 - 1)
+ depth[depth > 0.999] = -1
+ single_element_t['depth'] = depth[None]
+
+ # Load action and reward
+ throttle, steering, brake = data_row['action']
+ throttle_brake = throttle if throttle > 0 else -brake
+
+ single_element_t['steering'] = np.array([steering], dtype=np.float32)
+ single_element_t['throttle_brake'] = np.array([throttle_brake], dtype=np.float32)
+ single_element_t['speed'] = data_row['speed']
+
+ single_element_t['reward'] = np.array([data_row['reward']], dtype=np.float32).clip(-1.0, 1.0)
+ single_element_t['value_function'] = np.array([data_row['value']], dtype=np.float32)
+
+ # Geometry
+ single_element_t['intrinsics'] = self.intrinsics.copy()
+ single_element_t['extrinsics'] = self.extrinsics.copy()
+
+ return single_element_t
+
+
+def calculate_geometry_from_config(cfg):
+ """ Intrinsics and extrinsics for a single camera.
+ See https://github.com/bradyz/carla_utils_fork/blob/dynamic-scene/carla_utils/leaderboard/camera.py
+ and https://github.com/bradyz/carla_utils_fork/blob/dynamic-scene/carla_utils/recording/sensors/camera.py
+ """
+ # Intrinsics
+ fov = cfg.IMAGE.FOV
+ h, w = cfg.IMAGE.SIZE
+
+ # Extrinsics
+ forward, right, up = cfg.IMAGE.CAMERA_POSITION
+ pitch, yaw, roll = cfg.IMAGE.CAMERA_ROTATION
+
+ return calculate_geometry(fov, h, w, forward, right, up, pitch, yaw, roll)
diff --git a/muvo/data/dataset_utils.py b/muvo/data/dataset_utils.py
new file mode 100644
index 0000000..0f901c7
--- /dev/null
+++ b/muvo/data/dataset_utils.py
@@ -0,0 +1,128 @@
+import torch
+import numpy as np
+
+import carla
+# import carla_gym.utils.transforms as trans_utils
+# import carla_gym.core.task_actor.common.navigation.route_manipulation as gps_util
+from muvo.data.carlagym_utils import gps_to_location, vec_global_to_ref
+
+
+def binary_to_integer(binary_array, n_bits):
+ """
+ Parameters
+ ----------
+ binary_array: shape (n, n_bits)
+
+ Returns
+ -------
+ integer_array: shape (n,) np.int32
+ """
+ return (binary_array @ 2 ** np.arange(n_bits, dtype=binary_array.dtype)).astype(np.int32)
+
+
+def integer_to_binary(integer_array, n_bits):
+ """
+ Parameters
+ ----------
+ integer_array: np.ndarray (n,)
+ n_bits: int
+
+ Returns
+ -------
+ binary_array: np.ndarray (n, n_bits)
+
+ """
+ return (((integer_array[:, None] & (1 << np.arange(n_bits)))) > 0).astype(np.float32)
+
+
+def calculate_birdview_labels(birdview, n_classes, has_time_dimension=False):
+ """
+ Parameters
+ ----------
+ birdview: torch.Tensor (C, H, W)
+ n_classes: int
+ number of total classes
+ has_time_dimension: bool
+
+ Returns
+ -------
+ birdview_label: (H, W)
+ """
+ # When a pixel contains two labels, argmax will output the first one that is encountered.
+ # By reversing the order, we prioritise traffic lights over road.
+ dim = 0
+ if has_time_dimension:
+ dim = 1
+ birdview_label = torch.argmax(birdview.flip(dims=[dim]), dim=dim)
+ # We then re-normalise the classes in the normal order.
+ birdview_label = (n_classes - 1) - birdview_label
+ return birdview_label
+
+
+def preprocess_measurements(route_command, ego_gps, target_gps, imu):
+ # preprocess measurements
+ route_command = route_command.copy()
+ route_command[route_command < 0] = 4
+ route_command -= 1
+ route_command = np.array(route_command[0], dtype=np.int64)
+
+ loc_in_ev = preprocess_gps(ego_gps, target_gps, imu)
+ gps_vector = np.array([loc_in_ev.x, loc_in_ev.y], dtype=np.float32)
+ return route_command, gps_vector
+
+
+def preprocess_gps(ego_gps, target_gps, imu):
+ # imu nan bug
+ compass = 0.0 if np.isnan(imu[-1]) else imu[-1]
+ target_vec_in_global = gps_to_location(target_gps) - gps_to_location(ego_gps)
+ ref_rot_in_global = carla.Rotation(yaw=np.rad2deg(compass) - 90.0)
+ loc_in_ev = vec_global_to_ref(target_vec_in_global, ref_rot_in_global)
+ return loc_in_ev
+
+
+def preprocess_birdview_and_routemap(birdview):
+ ROUTE_MAP_INDEX = 1
+ # road, lane markings, vehicles, pedestrians
+ relevant_indices = [0, 2, 6, 10]
+
+ if isinstance(birdview, np.ndarray):
+ birdview = torch.from_numpy(birdview)
+ has_time_dimension = True
+ if len(birdview.shape) == 3:
+ birdview = birdview.unsqueeze(0)
+ has_time_dimension = False
+ # Birdview has values in {0, 255}. Convert to {0, 1}
+
+ # lights and stops
+ light_and_stop_channel = birdview[:, -1:]
+ green_light = (light_and_stop_channel == 80).float()
+ yellow_light = (light_and_stop_channel == 170).float()
+ red_light_and_stop = (light_and_stop_channel == 255).float()
+
+ remaining = birdview[:, relevant_indices]
+ remaining[remaining > 0] = 1
+ remaining = remaining.float()
+
+ # Traffic light and stop.
+ processed_birdview = torch.cat([remaining, green_light, yellow_light, red_light_and_stop], dim=1)
+ # background
+ tmp = processed_birdview.sum(dim=1, keepdim=True)
+ background = (tmp == 0).float()
+
+ processed_birdview = torch.cat([background, processed_birdview], dim=1)
+
+ # Route map
+ route_map = birdview[:, ROUTE_MAP_INDEX]
+ route_map[route_map > 0] = 255
+ route_map = route_map.to(torch.uint8)
+
+ if not has_time_dimension:
+ processed_birdview = processed_birdview[0]
+ route_map = route_map[0]
+ return processed_birdview, route_map
+
+
+def calculate_instance_mask(semantics, vehicle_idx, pedestrian_idx):
+ mask = np.zeros_like(semantics)
+ mask[(semantics == vehicle_idx) | (semantics == pedestrian_idx)] = 1
+ return mask.astype(bool)
diff --git a/muvo/layers/layers.py b/muvo/layers/layers.py
new file mode 100644
index 0000000..ee7e513
--- /dev/null
+++ b/muvo/layers/layers.py
@@ -0,0 +1,357 @@
+from collections import OrderedDict
+from functools import partial
+
+import torch
+import torch.nn as nn
+from timm.models.resnet import downsample_conv
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(
+ self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
+ reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,):
+ super(BasicBlock, self).__init__()
+
+ assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
+ assert base_width == 64, 'BasicBlock does not support changing base width'
+ first_planes = planes // reduce_first
+ outplanes = planes * self.expansion
+ first_dilation = first_dilation or dilation
+
+ self.conv1 = nn.Conv2d(
+ inplanes, first_planes, kernel_size=3, stride=stride, padding=first_dilation,
+ dilation=first_dilation, bias=False)
+ self.bn1 = norm_layer(first_planes)
+ self.act1 = act_layer(inplace=True)
+
+ self.conv2 = nn.Conv2d(
+ first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
+ self.bn2 = norm_layer(outplanes)
+
+ self.act2 = act_layer(inplace=True)
+ self.downsample = downsample
+ if self.downsample is not None:
+ self.downsample = downsample_conv(
+ in_channels=inplanes,
+ out_channels=outplanes,
+ kernel_size=1,
+ stride=2,
+ dilation=1,
+ first_dilation=first_dilation,
+ norm_layer=nn.BatchNorm2d,
+ )
+ self.stride = stride
+ self.dilation = dilation
+
+ def zero_init_last(self):
+ nn.init.zeros_(self.bn2.weight)
+
+ def forward(self, x):
+ shortcut = x
+
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.act1(x)
+
+ x = self.conv2(x)
+ x = self.bn2(x)
+
+ if self.downsample is not None:
+ shortcut = self.downsample(shortcut)
+ x += shortcut
+ x = self.act2(x)
+
+ return x
+
+
+class RestrictionActivation(nn.Module):
+ """ Constrain output to be between min_value and max_value."""
+
+ def __init__(self, min_value=0, max_value=1):
+ super().__init__()
+ self.scale = (max_value - min_value) / 2
+ self.offset = min_value
+
+ def forward(self, x):
+ x = torch.tanh(x) + 1 # in range [0, 2]
+ x = self.scale * x + self.offset # in range [min_value, max_value]
+ return x
+
+
+class ConvBlock(nn.Module):
+ """2D convolution followed by
+ - an optional normalisation (batch norm or instance norm)
+ - an optional activation (ReLU, LeakyReLU, or tanh)
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels=None,
+ kernel_size=3,
+ stride=1,
+ norm='bn',
+ activation='relu',
+ bias=False,
+ transpose=False,
+ ):
+ super().__init__()
+ out_channels = out_channels or in_channels
+ padding = int((kernel_size - 1) / 2)
+ self.conv = nn.Conv2d if not transpose else partial(nn.ConvTranspose2d, output_padding=1)
+ self.conv = self.conv(in_channels, out_channels, kernel_size, stride, padding=padding, bias=bias)
+
+ if norm == 'bn':
+ self.norm = nn.BatchNorm2d(out_channels)
+ elif norm == 'in':
+ self.norm = nn.InstanceNorm2d(out_channels)
+ elif norm == 'none':
+ self.norm = None
+ else:
+ raise ValueError('Invalid norm {}'.format(norm))
+
+ if activation == 'relu':
+ self.activation = nn.ReLU(inplace=True)
+ elif activation == 'lrelu':
+ self.activation = nn.LeakyReLU(0.1, inplace=True)
+ elif activation == 'elu':
+ self.activation = nn.ELU(inplace=True)
+ elif activation == 'tanh':
+ self.activation = nn.Tanh(inplace=True)
+ elif activation == 'none':
+ self.activation = None
+ else:
+ raise ValueError('Invalid activation {}'.format(activation))
+
+ def forward(self, x):
+ x = self.conv(x)
+
+ if self.norm:
+ x = self.norm(x)
+ if self.activation:
+ x = self.activation(x)
+ return x
+
+
+class Bottleneck(nn.Module):
+ """
+ Defines a bottleneck module with a residual connection
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels=None,
+ kernel_size=3,
+ dilation=1,
+ groups=1,
+ upsample=False,
+ downsample=False,
+ dropout=0.0,
+ ):
+ super().__init__()
+ self._downsample = downsample
+ bottleneck_channels = int(in_channels / 2)
+ out_channels = out_channels or in_channels
+ padding_size = ((kernel_size - 1) * dilation + 1) // 2
+
+ # Define the main conv operation
+ assert dilation == 1
+ if upsample:
+ assert not downsample, 'downsample and upsample not possible simultaneously.'
+ bottleneck_conv = nn.ConvTranspose2d(
+ bottleneck_channels,
+ bottleneck_channels,
+ kernel_size=kernel_size,
+ bias=False,
+ dilation=1,
+ stride=2,
+ output_padding=padding_size,
+ padding=padding_size,
+ groups=groups,
+ )
+ elif downsample:
+ bottleneck_conv = nn.Conv2d(
+ bottleneck_channels,
+ bottleneck_channels,
+ kernel_size=kernel_size,
+ bias=False,
+ dilation=dilation,
+ stride=2,
+ padding=padding_size,
+ groups=groups,
+ )
+ else:
+ bottleneck_conv = nn.Conv2d(
+ bottleneck_channels,
+ bottleneck_channels,
+ kernel_size=kernel_size,
+ bias=False,
+ dilation=dilation,
+ padding=padding_size,
+ groups=groups,
+ )
+
+ self.layers = nn.Sequential(
+ OrderedDict(
+ [
+ # First projection with 1x1 kernel
+ ('conv_down_project', nn.Conv2d(in_channels, bottleneck_channels, kernel_size=1, bias=False)),
+ ('abn_down_project', nn.Sequential(nn.BatchNorm2d(bottleneck_channels),
+ nn.ReLU(inplace=True))),
+ # Second conv block
+ ('conv', bottleneck_conv),
+ ('abn', nn.Sequential(nn.BatchNorm2d(bottleneck_channels), nn.ReLU(inplace=True))),
+ # Final projection with 1x1 kernel
+ ('conv_up_project', nn.Conv2d(bottleneck_channels, out_channels, kernel_size=1, bias=False)),
+ ('abn_up_project', nn.Sequential(nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True))),
+ # Regulariser
+ ('dropout', nn.Dropout2d(p=dropout)),
+ ]
+ )
+ )
+
+ if out_channels == in_channels and not downsample and not upsample:
+ self.projection = None
+ else:
+ projection = OrderedDict()
+ if upsample:
+ projection.update({'upsample_skip_proj': Interpolate(scale_factor=2)})
+ elif downsample:
+ projection.update({'upsample_skip_proj': nn.MaxPool2d(kernel_size=2, stride=2)})
+ projection.update(
+ {
+ 'conv_skip_proj': nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
+ 'bn_skip_proj': nn.BatchNorm2d(out_channels),
+ }
+ )
+ self.projection = nn.Sequential(projection)
+
+ # pylint: disable=arguments-differ
+ def forward(self, *args):
+ (x,) = args
+ x_residual = self.layers(x)
+ if self.projection is not None:
+ if self._downsample:
+ # pad h/w dimensions if they are odd to prevent shape mismatch with residual layer
+ x = nn.functional.pad(x, (0, x.shape[-1] % 2, 0, x.shape[-2] % 2), value=0)
+ return x_residual + self.projection(x)
+ return x_residual + x
+
+
+class Interpolate(nn.Module):
+ def __init__(self, scale_factor: int = 2):
+ super().__init__()
+ self._interpolate = nn.functional.interpolate
+ self._scale_factor = scale_factor
+
+ # pylint: disable=arguments-differ
+ def forward(self, x):
+ return self._interpolate(x, scale_factor=self._scale_factor, mode='bilinear', align_corners=False)
+
+
+class Upsampling(nn.Module):
+ def __init__(self, in_channels, out_channels, scale_factor=2):
+ super().__init__()
+ self.upsample_layer = nn.Sequential(
+ nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False),
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True),
+ )
+
+ def forward(self, x):
+ x = self.upsample_layer(x)
+ return x
+
+
+class UpsamplingAdd(nn.Module):
+ def __init__(self, in_channels, action_channels, out_channels, scale_factor=2):
+ super().__init__()
+ self.upsample_layer = nn.Sequential(
+ nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False),
+ nn.Conv2d(in_channels + action_channels, out_channels, kernel_size=1, padding=0, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True),
+ )
+
+ def forward(self, x, x_skip, action):
+ # Spatially broadcast
+ b, _, h, w = x.shape
+ action = action.view(b, -1, 1, 1).expand(b, -1, h, w)
+ x = torch.cat([x, action], dim=1)
+ x = self.upsample_layer(x)
+ return x + x_skip
+
+
+class UpsamplingConcat(nn.Module):
+ def __init__(self, in_channels, out_channels, scale_factor=2):
+ super().__init__()
+ self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)
+
+ self.conv = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True),
+ )
+
+ def forward(self, x_to_upsample, x):
+ x_to_upsample = self.upsample(x_to_upsample)
+ x_to_upsample = torch.cat([x, x_to_upsample], dim=1)
+ return self.conv(x_to_upsample)
+
+
+class ActivatedNormLinear(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.module = nn.Sequential(nn.Linear(in_channels, out_channels),
+ nn.BatchNorm1d(out_channels),
+ nn.ReLU(inplace=True))
+
+ def forward(self, x):
+ return self.module(x)
+
+
+class Flatten(nn.Module):
+ def forward(self, x):
+ return x.mean(dim=(-1, -2))
+
+
+class VoxelsSumming(torch.autograd.Function):
+ """Adapted from https://github.com/nv-tlabs/lift-splat-shoot/blob/master/src/tools.py#L193"""
+ @staticmethod
+ def forward(ctx, x, geometry, ranks):
+ """The features `x` and `geometry` are ranked by voxel positions."""
+ # Cumulative sum of all features.
+ x = x.cumsum(0)
+
+ # Indicates the change of voxel.
+ mask = torch.ones(x.shape[0], device=x.device, dtype=torch.bool)
+ mask[:-1] = ranks[1:] != ranks[:-1]
+
+ x, geometry = x[mask], geometry[mask]
+ # Calculate sum of features within a voxel.
+ x = torch.cat((x[:1], x[1:] - x[:-1]))
+
+ ctx.save_for_backward(mask)
+ ctx.mark_non_differentiable(geometry)
+
+ return x, geometry
+
+ @staticmethod
+ def backward(ctx, grad_x, grad_geometry):
+ (mask,) = ctx.saved_tensors
+ # Since the operation is summing, we simply need to send gradient
+ # to all elements that were part of the summation process.
+ indices = torch.cumsum(mask, 0)
+ indices[mask] -= 1
+
+ output_grad = grad_x[indices]
+
+ return output_grad, None, None
diff --git a/muvo/losses.py b/muvo/losses.py
new file mode 100644
index 0000000..8c24ca3
--- /dev/null
+++ b/muvo/losses.py
@@ -0,0 +1,375 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.cuda.amp import autocast
+
+from constants import SEMANTIC_SEG_WEIGHTS, VOXEL_SEG_WEIGHTS
+
+
+class SegmentationLoss(nn.Module):
+ def __init__(self, use_top_k=False, top_k_ratio=1.0, use_weights=False, poly_one=False, poly_one_coefficient=0.0,
+ is_bev=True):
+ super().__init__()
+ self.use_top_k = use_top_k
+ self.top_k_ratio = top_k_ratio
+ self.use_weights = use_weights
+ self.poly_one = poly_one
+ self.poly_one_coefficient = poly_one_coefficient
+
+ if self.use_weights:
+ self.weights = SEMANTIC_SEG_WEIGHTS if is_bev else VOXEL_SEG_WEIGHTS
+
+ def forward(self, prediction, target):
+ b, s, c, h, w = prediction.shape
+
+ prediction = prediction.view(b * s, c, h, w)
+ target = target.view(b * s, h, w)
+
+ if self.use_weights:
+ weights = torch.tensor(self.weights, dtype=prediction.dtype, device=prediction.device)
+ else:
+ weights = None
+ loss = F.cross_entropy(
+ prediction,
+ target,
+ reduction='none',
+ weight=weights,
+ )
+
+ if self.poly_one:
+ prob = torch.exp(-loss)
+ loss_poly_one = self.poly_one_coefficient * (1 - prob)
+ loss = loss + loss_poly_one
+
+ loss = loss.view(b, s, -1)
+ if self.use_top_k:
+ # Penalises the top-k hardest pixels
+ k = int(self.top_k_ratio * loss.shape[2])
+ loss = loss.topk(k, dim=-1)[0]
+
+ return torch.mean(loss)
+
+
+class RegressionLoss(nn.Module):
+ def __init__(self, norm, channel_dim=-1):
+ super().__init__()
+ self.norm = norm
+ self.channel_dim = channel_dim
+
+ if norm == 1:
+ self.loss_fn = F.l1_loss
+ elif norm == 2:
+ self.loss_fn = F.mse_loss
+ else:
+ raise ValueError(f'Expected norm 1 or 2, but got norm={norm}')
+
+ def forward(self, prediction, target):
+ loss = self.loss_fn(prediction, target, reduction='none')
+
+ # Sum channel dimension
+ loss = torch.sum(loss, dim=self.channel_dim, keepdims=True)
+ return loss.mean()
+
+
+class SpatialRegressionLoss(nn.Module):
+ def __init__(self, norm, ignore_index=255):
+ super(SpatialRegressionLoss, self).__init__()
+ self.norm = norm
+ self.ignore_index = ignore_index
+
+ if norm == 1:
+ self.loss_fn = F.l1_loss
+ elif norm == 2:
+ self.loss_fn = F.mse_loss
+ else:
+ raise ValueError(f'Expected norm 1 or 2, but got norm={norm}')
+
+ def forward(self, prediction, target, instance_mask=None):
+ assert len(prediction.shape) == 5, 'Must be a 5D tensor'
+ # ignore_index is the same across all channels
+ mask = instance_mask if instance_mask is not None else target[:, :, :1] != self.ignore_index
+ if mask.sum() == 0:
+ return prediction.new_zeros(1)[0].float()
+
+ loss = self.loss_fn(prediction, target, reduction='none')
+
+ # Sum channel dimension
+ loss = torch.sum(loss, dim=-3, keepdims=True)
+
+ return loss[mask].mean()
+
+
+class ProbabilisticLoss(nn.Module):
+ """ Given a prior distribution and a posterior distribution, this module computes KL(posterior, prior)"""
+
+ def __init__(self, remove_first_timestamp=True):
+ super().__init__()
+ self.remove_first_timestamp = remove_first_timestamp
+
+ def forward(self, prior_mu, prior_sigma, posterior_mu, posterior_sigma):
+ posterior_var = posterior_sigma[:, 1:] ** 2
+ prior_var = prior_sigma[:, 1:] ** 2
+
+ posterior_log_sigma = torch.log(posterior_sigma[:, 1:])
+ prior_log_sigma = torch.log(prior_sigma[:, 1:])
+
+ kl_div = (
+ prior_log_sigma - posterior_log_sigma - 0.5
+ + (posterior_var + (posterior_mu[:, 1:] - prior_mu[:, 1:]) ** 2) / (2 * prior_var)
+ )
+ first_kl = - posterior_log_sigma[:, :1] - 0.5 + (posterior_var[:, :1] + posterior_mu[:, :1] ** 2) / 2
+ kl_div = torch.cat([first_kl, kl_div], dim=1)
+
+ # Sum across channel dimension
+ # Average across batch dimension, keep time dimension for monitoring
+ kl_loss = torch.mean(torch.sum(kl_div, dim=-1))
+ return kl_loss
+
+
+class KLLoss(nn.Module):
+ def __init__(self, alpha):
+ super().__init__()
+ self.alpha = alpha
+ self.loss = ProbabilisticLoss(remove_first_timestamp=True)
+
+ def forward(self, prior, posterior):
+ prior_mu, prior_sigma = prior['mu'], prior['sigma']
+ posterior_mu, posterior_sigma = posterior['mu'], posterior['sigma']
+ prior_loss = self.loss(prior_mu, prior_sigma, posterior_mu.detach(), posterior_sigma.detach())
+ posterior_loss = self.loss(prior_mu.detach(), prior_sigma.detach(), posterior_mu, posterior_sigma)
+
+ return self.alpha * prior_loss + (1 - self.alpha) * posterior_loss
+
+
+class VoxelLoss(nn.Module):
+ """ 3D version of SegmentationLoss """
+
+ def __init__(self, use_top_k=False, top_k_ratio=1.0, use_weights=False, poly_one=False, poly_one_coefficient=0.0):
+ super().__init__()
+ self.use_top_k = use_top_k
+ self.top_k_ratio = top_k_ratio
+ self.use_weights = use_weights
+ self.poly_one = poly_one
+ self.poly_one_coefficient = poly_one_coefficient
+
+ if self.use_weights:
+ self.weights = VOXEL_SEG_WEIGHTS
+
+ def forward(self, prediction, target):
+ b, s, c, x, y, z = prediction.shape
+
+ prediction = prediction.view(b * s, c, x, y, z)
+ target = target.view(b * s, x, y, z)
+
+ if self.use_weights:
+ weights = torch.tensor(self.weights, dtype=prediction.dtype, device=prediction.device)
+ else:
+ weights = None
+ loss = F.cross_entropy(
+ prediction,
+ target,
+ reduction='none',
+ weight=weights,
+ )
+
+ if self.poly_one:
+ prob = torch.exp(-loss)
+ loss_poly_one = self.poly_one_coefficient * (1 - prob)
+ loss = loss + loss_poly_one
+
+ loss = loss.view(b, s, -1)
+ if self.use_top_k:
+ # Penalises the top-k hardest pixels
+ k = int(self.top_k_ratio * loss.shape[2])
+ loss = loss.topk(k, dim=-1)[0]
+
+ return torch.mean(loss)
+
+
+# Scene-Class Affinity Loss proposed in MonoScene
+# https://github.com/astra-vision/MonoScene/blob/master/monoscene/loss/ssc_loss.py
+class SemScalLoss(nn.Module):
+ def __init__(self, ignore_index=255):
+ super().__init__()
+ self.ignore_index = ignore_index
+
+ def forward(self, prediction, target):
+ b, s, c, x, y, z = prediction.shape
+
+ prediction = prediction.view(b * s, c, x, y, z)
+ target = target.view(b * s, x, y, z)
+
+ # Get softmax probabilities
+ prediction = F.softmax(prediction, dim=1)
+ loss = 0
+ count = 0
+ mask = target != self.ignore_index
+ n_classes = prediction.shape[1]
+ for i in range(0, n_classes):
+
+ # Get probability of class i
+ p = prediction[:, i, :, :, :]
+
+ # Remove unknown voxels
+ target_ori = target
+ p = p[mask]
+ target_mask = target[mask]
+
+ completion_target = torch.ones_like(target_mask).float()
+ completion_target[target_mask != i] = 0
+ completion_target_ori = torch.ones_like(target_ori).float()
+ completion_target_ori[target_ori != i] = 0
+ if torch.sum(completion_target) > 0:
+ count += 1.0
+ nominator = torch.sum(p * completion_target)
+ loss_class = 0
+ with autocast(enabled=False):
+ if torch.sum(p) > 0:
+ precision = nominator / (torch.sum(p))
+ if 0 <= precision <= 1:
+ loss_precision = F.binary_cross_entropy(
+ precision, torch.ones_like(precision)
+ )
+ loss_class += loss_precision
+ if torch.sum(completion_target) > 0:
+ recall = nominator / (torch.sum(completion_target))
+ if 0 <= recall <= 1:
+ loss_recall = F.binary_cross_entropy(
+ recall, torch.ones_like(recall)
+ )
+ loss_class += loss_recall
+ if torch.sum(1 - completion_target) > 0:
+ specificity = torch.sum((1 - p) * (1 - completion_target)) / (
+ torch.sum(1 - completion_target)
+ )
+ if 0 <= specificity <= 1:
+ loss_specificity = F.binary_cross_entropy(
+ specificity, torch.ones_like(specificity)
+ )
+ loss_class += loss_specificity
+ loss += loss_class
+ return loss / count
+
+
+class GeoScalLoss(nn.Module):
+ def __init__(self, ignore_index=255):
+ super().__init__()
+ self.ignore_index = ignore_index
+
+ def forward(self, prediction, target):
+ b, s, c, x, y, z = prediction.shape
+
+ prediction = prediction.view(b * s, c, x, y, z)
+ target = target.view(b * s, x, y, z)
+
+ # Get softmax probabilities
+ prediction = F.softmax(prediction, dim=1)
+
+ # Compute empty and nonempty probabilities
+ empty_probs = prediction[:, 0, :, :, :]
+ nonempty_probs = 1 - empty_probs
+
+ # Remove unknown voxels
+ mask = target != self.ignore_index
+ nonempty_target = target != 0
+ nonempty_target = nonempty_target[mask].float()
+ nonempty_probs = nonempty_probs[mask]
+ empty_probs = empty_probs[mask]
+
+ intersection = (nonempty_target * nonempty_probs).sum()
+ precision = intersection / nonempty_probs.sum()
+ recall = intersection / nonempty_target.sum()
+ spec = ((1 - nonempty_target) * (empty_probs)).sum() / (1 - nonempty_target).sum()
+ with autocast(enabled=False):
+ loss = F.binary_cross_entropy(precision, torch.ones_like(precision)) + \
+ F.binary_cross_entropy(recall, torch.ones_like(recall)) + \
+ F.binary_cross_entropy(spec, torch.ones_like(spec))
+ return loss
+
+
+# Structure Similarity Index Measure,
+# modified from https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py
+class SSIMLoss(nn.Module):
+ def __init__(self, channel=1, window_size=11, sigma=1.5, L=1, non_negative=False):
+ super().__init__()
+ self.window_size = window_size
+ # self.size_average = size_average
+ self.channel = channel
+ self.sigma = sigma
+ self.C1 = (0.01 * L) ** 2
+ self.C2 = (0.03 * L) ** 2
+ self.window = self.create_window()
+ self.non_negative = non_negative
+
+ def gaussian(self, window_size, sigma):
+ x = torch.arange(window_size)
+ gauss = torch.exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2))
+ return gauss / gauss.sum()
+
+ # gaussian kernel
+ def create_window(self):
+ _1D_window = self.gaussian(self.window_size, self.sigma).unsqueeze(1)
+ _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
+ window = _2D_window.expand(self.channel, 1, self.window_size, self.window_size).contiguous()
+ return window
+
+ def _ssim(self, prediction, target):
+ window = torch.as_tensor(self.window, dtype=prediction.dtype, device=prediction.device)
+
+ padd = 0
+ # padd = self.window_size // 2
+ mu1 = F.conv2d(target, window, padding=padd, groups=self.channel)
+ mu2 = F.conv2d(prediction, window, padding=padd, groups=self.channel)
+
+ mu1_sq = mu1.pow(2)
+ mu2_sq = mu2.pow(2)
+ mu1_mu2 = mu1 * mu2
+
+ sigma1_sq = F.conv2d(target * target, window, padding=padd, groups=self.channel) - mu1_sq
+ sigma2_sq = F.conv2d(prediction * prediction, window, padding=padd, groups=self.channel) - mu2_sq
+ sigma12 = F.conv2d(target * prediction, window, padding=padd, groups=self.channel) - mu1_mu2
+
+ ssim_map = ((2 * mu1_mu2 + self.C1) * (2 * sigma12 + self.C2)) / \
+ ((mu1_sq + mu2_sq + self.C1) * (sigma1_sq + sigma2_sq + self.C2))
+
+ ssim_batch = ssim_map.mean([1, 2, 3])
+ if self.non_negative:
+ ssim_batch = F.relu(ssim_batch)
+
+ return ssim_batch
+
+ def forward(self, prediction, target):
+ b, s, c, h, w = prediction.shape
+
+ prediction = prediction.view(b * s, c, h, w)
+ target = target.view(b * s, c, h, w)
+
+ loss = self._ssim(prediction, target)
+ return loss.mean()
+
+
+# Chamfer Distance
+class CDLoss(nn.Module):
+ def __init__(self, reducer=torch.mean):
+ super().__init__()
+ self.reducer = reducer
+
+ def forward(self, prediction, target):
+ b, s, n, d = prediction.shape
+
+ prediction = prediction.view(b * s, n, d)
+ target = target.view(b * s, n, d)
+ # dist = self.batch_pairwise_dist(prediction, target)
+ # point-to-point distance
+ dist = torch.cdist(prediction.float(), target.float(), 2) # b*s, n, n
+ dl, dr = dist.min(1)[0], dist.min(2)[0]
+ loss = self.reducer(dl, dim=1) + self.reducer(dr, dim=1)
+ return loss.mean()
+
+ @staticmethod
+ def batch_pairwise_dist(x: torch.Tensor, y: torch.Tensor):
+ x_norm = torch.sum(x ** 2, dim=2, keepdim=True)
+ y_norm = torch.sum(y ** 2, dim=2, keepdim=True)
+ xy = torch.bmm(x, y.transpose(1, 2))
+ dist = x_norm - 2 * xy + y_norm.transpose(1, 2)
+ return dist
diff --git a/muvo/metrics.py b/muvo/metrics.py
new file mode 100644
index 0000000..23c36f1
--- /dev/null
+++ b/muvo/metrics.py
@@ -0,0 +1,317 @@
+"""
+code is taken from https://github.com/astra-vision/MonoScene/blob/master/monoscene/loss/sscMetrics.py
+
+Part of the code is taken from https://github.com/waterljwant/SSC/blob/master/sscMetrics.py
+"""
+# import numpy as np
+import torch
+from chamferdist import ChamferDistance
+# from sklearn.metrics import accuracy_score, precision_recall_fscore_support
+
+from muvo.losses import SSIMLoss, CDLoss
+
+
+# SSCMetrics code is modified from https://github.com/astra-vision/MonoScene/blob/master/monoscene/loss/sscMetrics.py
+def get_iou(iou_sum, cnt_class):
+ _C = iou_sum.shape[0] # 12
+ iou = torch.zeros(_C, dtype=torch.float32) # iou for each class
+ for idx in range(_C):
+ iou[idx] = iou_sum[idx] / cnt_class[idx] if cnt_class[idx] else 0
+
+ mean_iou = torch.sum(iou[1:]) / torch.count_nonzero(cnt_class[1:])
+ return iou, mean_iou
+
+
+def get_accuracy(predict, target, weight=None): # 0.05s
+ _bs = predict.shape[0] # batch size
+ _C = predict.shape[1] # _C = 12
+ target = target.int32()
+ target = target.reshape(_bs, -1) # (_bs, 60*36*60) 129600
+ predict = predict.reshape(_bs, _C, -1) # (_bs, _C, 60*36*60)
+ predict = torch.argmax(
+ predict, dim=1
+ ) # one-hot: _bs x _C x 60*36*60 --> label: _bs x 60*36*60.
+
+ correct = predict == target # (_bs, 129600)
+ if weight: # 0.04s, add class weights
+ weight_k = torch.ones(target.shape)
+ for i in range(_bs):
+ for n in range(target.shape[1]):
+ idx = 0 if target[i, n] == 255 else target[i, n]
+ weight_k[i, n] = weight[idx]
+ correct = correct * weight_k
+ acc = correct.sum() / correct.size
+ return acc
+
+
+class SSCMetrics:
+ def __init__(self, n_classes):
+ self.n_classes = n_classes
+ self.reset()
+
+ def hist_info(self, n_cl, pred, gt):
+ assert pred.shape == gt.shape
+ k = (gt >= 0) & (gt < n_cl) # exclude 255
+ labeled = torch.sum(k)
+ correct = torch.sum((pred[k] == gt[k]))
+
+ return (
+ torch.bincount(
+ n_cl * gt[k].astype(int) + pred[k].astype(int), minlength=n_cl ** 2
+ ).reshape(n_cl, n_cl),
+ correct,
+ labeled,
+ )
+
+ @staticmethod
+ def compute_score(hist, correct, labeled):
+ iu = torch.diag(hist) / (hist.sum(1) + hist.sum(0) - torch.diag(hist))
+ mean_IU = torch.nanmean(iu)
+ mean_IU_no_back = torch.nanmean(iu[1:])
+ freq = hist.sum(1) / hist.sum()
+ freq_IU = (iu[freq > 0] * freq[freq > 0]).sum()
+ mean_pixel_acc = correct / labeled if labeled != 0 else 0
+
+ return iu, mean_IU, mean_IU_no_back, mean_pixel_acc
+
+ def add_batch(self, y_pred, y_true, nonempty=None, nonsurface=None):
+ self.count += 1
+ mask = y_true != 255
+ if nonempty is not None:
+ mask = mask & nonempty
+ if nonsurface is not None:
+ mask = mask & nonsurface
+ tp, fp, fn = self.get_score_completion(y_pred, y_true, mask)
+
+ self.completion_tp += tp
+ self.completion_fp += fp
+ self.completion_fn += fn
+
+ mask = y_true != 255
+ if nonempty is not None:
+ mask = mask & nonempty
+ tp_sum, fp_sum, fn_sum = self.get_score_semantic_and_completion(
+ y_pred, y_true, mask
+ )
+ self.tps += tp_sum
+ self.fps += fp_sum
+ self.fns += fn_sum
+
+ self.compute()
+
+ def compute(self):
+ if self.completion_tp != 0:
+ self.precision = self.completion_tp / (self.completion_tp + self.completion_fp)
+ self.recall = self.completion_tp / (self.completion_tp + self.completion_fn)
+ self.iou = self.completion_tp / (
+ self.completion_tp + self.completion_fp + self.completion_fn
+ )
+ else:
+ self.precision, self.recall, self.iou = 0, 0, 0
+
+ self.iou_ssc = self.tps / (self.tps + self.fps + self.fns + 1e-5)
+
+ def get_stats(self):
+ return {
+ "precision": self.precision,
+ "recall": self.recall,
+ "iou": self.iou,
+ "iou_ssc": self.iou_ssc,
+ "iou_ssc_mean": torch.mean(self.iou_ssc[1:]),
+ }
+
+ def reset(self):
+
+ self.completion_tp = 0
+ self.completion_fp = 0
+ self.completion_fn = 0
+ self.tps = torch.zeros(self.n_classes)
+ self.fps = torch.zeros(self.n_classes)
+ self.fns = torch.zeros(self.n_classes)
+
+ self.hist_ssc = torch.zeros((self.n_classes, self.n_classes))
+ self.labeled_ssc = 0
+ self.correct_ssc = 0
+
+ self.precision = 0
+ self.recall = 0
+ self.iou = 0
+ self.count = 1e-8
+ self.iou_ssc = torch.zeros(self.n_classes, dtype=torch.float32)
+ self.cnt_class = torch.zeros(self.n_classes, dtype=torch.float32)
+
+ def get_score_completion(self, predict, target, nonempty=None):
+ predict = predict.clone().detach()
+ target = target.clone().detach()
+
+ """for scene completion, treat the task as two-classes problem, just empty or occupancy"""
+ _bs = predict.shape[0] # batch size
+ # ---- ignore
+ predict[target == 255] = 0
+ target[target == 255] = 0
+ # ---- flatten
+ target = target.reshape(_bs, -1) # (_bs, 129600)
+ predict = predict.reshape(_bs, -1) # (_bs, _C, 129600), 60*36*60=129600
+ # ---- treat all non-empty object class as one category, set them to label 1
+ b_pred = predict.new_zeros(predict.shape)
+ b_true = target.new_zeros(target.shape)
+ b_pred[predict > 0] = 1
+ b_true[target > 0] = 1
+ p, r, iou = 0.0, 0.0, 0.0
+ tp_sum, fp_sum, fn_sum = 0, 0, 0
+ for idx in range(_bs):
+ y_true = b_true[idx, :] # GT
+ y_pred = b_pred[idx, :]
+ if nonempty is not None:
+ nonempty_idx = nonempty[idx, :].reshape(-1)
+ y_true = y_true[nonempty_idx == 1]
+ y_pred = y_pred[nonempty_idx == 1]
+
+ tp = torch.stack(torch.where(torch.logical_and(y_true == 1, y_pred == 1))).numel()
+ fp = torch.stack(torch.where(torch.logical_and(y_true != 1, y_pred == 1))).numel()
+ fn = torch.stack(torch.where(torch.logical_and(y_true == 1, y_pred != 1))).numel()
+ tp_sum += tp
+ fp_sum += fp
+ fn_sum += fn
+ return tp_sum, fp_sum, fn_sum
+
+ def get_score_semantic_and_completion(self, predict, target, nonempty=None):
+ target = target.clone().detach()
+ predict = predict.clone().detach()
+ _bs = predict.shape[0] # batch size
+ _C = self.n_classes # _C = 12
+ # ---- ignore
+ predict[target == 255] = 0
+ target[target == 255] = 0
+ # ---- flatten
+ target = target.reshape(_bs, -1) # (_bs, 129600)
+ predict = predict.reshape(_bs, -1) # (_bs, 129600), 60*36*60=129600
+
+ cnt_class = torch.zeros(_C, dtype=torch.int32) # count for each class
+ iou_sum = torch.zeros(_C, dtype=torch.float32) # sum of iou for each class
+ tp_sum = torch.zeros(_C, dtype=torch.int32) # tp
+ fp_sum = torch.zeros(_C, dtype=torch.int32) # fp
+ fn_sum = torch.zeros(_C, dtype=torch.int32) # fn
+
+ for idx in range(_bs):
+ y_true = target[idx, :] # GT
+ y_pred = predict[idx, :]
+ if nonempty is not None:
+ nonempty_idx = nonempty[idx, :].reshape(-1)
+ y_pred = y_pred[
+ torch.where(torch.logical_and(nonempty_idx == 1, y_true != 255))
+ ]
+ y_true = y_true[
+ torch.where(torch.logical_and(nonempty_idx == 1, y_true != 255))
+ ]
+ for j in range(_C): # for each class
+ tp = torch.stack(torch.where(torch.logical_and(y_true == j, y_pred == j))).numel()
+ fp = torch.stack(torch.where(torch.logical_and(y_true != j, y_pred == j))).numel()
+ fn = torch.stack(torch.where(torch.logical_and(y_true == j, y_pred != j))).numel()
+
+ tp_sum[j] += tp
+ fp_sum[j] += fp
+ fn_sum[j] += fn
+
+ return tp_sum, fp_sum, fn_sum
+
+
+class SSIMMetric:
+ def __init__(self, channel=3, window_size=11, sigma=1.5, L=1, non_negative=False):
+ self.ssim = SSIMLoss(channel=channel, window_size=window_size, sigma=sigma, L=L, non_negative=non_negative)
+ self.reset()
+
+ def add_batch(self, prediction, target):
+ self.count += 1
+ self.ssim_score += self.ssim(prediction, target)
+ self.ssim_avg = self.ssim_score / self.count
+
+ def get_stat(self):
+ return self.ssim_avg
+
+ def reset(self):
+ self.ssim_score = 0
+ self.count = 1e-8
+ self.ssim_avg = 0
+
+
+class CDMetric:
+ def __init__(self, reducer=torch.mean):
+ self.reducer = reducer
+ self.reset()
+
+ def add_batch(self, prediction, target):
+ self.count += 1
+ # dist = CDLoss.batch_pairwise_dist(prediction.float(), target.float()).cpu().numpy()
+ dist = torch.cdist(prediction.float(), target.float(), 2)
+ dl, dr = dist.min(1)[0], dist.min(2)[0]
+ cost = (self.reducer(dl, dim=1) + self.reducer(dr, dim=1)) / 2
+ self.total_cost += cost.mean()
+ self.avg_cost = self.total_cost / self.count
+
+ def get_stat(self):
+ return self.avg_cost
+
+ def reset(self):
+ self.total_cost = 0
+ self.count = 1e-8
+ self.avg_cost = 0
+
+
+class CDMetric0:
+ def __init__(self):
+ self.chamferDist = ChamferDistance()
+ self.reset()
+
+ def add_batch(self, prediction, target, valid_pred, valid_target):
+ self.count += 1
+ b = prediction.shape[0]
+ cdist = 0
+ for i in range(b):
+ # cdist += 0.5 * self.chamferDist(prediction[i][valid_pred[i]][None].float(),
+ # target[i][valid_target[i]][None].float(),
+ # bidirectional=True).detach().cpu().item()
+ pred_pcd = prediction[i][valid_pred[i]]
+ target_pcd = target[i][valid_target[i]]
+ cd_forward = self.chamferDist(pred_pcd[None].float(),
+ target_pcd[None].float(),
+ point_reduction='mean').detach().cpu().item()
+ cd_backward = self.chamferDist(target_pcd[None].float(),
+ pred_pcd[None].float(),
+ point_reduction='mean').detach().cpu().item()
+ cdist += 0.5 * (cd_forward + cd_backward)
+ self.total_cost += cdist / b
+ self.avg_cost = self.total_cost / self.count
+
+ def get_stat(self):
+ return self.avg_cost
+
+ def reset(self):
+ self.total_cost = 0
+ self.count = 1e-8
+ self.avg_cost = 0
+
+
+class PSNRMetric:
+ def __init__(self, max_pixel_val=1.0):
+ self.max_pixel_value = max_pixel_val
+ self.reset()
+
+ def add_batch(self, prediction, target):
+ self.count += 1
+ self.total_psnr += self.psnr(prediction, target).mean()
+ self.avg_psnr = self.total_psnr / self.count
+
+ def psnr(self, prediction, target):
+ # b, s, c, h, w
+ mse = torch.mean((prediction - target) ** 2, dim=(2, 3, 4))
+ psnr = 20 * torch.log10(self.max_pixel_value / torch.sqrt(mse))
+ return psnr
+
+ def get_stat(self):
+ return self.avg_psnr
+
+ def reset(self):
+ self.total_psnr = 0
+ self.count = 1e-8
+ self.avg_psnr = 0
diff --git a/muvo/models/common.py b/muvo/models/common.py
new file mode 100644
index 0000000..0c3c7d7
--- /dev/null
+++ b/muvo/models/common.py
@@ -0,0 +1,787 @@
+from typing import List
+
+import timm
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from torch_scatter import scatter_mean, scatter_max
+import math
+
+
+class RouteEncode(nn.Module):
+ def __init__(self, out_channels, backbone='resnet18'):
+ super().__init__()
+ self.backbone = timm.create_model(backbone, pretrained=True, features_only=True, out_indices=[4])
+ self.out_channels = out_channels
+ feature_info = self.backbone.feature_info.get_dicts(keys=['num_chs', 'reduction'])
+ self.fc = nn.Linear(feature_info[-1]['num_chs'], out_channels)
+
+ def forward(self, route):
+ x = self.backbone(route)[0]
+ x = F.adaptive_avg_pool2d(x, (1, 1)).flatten(1)
+ return self.fc(x)
+
+
+class GRUCellLayerNorm(nn.Module):
+ def __init__(self, input_size, hidden_size, reset_bias=1.0):
+ super().__init__()
+ self.reset_bias = reset_bias
+
+ self.update_layer = nn.Linear(input_size + hidden_size, hidden_size, bias=False)
+ self.update_norm = nn.LayerNorm(hidden_size)
+
+ self.reset_layer = nn.Linear(input_size + hidden_size, hidden_size, bias=False)
+ self.reset_norm = nn.LayerNorm(hidden_size)
+
+ self.proposal_layer = nn.Linear(input_size + hidden_size, hidden_size, bias=False)
+ self.proposal_norm = nn.LayerNorm(hidden_size)
+
+ def forward(self, inputs, state):
+ update = self.update_layer(torch.cat([inputs, state], -1))
+ update = torch.sigmoid(self.update_norm(update))
+
+ reset = self.reset_layer(torch.cat([inputs, state], -1))
+ reset = torch.sigmoid(self.reset_norm(reset) + self.reset_bias)
+
+ h_n = self.proposal_layer(torch.cat([inputs, reset * state], -1))
+ h_n = torch.tanh(self.proposal_norm(h_n))
+ output = (1 - update) * h_n + update * state
+ return output
+
+
+class Policy(nn.Module):
+ def __init__(self, in_channels):
+ super().__init__()
+ self.fc = nn.Sequential(
+ nn.Linear(in_channels, in_channels),
+ nn.ReLU(True),
+ nn.Linear(in_channels, in_channels),
+ nn.ReLU(True),
+ nn.Linear(in_channels, in_channels // 2),
+ nn.ReLU(True),
+ nn.Linear(in_channels // 2, 2),
+ nn.Tanh(),
+ )
+
+ def forward(self, x):
+ return self.fc(x)
+
+
+class Decoder(nn.Module):
+ def __init__(self, feature_info, out_channels):
+ super().__init__()
+ n_upsample_skip_convs = len(feature_info) - 1
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(feature_info[-1]['num_chs'], out_channels, 3, 1, 1, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(True),
+ )
+
+ self.upsample_skip_convs = nn.ModuleList(
+ nn.Sequential(
+ nn.Conv2d(feature_info[-i]['num_chs'], out_channels, 3, 1, 1, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(True),
+ )
+ for i in range(2, n_upsample_skip_convs + 2)
+ )
+
+ self.out_channels = out_channels
+
+ def forward(self, xs: List[Tensor]) -> Tensor:
+ x = self.conv1(xs[-1])
+
+ for i, conv in enumerate(self.upsample_skip_convs):
+ size = xs[-(i + 2)].shape[-2:]
+ x = conv(xs[-(i + 2)]) + F.interpolate(x, size=size, mode='bilinear', align_corners=False)
+
+ return x
+
+
+class DecoderDS(nn.Module):
+ def __init__(self, feature_info, out_channels):
+ super().__init__()
+ n_downsample_skip_convs = len(feature_info) - 1
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(feature_info[0]['num_chs'], out_channels, 3, 1, 1, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(True),
+ )
+
+ self.downsample_skip_convs = nn.ModuleList(
+ nn.Sequential(
+ nn.Conv2d(feature_info[i]['num_chs'], out_channels, 3, 1, 1, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(True),
+ )
+ for i in range(1, n_downsample_skip_convs + 1)
+ )
+
+ self.out_channels = out_channels
+
+ def forward(self, xs: List[Tensor]) -> Tensor:
+ x = self.conv1(xs[0])
+
+ for i, conv in enumerate(self.downsample_skip_convs):
+ stride = xs[i].shape[-1] // xs[i + 1].shape[-1]
+ x = conv(xs[i + 1]) + F.max_pool2d(x, stride) # avg_pool?
+
+ return x
+
+
+class DownSampleConv(nn.Module):
+ def __init__(self, in_channels, out_channels, latent_n_channels, down_sample_scale=None):
+ super().__init__()
+ self.down_sample_scale = down_sample_scale
+ self.conv1 = ConvInstanceNorm(in_channels, out_channels, latent_n_channels)
+ self.conv2 = ConvInstanceNorm(out_channels, out_channels, latent_n_channels)
+
+ def forward(self, x, w):
+ if self.down_sample_scale:
+ x = F.avg_pool2d(x, self.down_sample_scale)
+ x = self.conv1(x, w)
+ return self.conv2(x, w)
+
+
+class DecoderBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, latent_n_channels, upsample=False):
+ super().__init__()
+ self.upsample = upsample
+ self.conv1 = ConvInstanceNorm(in_channels, out_channels, latent_n_channels)
+ self.conv2 = ConvInstanceNorm(out_channels, out_channels, latent_n_channels)
+
+ def forward(self, x, w):
+ if self.upsample:
+ x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=False)
+ x = self.conv1(x, w)
+ return self.conv2(x, w)
+
+
+class DecoderBlock3d(nn.Module):
+ def __init__(self, in_channels, out_channels, latent_n_channels, upsample=False):
+ super().__init__()
+ self.upsample = upsample
+ self.conv1 = ConvInstanceNorm3d(in_channels, out_channels, latent_n_channels)
+ self.conv2 = ConvInstanceNorm3d(out_channels, out_channels, latent_n_channels)
+
+ def forward(self, x, w):
+ if self.upsample:
+ x = F.interpolate(x, scale_factor=2.0, mode='trilinear', align_corners=False)
+ x = self.conv1(x, w)
+ return self.conv2(x, w)
+
+
+class ConvInstanceNorm(nn.Module):
+ def __init__(self, in_channels, out_channels, latent_n_channels):
+ super().__init__()
+ self.conv_act = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, 3, 1, 1),
+ nn.LeakyReLU(0.2, inplace=True),
+ )
+
+ self.adaptive_norm = AdaptiveInstanceNorm(latent_n_channels, out_channels)
+
+ def forward(self, x, w):
+ x = self.conv_act(x)
+ return self.adaptive_norm(x, w)
+
+
+class ConvInstanceNorm3d(nn.Module):
+ def __init__(self, in_channels, out_channels, latent_n_channels):
+ super().__init__()
+ self.conv_act = nn.Sequential(
+ nn.Conv3d(in_channels, out_channels, 3, 1, 1),
+ nn.LeakyReLU(0.2, inplace=True),
+ )
+
+ self.adaptive_norm = AdaptiveInstanceNorm3d(latent_n_channels, out_channels)
+
+ def forward(self, x, w):
+ x = self.conv_act(x)
+ return self.adaptive_norm(x, w)
+
+
+class AdaptiveInstanceNorm(nn.Module):
+ def __init__(self, latent_n_channels, out_channels, epsilon=1e-8):
+ super().__init__()
+ self.out_channels = out_channels
+ self.epsilon = epsilon
+
+ self.latent_affine = nn.Linear(latent_n_channels, 2 * out_channels)
+
+ def forward(self, x, style):
+ # Instance norm
+ mean = x.mean(dim=(-1, -2), keepdim=True)
+ x = x - mean
+ std = torch.sqrt(torch.mean(x ** 2, dim=(-1, -2), keepdim=True) + self.epsilon)
+ x = x / std
+
+ # Normalising with the style vector
+ style = self.latent_affine(style).unsqueeze(-1).unsqueeze(-1)
+ scale, bias = torch.split(style, split_size_or_sections=self.out_channels, dim=1)
+ out = scale * x + bias
+ return out
+
+
+class AdaptiveInstanceNorm3d(nn.Module):
+ def __init__(self, latent_n_channels, out_channels, epsilon=1e-8):
+ super().__init__()
+ self.out_channels = out_channels
+ self.epsilon = epsilon
+
+ self.latent_affine = nn.Linear(latent_n_channels, 2 * out_channels)
+
+ def forward(self, x, style):
+ # Instance norm
+ mean = x.mean(dim=(-1, -2, -3), keepdim=True)
+ x = x - mean
+ std = torch.sqrt(torch.mean(x ** 2, dim=(-1, -2, -3), keepdim=True) + self.epsilon)
+ x = x / std
+
+ # Normalising with the style vector
+ style = self.latent_affine(style).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ scale, bias = torch.split(style, split_size_or_sections=self.out_channels, dim=1)
+ out = scale * x + bias
+ return out
+
+
+class SegmentationHead(nn.Module):
+ def __init__(self, in_channels, n_classes, downsample_factor):
+ super().__init__()
+ self.downsample_factor = downsample_factor
+
+ self.segmentation_head = nn.Sequential(
+ nn.Conv2d(in_channels, n_classes, kernel_size=1, padding=0),
+ )
+ self.instance_offset_head = nn.Sequential(
+ nn.Conv2d(in_channels, 2, kernel_size=1, padding=0),
+ )
+ self.instance_center_head = nn.Sequential(
+ nn.Conv2d(in_channels, 1, kernel_size=1, padding=0),
+ nn.Sigmoid(),
+ )
+
+ def forward(self, x):
+ output = {
+ f'bev_segmentation_{self.downsample_factor}': self.segmentation_head(x),
+ f'bev_instance_offset_{self.downsample_factor}': self.instance_offset_head(x),
+ f'bev_instance_center_{self.downsample_factor}': self.instance_center_head(x),
+ }
+ return output
+
+
+class RGBHead(nn.Module):
+ def __init__(self, in_channels, n_classes, downsample_factor):
+ super().__init__()
+ self.downsample_factor = downsample_factor
+
+ self.rgb_head = nn.Sequential(
+ nn.Conv2d(in_channels, n_classes, kernel_size=1, padding=0),
+ )
+
+ def forward(self, x):
+ output = {
+ f'rgb_{self.downsample_factor}': self.rgb_head(x),
+ }
+ return output
+
+
+class LidarReHead(nn.Module):
+ def __init__(self, in_channels, n_classes, downsample_factor):
+ super().__init__()
+ self.downsample_factor = downsample_factor
+
+ self.lidar_re_head = nn.Sequential(
+ nn.Conv2d(in_channels, n_classes, kernel_size=1, padding=0),
+ )
+
+ def forward(self, x):
+ output = {
+ f'lidar_reconstruction_{self.downsample_factor}': self.lidar_re_head(x),
+ }
+ return output
+
+
+class LidarSegHead(nn.Module):
+ def __init__(self, in_channels, n_classes, downsample_factor):
+ super().__init__()
+ self.downsample_factor = downsample_factor
+
+ self.seg_head = nn.Sequential(
+ nn.Conv2d(in_channels, n_classes, kernel_size=1, padding=0),
+ )
+
+ def forward(self, x):
+ output = {
+ f'lidar_segmentation_{self.downsample_factor}': self.seg_head(x),
+ }
+ return output
+
+
+class SemHead(nn.Module):
+ def __init__(self, in_channels, n_classes, downsample_factor):
+ super().__init__()
+ self.downsample_factor = downsample_factor
+
+ self.sem_head = nn.Sequential(
+ nn.Conv2d(in_channels, n_classes, kernel_size=1, padding=0),
+ )
+
+ def forward(self, x):
+ output = {
+ f'semantic_image_{self.downsample_factor}': self.sem_head(x),
+ }
+ return output
+
+
+class DepthHead(nn.Module):
+ def __init__(self, in_channels, n_classes, downsample_factor):
+ super().__init__()
+ self.downsample_factor = downsample_factor
+
+ self.depth_head = nn.Sequential(
+ nn.Conv2d(in_channels, n_classes, kernel_size=1, padding=0),
+ )
+
+ def forward(self, x):
+ output = {
+ f'depth_{self.downsample_factor}': self.depth_head(x),
+ }
+ return output
+
+
+class VoxelSemHead(nn.Module):
+ def __init__(self, in_channels, n_classes, downsample_factor):
+ super().__init__()
+ self.downsample_factor = downsample_factor
+
+ self.segmentation_head = nn.Sequential(
+ nn.Conv3d(in_channels, n_classes, kernel_size=1, padding=0),
+ )
+
+ def forward(self, x):
+ output = {
+ f'voxel_{self.downsample_factor}': self.segmentation_head(x),
+ }
+ return output
+
+
+class BevDecoder(nn.Module):
+ def __init__(self, latent_n_channels, semantic_n_channels, constant_size=(3, 3), head='bev'):
+ super().__init__()
+ n_channels = 512
+
+ self.constant_tensor = nn.Parameter(torch.randn((n_channels, *constant_size), dtype=torch.float32))
+
+ # Input 512 x 3 x 3
+ self.first_norm = AdaptiveInstanceNorm(latent_n_channels, out_channels=n_channels)
+ self.first_conv = ConvInstanceNorm(n_channels, n_channels, latent_n_channels)
+ # 512 x 3 x 3
+
+ self.middle_conv = nn.ModuleList(
+ [DecoderBlock(n_channels, n_channels, latent_n_channels, upsample=True) for _ in range(3)]
+ )
+
+ head_modules = {'rgb': RGBHead,
+ 'bev': SegmentationHead,
+ 'depth': DepthHead,
+ 'sem_image': SemHead,
+ 'lidar_re': LidarReHead,
+ 'lidar_seg': LidarSegHead}
+ head_module = head_modules[head] if head in head_modules else RGBHead
+ # 512 x 24 x 24
+ self.conv1 = DecoderBlock(n_channels, 256, latent_n_channels, upsample=True)
+ self.head_4 = head_module(256, semantic_n_channels, downsample_factor=4)
+ # 256 x 48 x 48
+
+ self.conv2 = DecoderBlock(256, 128, latent_n_channels, upsample=True)
+ self.head_2 = head_module(128, semantic_n_channels, downsample_factor=2)
+ # 128 x 96 x 96
+
+ self.conv3 = DecoderBlock(128, 64, latent_n_channels, upsample=True)
+ self.head_1 = head_module(64, semantic_n_channels, downsample_factor=1)
+ # 64 x 192 x 192
+
+ def forward(self, w: Tensor) -> Tensor:
+ b = w.shape[0]
+ x = self.constant_tensor.unsqueeze(0).repeat([b, 1, 1, 1])
+
+ x = self.first_norm(x, w)
+ x = self.first_conv(x, w)
+
+ for module in self.middle_conv:
+ x = module(x, w)
+
+ x = self.conv1(x, w)
+ output_4 = self.head_4(x)
+ x = self.conv2(x, w)
+ output_2 = self.head_2(x)
+ x = self.conv3(x, w)
+ output_1 = self.head_1(x)
+
+ output = {**output_4, **output_2, **output_1}
+ return output
+
+
+class VoxelDecoderScale(nn.Module):
+ def __init__(self, input_channels, n_classes, kernel_size=1, feature_channels=512):
+ super().__init__()
+
+ # weight of xy,xz,yz view features.
+ self.weight_xy_decoder = nn.Conv2d(input_channels, 1, kernel_size, 1)
+ self.weight_xz_decoder = nn.Conv2d(input_channels, 1, kernel_size, 1)
+ self.weight_yz_decoder = nn.Conv2d(input_channels, 1, kernel_size, 1)
+
+ # self.classifier = nn.Sequential(
+ # nn.Linear(feature_channels, feature_channels),
+ # nn.Softplus(),
+ # nn.Linear(feature_channels, n_classes)
+ # )
+ self.classifier = nn.Sequential(
+ nn.Conv3d(feature_channels, feature_channels, kernel_size=3, stride=1, padding=1),
+ nn.Softplus(),
+ nn.Conv3d(feature_channels, n_classes, kernel_size=1, stride=1, padding=0)
+ )
+
+ def attention_fusion(self, t1, w1, t2, w2):
+ norm_weight = torch.softmax(torch.cat([w1, w2], dim=1), dim=1)
+ feat = t1 * norm_weight[:, 0:1] + t2 * norm_weight[:, 1:2]
+ return feat
+
+ def expand_to_XYZ(self, xy_feat, xz_feat, yz_feat):
+ B, C, X, Y, Z = *xy_feat.size(), xz_feat.size(3)
+ xy_feat = xy_feat.view(B, C, X, Y, 1)
+ xz_feat = xz_feat.view(B, C, X, 1, Z)
+ yz_feat = yz_feat.view(B, C, 1, Y, Z)
+ return torch.broadcast_tensors(xy_feat, xz_feat, yz_feat)
+
+ def forward(self, x):
+ feature_xy, feature_xz, feature_yz = x
+
+ weights_xy = self.weight_xy_decoder(feature_xy)
+ weights_xz = self.weight_xz_decoder(feature_xz)
+ weights_yz = self.weight_yz_decoder(feature_yz)
+
+ feature_xy, feature_xz, feature_yz = self.expand_to_XYZ(feature_xy, feature_xz, feature_yz)
+ weights_xy, weights_xz, weights_yz = self.expand_to_XYZ(weights_xy, weights_xz, weights_yz)
+
+ # fuse xy, xz, yz features in xyz.
+ features_xyz = self.attention_fusion(feature_xy, weights_xy, feature_xz, weights_xz) + \
+ self.attention_fusion(feature_xy, weights_xy, feature_yz, weights_yz)
+
+ # B, C, X, Y, Z = features_xyz.size()
+ # logits = self.classifier(features_xyz.view(B, C, -1).transpose(1, 2))
+ # logits = logits.permute(0, 2, 1).reshape(B, -1, X, Y, Z)
+ logits = self.classifier(features_xyz)
+
+ return logits
+
+
+class VoxelDecoder0(nn.Module):
+ def __init__(self, input_channels, n_classes, kernel_size=1, feature_channels=512):
+ super().__init__()
+
+ self.decoder_1 = VoxelDecoderScale(input_channels, n_classes, kernel_size, feature_channels)
+ self.decoder_2 = VoxelDecoderScale(input_channels, n_classes, kernel_size, feature_channels)
+ self.decoder_4 = VoxelDecoderScale(input_channels, n_classes, kernel_size, feature_channels)
+
+ def forward(self, xy, xz, yz):
+ output_1 = self.decoder_1((xy['rgb_1'], xz['rgb_1'], yz['rgb_1']))
+ output_2 = self.decoder_2((xy['rgb_2'], xz['rgb_2'], yz['rgb_2']))
+ output_4 = self.decoder_4((xy['rgb_4'], xz['rgb_4'], yz['rgb_4']))
+ return {'voxel_1': output_1,
+ 'voxel_2': output_2,
+ 'voxel_4': output_4}
+
+
+class VoxelDecoder1(nn.Module):
+ def __init__(self, latent_n_channels, semantic_n_channels, feature_channels=512, constant_size=(3, 3, 1)):
+ super().__init__()
+ n_channels = feature_channels
+
+ self.constant_tensor = nn.Parameter(torch.randn((2 * n_channels, *constant_size), dtype=torch.float32))
+
+ # Input 512 x 3 x 3 x 1
+ self.first_norm = AdaptiveInstanceNorm3d(latent_n_channels, out_channels=2 * n_channels)
+ self.first_conv = ConvInstanceNorm3d(2 * n_channels, n_channels, latent_n_channels)
+ # 512 x 3 x 3 x 1
+
+ self.middle_conv = nn.ModuleList(
+ [DecoderBlock3d(n_channels, n_channels, latent_n_channels, upsample=True) for _ in range(3)]
+ )
+
+ head_module = VoxelSemHead
+ # 512 x 24 x 24 x 8
+ self.conv1 = DecoderBlock3d(n_channels, n_channels // 2, latent_n_channels, upsample=True)
+ self.head_4 = head_module(n_channels // 2, semantic_n_channels, downsample_factor=4)
+ # 256 x 48 x 48 x 16
+
+ self.conv2 = DecoderBlock3d(n_channels // 2, n_channels // 4, latent_n_channels, upsample=True)
+ self.head_2 = head_module(n_channels // 4, semantic_n_channels, downsample_factor=2)
+ # 128 x 96 x 96 x 32
+
+ self.conv3 = DecoderBlock3d(n_channels // 4, n_channels // 8, latent_n_channels, upsample=True)
+ self.head_1 = head_module(n_channels // 8, semantic_n_channels, downsample_factor=1)
+ # 64 x 192 x 192 x 64
+
+ def forward(self, w: Tensor) -> Tensor:
+ b = w.shape[0]
+ x = self.constant_tensor.unsqueeze(0).repeat([b, 1, 1, 1, 1])
+
+ x = self.first_norm(x, w)
+ x = self.first_conv(x, w)
+
+ for module in self.middle_conv:
+ x = module(x, w)
+
+ x = self.conv1(x, w)
+ output_4 = self.head_4(x)
+ x = self.conv2(x, w)
+ output_2 = self.head_2(x)
+ x = self.conv3(x, w)
+ output_1 = self.head_1(x)
+
+ output = {**output_4, **output_2, **output_1}
+ return output
+
+
+class ConvDecoder(nn.Module):
+ def __init__(self, latent_n_channels, out_channels, constant_size=(5, 13), mlp_layers=0, layer_norm=True,
+ activation=nn.ELU, head='rgb'):
+ super().__init__()
+ n_channels = 512
+ if mlp_layers == 0:
+ layers = [
+ nn.Linear(latent_n_channels, n_channels), # no activation here in dreamer v2
+ ]
+ else:
+ hidden_dim = n_channels
+ norm = nn.LayerNorm if layer_norm else nn.Identity
+ layers = [
+ nn.Linear(latent_n_channels, hidden_dim),
+ norm(hidden_dim, eps=1e-3),
+ activation(),
+ ]
+ for _ in range(mlp_layers - 1):
+ layers += [
+ nn.Linear(hidden_dim, hidden_dim),
+ norm(hidden_dim, eps=1e-3),
+ activation()
+ ]
+ self.linear = nn.Sequential(*layers, nn.Unflatten(-1, (n_channels, 1, 1))) # N x n_channels
+
+ self.pre_transpose_conv = nn.Sequential(
+ # *layers,
+ # nn.Unflatten(-1, (n_channels, 1, 5)),
+ # nn.ConvTranspose2d(n_channels, n_channels, kernel_size=5, stride=2), # 5 x 13
+ nn.ConvTranspose2d(n_channels, n_channels, kernel_size=constant_size), # 5 x 13
+ activation(),
+ nn.ConvTranspose2d(n_channels, n_channels, kernel_size=5, stride=2, padding=2, output_padding=1), # 10 x 26
+ activation(),
+ nn.ConvTranspose2d(n_channels, n_channels, kernel_size=5, stride=2, padding=2, output_padding=1), # 20 x 52
+ activation(),
+ nn.ConvTranspose2d(n_channels, n_channels, kernel_size=6, stride=2, padding=2), # 40 x 104
+ activation(),
+ )
+
+ head_modules = {'rgb': RGBHead,
+ 'bev': SegmentationHead,
+ 'depth': DepthHead,
+ 'sem_image': SemHead,
+ 'lidar_re': LidarReHead,
+ 'lidar_seg': LidarSegHead}
+ head_module = head_modules[head] if head in head_modules else RGBHead
+
+ self.trans_conv1 = nn.Sequential(
+ nn.ConvTranspose2d(n_channels, 256, kernel_size=6, stride=2, padding=2),
+ activation(),
+ )
+ self.head_4 = head_module(in_channels=256, n_classes=out_channels, downsample_factor=4)
+ # 256 x 80 x 208
+
+ self.trans_conv2 = nn.Sequential(
+ nn.ConvTranspose2d(256, 128, kernel_size=6, stride=2, padding=2),
+ activation(),
+ )
+ self.head_2 = head_module(in_channels=128, n_classes=out_channels, downsample_factor=2)
+ # 128 x 160 x 416
+
+ self.trans_conv3 = nn.Sequential(
+ nn.ConvTranspose2d(128, 64, kernel_size=6, stride=2, padding=2),
+ activation()
+ )
+ self.head_1 = head_module(in_channels=64, n_classes=out_channels, downsample_factor=1)
+ # 64 x 320 x 832
+
+ def forward(self, x):
+ x = self.linear(x) # N x n_channels x 1 x 1
+
+ # x = x.repeat(1, 1, 1, 5)
+ # N x n_channels x 1 x 5
+ x = self.pre_transpose_conv(x)
+
+ x = self.trans_conv1(x)
+ output_4 = self.head_4(x)
+ x = self.trans_conv2(x)
+ output_2 = self.head_2(x)
+ x = self.trans_conv3(x)
+ output_1 = self.head_1(x)
+
+ output = {**output_4, **output_2, **output_1}
+ return output
+
+
+# https://github.com/opendilab/InterFuser/blob/e0682c350892a243cf40bf448622743f4b26d0f3/interfuser/timm/models/interfuser.py#L66
+class PositionEmbeddingSine(nn.Module):
+ """
+ This is a more standard version of the position embedding, very similar to the one
+ used by the Attention is all you need paper, generalized to work on images.
+ """
+
+ def __init__(
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
+ ):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * math.pi
+ self.scale = scale
+
+ def forward(self, tensor):
+ x = tensor
+ _, _, h, w = x.shape
+ not_mask = torch.ones((1, h, w), device=x.device)
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+
+# https://github.com/dotchen/LAV/blob/main/lav/models/point_pillar.py#L38
+class DynamicPointNet(nn.Module):
+ def __init__(self, num_input=9, num_features=[32, 32]):
+ super().__init__()
+
+ L = []
+ for num_feature in num_features:
+ L += [
+ nn.Linear(num_input, num_feature),
+ nn.BatchNorm1d(num_feature),
+ nn.ReLU(inplace=True),
+ ]
+
+ num_input = num_feature
+
+ self.net = nn.Sequential(*L)
+
+ def forward(self, points, inverse_indices):
+ """
+ TODO: multiple layers
+ """
+ feat = self.net(points)
+ feat_max = scatter_max(feat, inverse_indices, dim=0)[0]
+ # feat_max = scatter_max(points, inverse_indices, dim=0)[0]
+ return feat_max
+
+
+class PointPillarNet(nn.Module):
+ def __init__(self, num_input=9, num_features=[32, 32],
+ min_x=-10, max_x=70,
+ min_y=-40, max_y=40,
+ pixels_per_meter=4):
+ super().__init__()
+ self.point_net = DynamicPointNet(num_input, num_features)
+
+ self.nx = (max_x - min_x) * pixels_per_meter
+ self.ny = (max_y - min_y) * pixels_per_meter
+ self.min_x = min_x
+ self.min_y = min_y
+ self.max_x = max_x
+ self.max_y = max_y
+ self.pixels_per_meter = pixels_per_meter
+
+ def decorate(self, points, unique_coords, inverse_indices):
+ dtype = points.dtype
+ x_centers = unique_coords[inverse_indices][:, 2:3].to(dtype) / self.pixels_per_meter + self.min_x
+ y_centers = unique_coords[inverse_indices][:, 1:2].to(dtype) / self.pixels_per_meter + self.min_y
+
+ xyz = points[:, :3]
+
+ points_cluster = xyz - scatter_mean(xyz, inverse_indices, dim=0)[inverse_indices]
+
+ points_xp = xyz[:, :1] - x_centers
+ points_yp = xyz[:, 1:2] - y_centers
+
+ features = torch.cat([points, points_cluster, points_xp, points_yp], dim=-1)
+ return features
+
+ def grid_locations(self, points):
+ keep = (points[:, 0] >= self.min_x) & (points[:, 0] < self.max_x) & \
+ (points[:, 1] >= self.min_y) & (points[:, 1] < self.max_y)
+ points = points[keep, :]
+
+ coords = (points[:, [0, 1]] - torch.tensor([self.min_x, self.min_y],
+ device=points.device)) * self.pixels_per_meter
+ coords = coords.long()
+
+ return points, coords
+
+ def pillar_generation(self, points, coords):
+ unique_coords, inverse_indices = coords.unique(return_inverse=True, dim=0)
+ decorated_points = self.decorate(points, unique_coords, inverse_indices)
+
+ return decorated_points, unique_coords, inverse_indices
+
+ def scatter_points(self, features, coords, batch_size):
+ canvas = torch.zeros(batch_size, features.shape[1], self.ny, self.nx, dtype=features.dtype,
+ device=features.device)
+ canvas[coords[:, 0], :, torch.clamp(
+ self.ny - 1 - coords[:, 1], 0, self.ny - 1), torch.clamp(coords[:, 2], 0, self.nx - 1)] = features
+ return canvas
+
+ def forward(self, lidar_list, num_points):
+ batch_size = len(lidar_list)
+ with torch.no_grad():
+ coords = []
+ filtered_points = []
+ for batch_id, points in enumerate(lidar_list):
+ points = points[:num_points[batch_id]]
+ points, grid_yx = self.grid_locations(points)
+
+ # batch indices
+ grid_byx = torch.nn.functional.pad(grid_yx,
+ (1, 0), mode='constant', value=batch_id)
+
+ coords.append(grid_byx)
+ filtered_points.append(points)
+
+ # batch_size, grid_y, grid_x
+ coords = torch.cat(coords, dim=0)
+ filtered_points = torch.cat(filtered_points, dim=0)
+
+ decorated_points, unique_coords, inverse_indices = self.pillar_generation(filtered_points, coords)
+
+ features = self.point_net(decorated_points, inverse_indices)
+
+ return self.scatter_points(features, unique_coords, batch_size)
\ No newline at end of file
diff --git a/muvo/models/frustum_pooling.py b/muvo/models/frustum_pooling.py
new file mode 100644
index 0000000..e148095
--- /dev/null
+++ b/muvo/models/frustum_pooling.py
@@ -0,0 +1,217 @@
+""" Adapted from https://github.com/nv-tlabs/lift-splat-shoot/blob/master/src/tools.py"""
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from muvo.utils.geometry_utils import bev_params_to_intrinsics, intrinsics_inverse
+
+
+def gen_dx_bx(size, scale, offsetx):
+ xbound = [-size[0] * scale / 2 - offsetx * scale, size[0] * scale / 2 - offsetx * scale, scale]
+ ybound = [-size[1] * scale / 2, size[1] * scale / 2, scale]
+ zbound = [-10.0, 10.0, 20.0]
+
+ dx = torch.Tensor([row[2] for row in [xbound, ybound, zbound]])
+ bx = torch.Tensor([row[0] + row[2] / 2.0 for row in [xbound, ybound, zbound]])
+ # nx = torch.LongTensor([(row[1] - row[0]) / row[2] for row in [xbound, ybound, zbound]])
+ nx = torch.LongTensor([np.round((row[1] - row[0]) / row[2]) for row in [xbound, ybound, zbound]])
+
+ return dx, bx, nx
+
+
+def cumsum_trick(x, geom_feats, ranks):
+ x = x.cumsum(0)
+ kept = torch.ones(x.shape[0], device=x.device, dtype=torch.bool)
+ kept[:-1] = (ranks[1:] != ranks[:-1])
+
+ x, geom_feats = x[kept], geom_feats[kept]
+ x = torch.cat((x[:1], x[1:] - x[:-1]))
+
+ return x, geom_feats
+
+
+class QuickCumsum(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, geom_feats, ranks):
+ x = x.cumsum(0)
+ kept = torch.ones(x.shape[0], device=x.device, dtype=torch.bool)
+ kept[:-1] = (ranks[1:] != ranks[:-1])
+
+ x, geom_feats = x[kept], geom_feats[kept]
+ x = torch.cat((x[:1], x[1:] - x[:-1]))
+
+ # save kept for backward
+ ctx.save_for_backward(kept)
+
+ # no gradient for geom_feats
+ ctx.mark_non_differentiable(geom_feats)
+
+ return x, geom_feats
+
+ @staticmethod
+ def backward(ctx, gradx, gradgeom):
+ kept, = ctx.saved_tensors
+ back = torch.cumsum(kept, 0)
+ back[kept] -= 1
+
+ val = gradx[back]
+
+ return val, None, None
+
+
+def quick_cumsum(x, geom_feats, ranks):
+ return QuickCumsum.apply(x, geom_feats, ranks)
+
+
+class FrustumPooling(nn.Module):
+ def __init__(self, size, scale, offsetx, dbound, downsample, use_quickcumsum=True):
+ """ Pools camera frustums into Birds Eye View
+
+ Args:
+ size: (width, height) size of voxel grid
+ scale: size of pixel in m
+ offsetx: egocar offset (forwards) from center of bev in px
+ dbound: depth planes in camera frustum (min, max, step)
+ downsample: fraction of the size of the feature maps (stride of backbone)
+ """
+ super().__init__()
+
+ self.register_buffer('bev_intrinsics', torch.tensor(bev_params_to_intrinsics(size, scale, offsetx)))
+
+ dx, bx, nx = gen_dx_bx(size, scale, offsetx)
+ self.nx_constant = nx.numpy().tolist()
+ self.register_buffer('dx', dx, persistent=False) # bev_resolution
+ self.register_buffer('bx', bx, persistent=False) # bev_start_position
+ self.register_buffer('nx', nx, persistent=False) # bev_dimension
+ self.use_quickcumsum = use_quickcumsum
+
+ self.dbound = dbound
+ ds = torch.arange(self.dbound[0], self.dbound[1], self.dbound[2], dtype=torch.float32)
+ self.D = len(ds)
+ self.register_buffer('ds', ds, persistent=False)
+
+ self.downsample = downsample
+ self.register_buffer('frustum', torch.zeros(0,), persistent=False)
+
+ def initialize_frustum(self, image):
+ if self.frustum.shape[0] == 0:
+ device = image.device
+ # make grid in image plane
+ fH, fW = image.shape[-3:-1]
+ ogfH, ogfW = fH * self.downsample, fW * self.downsample
+ ds = self.ds.view(-1, 1, 1).expand(-1, fH, fW)
+ xs = torch.linspace(0, ogfW - 1, fW, dtype=torch.float, device=device).view(1, 1, fW).expand(self.D, fH, fW)
+ ys = torch.linspace(0, ogfH - 1, fH, dtype=torch.float, device=device).view(1, fH, 1).expand(self.D, fH, fW)
+
+ # D x H x W x 3
+ # with the 3D coordinates being (x, y, z)
+ self.frustum = torch.stack((xs, ys, ds), -1)
+
+ def get_geometry(self, rots, trans, intrins): # , post_rots=None, post_trans=None):
+ """Determine the (x,y,z) locations (in the ego frame)
+ of the points in the point cloud.
+ Returns B x N x D x H/downsample x W/downsample x 3
+ """
+ B, N = trans.shape[:2]
+
+ points = self.frustum.unsqueeze(0).unsqueeze(0).unsqueeze(-1)
+
+ # cam_to_ego
+ points = torch.cat((points[:, :, :, :, :, :2] * points[:, :, :, :, :, 2:3],
+ points[:, :, :, :, :, 2:3]
+ ), 5)
+ # combine = rots.matmul(torch.inverse(intrins))
+ combine = rots.matmul(intrinsics_inverse(intrins))
+ points = combine.view(B, N, 1, 1, 1, 3, 3).matmul(points).squeeze(-1)
+ points += trans.view(B, N, 1, 1, 1, 3)
+
+ return points
+
+ def voxel_pooling(self, geom_feats, x, mask):
+ B, N, D, H, W, C = x.shape
+ Nprime = B * N * D * H * W
+
+ # flatten
+ x = x.reshape(Nprime, C)
+
+ # The coordinates are defined as (forward, left, up)
+ geom_feats = geom_feats.view(Nprime, 3)
+
+ # transform world points to bev coords
+ geom_feats[:, 0] = geom_feats[:, 0] * self.bev_intrinsics[0, 0] + self.bev_intrinsics[0, 2]
+ geom_feats[:, 1] = geom_feats[:, 1] * self.bev_intrinsics[1, 1] + self.bev_intrinsics[1, 2]
+ # TODO: seems like things < -10m also get projected.
+ geom_feats[:, 2] = (geom_feats[:, 2] - self.bx[2] + self.dx[2] / 2.) / self.dx[2]
+ geom_feats = geom_feats.long()
+
+ batch_ix = torch.cat([torch.full(size=(Nprime // B, 1), fill_value=ix,
+ device=x.device, dtype=torch.long) for ix in range(B)])
+ geom_feats = torch.cat((geom_feats, batch_ix), 1)
+
+ # sparse lifting for speed
+ if len(mask) > 0:
+ mask = mask.view(Nprime)
+ x = x[mask]
+ geom_feats = geom_feats[mask]
+
+ # filter out points that are outside box
+ kept = (geom_feats[:, 0] >= 0) & (geom_feats[:, 0] < self.nx[0]) \
+ & (geom_feats[:, 1] >= 0) & (geom_feats[:, 1] < self.nx[1]) \
+ & (geom_feats[:, 2] >= 0) & (geom_feats[:, 2] < self.nx[2])
+ x = x[kept]
+ geom_feats = geom_feats[kept]
+
+ # get tensors from the same voxel next to each other
+ ranks = geom_feats[:, 0] * (self.nx[1] * self.nx[2] * B) \
+ + geom_feats[:, 1] * (self.nx[2] * B) \
+ + geom_feats[:, 2] * B \
+ + geom_feats[:, 3]
+ sorts = ranks.argsort()
+ x, geom_feats, ranks = x[sorts], geom_feats[sorts], ranks[sorts]
+
+ # cumsum trick
+ if self.use_quickcumsum and self.training:
+ x, geom_feats = quick_cumsum(x, geom_feats, ranks)
+ else:
+ x, geom_feats = cumsum_trick(x, geom_feats, ranks)
+
+ # griddify (B x C x up x left x forward)
+ final = torch.zeros((B, C, self.nx_constant[2], self.nx_constant[1], self.nx_constant[0]), dtype=x.dtype,
+ device=x.device)
+ final[geom_feats[:, 3], :, geom_feats[:, 2], geom_feats[:, 1], geom_feats[:, 0]] = x
+
+ # collapse "up" dimension
+ final = torch.cat(final.unbind(dim=2), 1)
+
+ return final
+
+ def forward(self, x, intrinsics, pose, mask=torch.zeros(0)): # , post_rots=None, post_trans=None):
+ """
+ Args:
+ x: (B x N x D x H x W x C) frustum feature maps
+ intrinsics: (B x N x 3 x 3) camera intrinsics (of input image prior to downsampling by backbone)
+ pose: (B x N x 4 x 4) camera pose matrix
+ """
+
+ # the intrinsics matrix is defined as
+ # [[f', 0, m_x],
+ # [0, f', m_y],
+ # [0, 0, 1]]
+ # with f' = kf in pixel units. k being the factor in pixel/m, f the focal lens in m.
+ # (m_x, m_y) is the center point in pixel.
+
+ self.initialize_frustum(x)
+ rots = pose[..., :3, :3]
+ trans = pose[..., :3, 3:]
+ geom = self.get_geometry(rots, trans, intrinsics) # , post_rots, post_trans)
+ x = self.voxel_pooling(geom, x, mask).type_as(x) # TODO: do we want to do more of frustum pooling in FP16?
+ return x
+
+ def get_depth_map(self, depth):
+ """ Convert depth probibility distribution to depth """
+ ds = self.ds.view(1, -1, 1, 1)
+ depth = (ds * depth).sum(1, keepdim=True)
+ depth = nn.functional.interpolate(depth, scale_factor=float(self.downsample), mode='bilinear',
+ align_corners=False)
+ return depth
diff --git a/muvo/models/mile.py b/muvo/models/mile.py
new file mode 100644
index 0000000..5a02c0a
--- /dev/null
+++ b/muvo/models/mile.py
@@ -0,0 +1,1032 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import timm
+
+from constants import CARLA_FPS, DISPLAY_SEGMENTATION
+from muvo.utils.network_utils import pack_sequence_dim, unpack_sequence_dim, remove_past
+from muvo.models.common import BevDecoder, Decoder, RouteEncode, Policy, VoxelDecoder1, ConvDecoder, \
+ PositionEmbeddingSine, DecoderDS, PointPillarNet, DownSampleConv
+from muvo.models.frustum_pooling import FrustumPooling
+from muvo.layers.layers import BasicBlock
+from muvo.models.transition import RSSM
+
+
+class Mile(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+ self.receptive_field = cfg.RECEPTIVE_FIELD
+
+ embedding_n_channels = self.cfg.MODEL.EMBEDDING_DIM
+ # Image feature encoder
+ if self.cfg.MODEL.ENCODER.NAME == 'resnet18':
+ self.encoder = timm.create_model(
+ cfg.MODEL.ENCODER.NAME, pretrained=True, features_only=True, out_indices=[2, 3, 4],
+ )
+ feature_info = self.encoder.feature_info.get_dicts(keys=['num_chs', 'reduction'])
+
+ if self.cfg.MODEL.TRANSFORMER.ENABLED:
+ # weather use the transformer with more tokens.
+ DecoderT = Decoder if self.cfg.MODEL.TRANSFORMER.LARGE else DecoderDS
+ self.feat_decoder = DecoderT(feature_info, self.cfg.MODEL.TRANSFORMER.CHANNELS)
+ if self.cfg.MODEL.TRANSFORMER.BEV:
+ self.feat_decoder = Decoder(feature_info, self.cfg.MODEL.TRANSFORMER.CHANNELS)
+ # Frustum pooling
+ bev_downsample = cfg.BEV.FEATURE_DOWNSAMPLE
+ self.frustum_pooling = FrustumPooling(
+ size=(cfg.BEV.SIZE[0] // bev_downsample, cfg.BEV.SIZE[1] // bev_downsample),
+ scale=cfg.BEV.RESOLUTION * bev_downsample,
+ offsetx=cfg.BEV.OFFSET_FORWARD / bev_downsample,
+ dbound=cfg.BEV.FRUSTUM_POOL.D_BOUND,
+ downsample=8,
+ )
+
+ # mono depth head
+ self.depth_decoder = Decoder(feature_info, self.cfg.MODEL.TRANSFORMER.CHANNELS)
+ self.depth = nn.Conv2d(self.depth_decoder.out_channels, self.frustum_pooling.D, kernel_size=1)
+ # only lift argmax of depth distribution for speed
+ self.sparse_depth = cfg.BEV.FRUSTUM_POOL.SPARSE
+ self.sparse_depth_count = cfg.BEV.FRUSTUM_POOL.SPARSE_COUNT
+ if not self.cfg.MODEL.TRANSFORMER.LARGE:
+ # Down-sampling
+ # self.bev_down_sample_4 = nn.MaxPool2d(4)
+ bev_out_channels = self.cfg.MODEL.TRANSFORMER.CHANNELS
+ self.bev_down_sample_4 = nn.Sequential(
+ nn.Conv2d(bev_out_channels, 512, kernel_size=5, stride=2, padding=2),
+ nn.ReLU(),
+ nn.Conv2d(512, bev_out_channels, kernel_size=5, stride=2, padding=2),
+ )
+
+ if self.cfg.MODEL.LIDAR.ENABLED:
+ if self.cfg.MODEL.LIDAR.POINT_PILLAR.ENABLED:
+ # Point-Pillar net
+ self.point_pillars = PointPillarNet(
+ num_input=8,
+ num_features=[32, 32],
+ min_x=-48,
+ max_x=48,
+ min_y=-48,
+ max_y=48,
+ pixels_per_meter=5)
+ # encoder for point-pillar features
+ self.point_pillar_encoder = timm.create_model(
+ cfg.MODEL.LIDAR.ENCODER, pretrained=True, features_only=True, out_indices=[2, 3, 4], in_chans=32
+ )
+ point_pillar_feature_info = \
+ self.point_pillar_encoder.feature_info.get_dicts(keys=['num_chs', 'reduction'])
+ self.point_pillar_decoder = DecoderT(point_pillar_feature_info, self.cfg.MODEL.TRANSFORMER.CHANNELS)
+ else:
+ # range-view pcd encoder
+ self.range_view_encoder = timm.create_model(
+ cfg.MODEL.LIDAR.ENCODER, pretrained=True, features_only=True, out_indices=[2, 3, 4], in_chans=4
+ )
+ range_view_feature_info = self.range_view_encoder.feature_info.get_dicts(keys=['num_chs', 'reduction'])
+ self.range_view_decoder = DecoderT(range_view_feature_info, self.cfg.MODEL.TRANSFORMER.CHANNELS)
+
+ # 2d sinuous positional embedding
+ self.position_encode = PositionEmbeddingSine(
+ num_pos_feats=self.cfg.MODEL.TRANSFORMER.CHANNELS // 2,
+ normalize=True)
+
+ # sensor type embedding
+ self.type_embedding = nn.Parameter(torch.zeros(1, 1, self.cfg.MODEL.TRANSFORMER.CHANNELS, 2))
+
+ # transformer encoder
+ self.encoder_layer = nn.TransformerEncoderLayer(
+ d_model=self.cfg.MODEL.TRANSFORMER.CHANNELS,
+ nhead=8,
+ dropout=0.1,
+ )
+ self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=6)
+
+ # compress sensor features to 1D
+ self.image_feature_conv = nn.Sequential(
+ BasicBlock(self.cfg.MODEL.TRANSFORMER.CHANNELS, embedding_n_channels, stride=2, downsample=True),
+ BasicBlock(embedding_n_channels, embedding_n_channels),
+ nn.AdaptiveAvgPool2d(output_size=(1, 1)),
+ nn.Flatten(start_dim=1),
+ )
+ self.lidar_feature_conv = nn.Sequential(
+ BasicBlock(self.cfg.MODEL.TRANSFORMER.CHANNELS, embedding_n_channels, stride=2, downsample=True),
+ BasicBlock(embedding_n_channels, embedding_n_channels),
+ nn.AdaptiveAvgPool2d(output_size=(1, 1)),
+ nn.Flatten(start_dim=1),
+ )
+ feature_n_channels = 2 * embedding_n_channels
+
+ # Route map
+ if self.cfg.MODEL.ROUTE.ENABLED:
+ self.backbone_route = RouteEncode(self.cfg.MODEL.ROUTE.CHANNELS, cfg.MODEL.ROUTE.BACKBONE)
+ feature_n_channels += self.cfg.MODEL.ROUTE.CHANNELS
+
+ # Measurements
+ if self.cfg.MODEL.MEASUREMENTS.ENABLED:
+ self.command_encoder = nn.Sequential(
+ nn.Embedding(6, self.cfg.MODEL.MEASUREMENTS.COMMAND_CHANNELS),
+ nn.Linear(self.cfg.MODEL.MEASUREMENTS.COMMAND_CHANNELS, self.cfg.MODEL.MEASUREMENTS.COMMAND_CHANNELS),
+ nn.ReLU(True),
+ nn.Linear(self.cfg.MODEL.MEASUREMENTS.COMMAND_CHANNELS, self.cfg.MODEL.MEASUREMENTS.COMMAND_CHANNELS),
+ nn.ReLU(True),
+ )
+
+ self.command_next_encoder = nn.Sequential(
+ nn.Embedding(6, self.cfg.MODEL.MEASUREMENTS.COMMAND_CHANNELS),
+ nn.Linear(self.cfg.MODEL.MEASUREMENTS.COMMAND_CHANNELS, self.cfg.MODEL.MEASUREMENTS.COMMAND_CHANNELS),
+ nn.ReLU(True),
+ nn.Linear(self.cfg.MODEL.MEASUREMENTS.COMMAND_CHANNELS, self.cfg.MODEL.MEASUREMENTS.COMMAND_CHANNELS),
+ nn.ReLU(True),
+ )
+
+ self.gps_encoder = nn.Sequential(
+ nn.Linear(2*2, self.cfg.MODEL.MEASUREMENTS.GPS_CHANNELS),
+ nn.ReLU(True),
+ nn.Linear(self.cfg.MODEL.MEASUREMENTS.GPS_CHANNELS, self.cfg.MODEL.MEASUREMENTS.GPS_CHANNELS),
+ nn.ReLU(True),
+ )
+ feature_n_channels += 2 * self.cfg.MODEL.MEASUREMENTS.COMMAND_CHANNELS
+ feature_n_channels += self.cfg.MODEL.MEASUREMENTS.GPS_CHANNELS
+
+ # Speed as input
+ self.speed_enc = nn.Sequential(
+ nn.Linear(1, cfg.MODEL.SPEED.CHANNELS),
+ nn.ReLU(True),
+ nn.Linear(cfg.MODEL.SPEED.CHANNELS, cfg.MODEL.SPEED.CHANNELS),
+ nn.ReLU(True),
+ )
+ feature_n_channels += cfg.MODEL.SPEED.CHANNELS
+ self.speed_normalisation = cfg.SPEED.NORMALISATION
+
+ # fuse all features together
+ self.features_combine = nn.Linear(feature_n_channels, embedding_n_channels)
+
+ else:
+ self.feat_decoder = Decoder(feature_info, self.cfg.MODEL.ENCODER.OUT_CHANNELS)
+ if not self.cfg.EVAL.NO_LIFTING:
+ # Frustum pooling
+ bev_downsample = cfg.BEV.FEATURE_DOWNSAMPLE
+ self.frustum_pooling = FrustumPooling(
+ size=(cfg.BEV.SIZE[0] // bev_downsample, cfg.BEV.SIZE[1] // bev_downsample),
+ scale=cfg.BEV.RESOLUTION * bev_downsample,
+ offsetx=cfg.BEV.OFFSET_FORWARD / bev_downsample,
+ dbound=cfg.BEV.FRUSTUM_POOL.D_BOUND,
+ downsample=8,
+ )
+
+ # mono depth head
+ self.depth_decoder = Decoder(feature_info, self.cfg.MODEL.ENCODER.OUT_CHANNELS)
+ self.depth = nn.Conv2d(self.depth_decoder.out_channels, self.frustum_pooling.D, kernel_size=1)
+ # only lift argmax of depth distribution for speed
+ self.sparse_depth = cfg.BEV.FRUSTUM_POOL.SPARSE
+ self.sparse_depth_count = cfg.BEV.FRUSTUM_POOL.SPARSE_COUNT
+
+ backbone_bev_in_channels = self.cfg.MODEL.ENCODER.OUT_CHANNELS
+
+
+ # Route map
+ if self.cfg.MODEL.ROUTE.ENABLED:
+ self.backbone_route = RouteEncode(cfg.MODEL.ROUTE.CHANNELS, cfg.MODEL.ROUTE.BACKBONE)
+ backbone_bev_in_channels += self.backbone_route.out_channels
+
+ # Measurements
+ if self.cfg.MODEL.MEASUREMENTS.ENABLED:
+ self.command_encoder = nn.Sequential(
+ nn.Embedding(6, self.cfg.MODEL.MEASUREMENTS.COMMAND_CHANNELS),
+ nn.Linear(self.cfg.MODEL.MEASUREMENTS.COMMAND_CHANNELS, self.cfg.MODEL.MEASUREMENTS.COMMAND_CHANNELS),
+ nn.ReLU(True),
+ nn.Linear(self.cfg.MODEL.MEASUREMENTS.COMMAND_CHANNELS, self.cfg.MODEL.MEASUREMENTS.COMMAND_CHANNELS),
+ nn.ReLU(True),
+ )
+
+ self.command_next_encoder = nn.Sequential(
+ nn.Embedding(6, self.cfg.MODEL.MEASUREMENTS.COMMAND_CHANNELS),
+ nn.Linear(self.cfg.MODEL.MEASUREMENTS.COMMAND_CHANNELS, self.cfg.MODEL.MEASUREMENTS.COMMAND_CHANNELS),
+ nn.ReLU(True),
+ nn.Linear(self.cfg.MODEL.MEASUREMENTS.COMMAND_CHANNELS, self.cfg.MODEL.MEASUREMENTS.COMMAND_CHANNELS),
+ nn.ReLU(True),
+ )
+
+ self.gps_encoder = nn.Sequential(
+ nn.Linear(2*2, self.cfg.MODEL.MEASUREMENTS.GPS_CHANNELS),
+ nn.ReLU(True),
+ nn.Linear(self.cfg.MODEL.MEASUREMENTS.GPS_CHANNELS, self.cfg.MODEL.MEASUREMENTS.GPS_CHANNELS),
+ nn.ReLU(True),
+ )
+
+ backbone_bev_in_channels += 2*self.cfg.MODEL.MEASUREMENTS.COMMAND_CHANNELS
+ backbone_bev_in_channels += self.cfg.MODEL.MEASUREMENTS.GPS_CHANNELS
+
+ # Speed as input
+ self.speed_enc = nn.Sequential(
+ nn.Linear(1, cfg.MODEL.SPEED.CHANNELS),
+ nn.ReLU(True),
+ nn.Linear(cfg.MODEL.SPEED.CHANNELS, cfg.MODEL.SPEED.CHANNELS),
+ nn.ReLU(True),
+ )
+ backbone_bev_in_channels += cfg.MODEL.SPEED.CHANNELS
+ self.speed_normalisation = cfg.SPEED.NORMALISATION
+
+ embedding_n_channels = self.cfg.MODEL.EMBEDDING_DIM
+
+ if self.cfg.MODEL.LIDAR.ENABLED:
+ # self.lidar_encoder_xy = timm.create_model(
+ # cfg.MODEL.LIDAR.ENCODER, pretrained=True, features_only=True, out_indices=[2, 3, 4], in_chans=4
+ # )
+ # lidar_feature_info_xy = self.lidar_encoder_xy.feature_info.get_dicts(keys=['num_chs', 'reduction'])
+ # self.lidar_decoder_xy = Decoder(lidar_feature_info_xy, self.cfg.MODEL.LIDAR.OUT_CHANNELS)
+ # backbone_bev_in_channels += self.cfg.MODEL.LIDAR.OUT_CHANNELS
+ if self.cfg.MODEL.LIDAR.POINT_PILLAR.ENABLED:
+ self.point_pillars = PointPillarNet(
+ num_input=8,
+ num_features=[32, 32],
+ min_x=-48,
+ max_x=48,
+ min_y=-48,
+ max_y=48,
+ pixels_per_meter=5)
+ self.point_pillar_encoder = timm.create_model(
+ cfg.MODEL.LIDAR.ENCODER, pretrained=True, features_only=True, out_indices=[2, 3, 4], in_chans=32
+ )
+ point_pillar_feature_info = \
+ self.point_pillar_encoder.feature_info.get_dicts(keys=['num_chs', 'reduction'])
+ self.point_pillar_decoder = Decoder(point_pillar_feature_info, self.cfg.MODEL.LIDAR.OUT_CHANNELS)
+ else:
+ self.range_view_encoder = timm.create_model(
+ cfg.MODEL.LIDAR.ENCODER, pretrained=True, features_only=True, out_indices=[2, 3, 4], in_chans=4
+ )
+ range_view_feature_info = self.range_view_encoder.feature_info.get_dicts(keys=['num_chs', 'reduction'])
+ self.range_view_decoder = Decoder(range_view_feature_info, self.cfg.MODEL.LIDAR.OUT_CHANNELS)
+ self.lidar_state_conv = nn.Sequential(
+ BasicBlock(self.cfg.MODEL.LIDAR.OUT_CHANNELS, embedding_n_channels, stride=2, downsample=True),
+ BasicBlock(embedding_n_channels, embedding_n_channels, stride=2, downsample=True),
+ nn.AdaptiveAvgPool2d(output_size=(1, 1)),
+ nn.Flatten(start_dim=1),
+ )
+
+ self.embedding_combine = nn.Linear(2 * embedding_n_channels, embedding_n_channels)
+
+ # Bev network
+ self.backbone_bev = timm.create_model(
+ cfg.MODEL.BEV.BACKBONE,
+ in_chans=backbone_bev_in_channels,
+ pretrained=True,
+ features_only=True,
+ out_indices=[3],
+ )
+ feature_info_bev = self.backbone_bev.feature_info.get_dicts(keys=['num_chs', 'reduction'])
+ self.final_state_conv = nn.Sequential(
+ BasicBlock(feature_info_bev[-1]['num_chs'], embedding_n_channels, stride=2, downsample=True),
+ BasicBlock(embedding_n_channels, embedding_n_channels),
+ nn.AdaptiveAvgPool2d(output_size=(1, 1)),
+ nn.Flatten(start_dim=1),
+ )
+
+ # Recurrent model
+ self.receptive_field = self.cfg.RECEPTIVE_FIELD
+ if self.cfg.MODEL.TRANSITION.ENABLED:
+ # Recurrent state sequence module
+ self.rssm = RSSM(
+ embedding_dim=embedding_n_channels,
+ action_dim=self.cfg.MODEL.ACTION_DIM,
+ hidden_state_dim=self.cfg.MODEL.TRANSITION.HIDDEN_STATE_DIM,
+ state_dim=self.cfg.MODEL.TRANSITION.STATE_DIM,
+ action_latent_dim=self.cfg.MODEL.TRANSITION.ACTION_LATENT_DIM,
+ receptive_field=self.receptive_field,
+ use_dropout=self.cfg.MODEL.TRANSITION.USE_DROPOUT,
+ dropout_probability=self.cfg.MODEL.TRANSITION.DROPOUT_PROBABILITY,
+ )
+
+ # Policy
+ if self.cfg.MODEL.TRANSITION.ENABLED:
+ state_dim = self.cfg.MODEL.TRANSITION.HIDDEN_STATE_DIM + self.cfg.MODEL.TRANSITION.STATE_DIM
+ else:
+ state_dim = embedding_n_channels
+ self.policy = Policy(in_channels=state_dim)
+
+ # Bird's-eye view semantic segmentation
+ if self.cfg.SEMANTIC_SEG.ENABLED:
+ self.bev_decoder = BevDecoder(
+ latent_n_channels=state_dim,
+ semantic_n_channels=self.cfg.SEMANTIC_SEG.N_CHANNELS,
+ head='bev',
+ )
+
+ # RGB reconstruction
+ if self.cfg.EVAL.RGB_SUPERVISION:
+ # self.rgb_decoder = BevDecoder(
+ # latent_n_channels=state_dim,
+ # semantic_n_channels=3,
+ # constant_size=(5, 13),
+ # head='rgb',
+ # )
+ self.rgb_decoder = ConvDecoder(
+ latent_n_channels=state_dim,
+ out_channels=3,
+ constant_size=(5, 13),
+ head='rgb'
+ )
+
+ # lidar reconstruction in range-view
+ if self.cfg.LIDAR_RE.ENABLED:
+ self.lidar_re = ConvDecoder(
+ latent_n_channels=state_dim,
+ out_channels=self.cfg.LIDAR_RE.N_CHANNELS,
+ constant_size=(1, 16),
+ head='lidar_re',
+ )
+
+ # lidar semantic segmentation
+ if self.cfg.LIDAR_SEG.ENABLED:
+ self.lidar_segmentation = ConvDecoder(
+ latent_n_channels=state_dim,
+ out_channels=self.cfg.LIDAR_SEG.N_CLASSES,
+ constant_size=(1, 16),
+ head='lidar_seg',
+ )
+
+ # camera semantic segmentation
+ if self.cfg.SEMANTIC_IMAGE.ENABLED:
+ self.sem_image_decoder = ConvDecoder(
+ latent_n_channels=state_dim,
+ out_channels=self.cfg.SEMANTIC_IMAGE.N_CLASSES,
+ constant_size=(5, 13),
+ head='sem_image',
+ )
+
+ # depth camera prediction
+ if self.cfg.DEPTH.ENABLED:
+ self.depth_image_decoder = ConvDecoder(
+ latent_n_channels=state_dim,
+ out_channels=1,
+ constant_size=(5, 13),
+ head='depth',
+ )
+
+ # Voxel reconstruction
+ if self.cfg.VOXEL_SEG.ENABLED:
+ # self.voxel_feature_xy_decoder = BevDecoder(
+ # latent_n_channels=state_dim,
+ # semantic_n_channels=self.cfg.VOXEL_SEG.DIMENSION,
+ # constant_size=(3, 3),
+ # is_segmentation=False,
+ # )
+ # self.voxel_feature_xz_decoder = BevDecoder(
+ # latent_n_channels=state_dim,
+ # semantic_n_channels=self.cfg.VOXEL_SEG.DIMENSION,
+ # constant_size=(3, 1),
+ # is_segmentation=False,
+ # )
+ # self.voxel_feature_yz_decoder = BevDecoder(
+ # latent_n_channels=state_dim,
+ # semantic_n_channels=self.cfg.VOXEL_SEG.DIMENSION,
+ # constant_size=(3, 1),
+ # is_segmentation=False,
+ # )
+ # self.voxel_decoder = VoxelDecoder0(
+ # input_channels=self.cfg.VOXEL_SEG.DIMENSION,
+ # n_classes=self.cfg.VOXEL_SEG.N_CLASSES,
+ # kernel_size=1,
+ # feature_channels=self.cfg.VOXEL_SEG.DIMENSION,
+ # )
+ self.voxel_decoder = VoxelDecoder1(
+ latent_n_channels=state_dim,
+ semantic_n_channels=self.cfg.VOXEL_SEG.N_CLASSES,
+ feature_channels=self.cfg.VOXEL_SEG.DIMENSION,
+ constant_size=(3, 3, 1),
+ )
+
+ # Used during deployment to save last state
+ self.last_h = None
+ self.last_sample = None
+ self.last_action = None
+ self.count = 0
+
+ def forward(self, batch, deployment=False):
+ """
+ Parameters
+ ----------
+ batch: dict of torch.Tensor
+ keys:
+ image: (b, s, 3, h, w)
+ route_map: (b, s, 3, h_r, w_r)
+ speed: (b, s, 1)
+ intrinsics: (b, s, 3, 3)
+ extrinsics: (b, s, 4, 4)
+ throttle_brake: (b, s, 1)
+ steering: (b, s, 1)
+ """
+ # Encode RGB images, route_map, speed using intrinsics and extrinsics
+ # to a 512 dimensional vector
+ embedding = self.encode(batch)
+ b, s = batch['image'].shape[:2]
+
+ output = dict()
+ if self.cfg.MODEL.TRANSITION.ENABLED:
+ # Recurrent state sequence module
+ if deployment:
+ action = batch['action']
+ else:
+ action = torch.cat([batch['throttle_brake'], batch['steering']], dim=-1)
+ state_dict = self.rssm(embedding, action, use_sample=not deployment, policy=self.policy)
+
+ if deployment:
+ state_dict = remove_past(state_dict, s)
+ s = 1
+
+ output = {**output, **state_dict}
+ state = torch.cat([state_dict['posterior']['hidden_state'], state_dict['posterior']['sample']], dim=-1)
+ else:
+ state = embedding
+ state_dict = {}
+
+ state = pack_sequence_dim(state)
+ output_policy = self.policy(state)
+ throttle_brake, steering = torch.split(output_policy, 1, dim=-1)
+ output['throttle_brake'] = unpack_sequence_dim(throttle_brake, b, s)
+ output['steering'] = unpack_sequence_dim(steering, b, s)
+
+ # reconstruction
+ if self.cfg.SEMANTIC_SEG.ENABLED:
+ if (not deployment) or (deployment and DISPLAY_SEGMENTATION):
+ bev_decoder_output = self.bev_decoder(state)
+ bev_decoder_output = unpack_sequence_dim(bev_decoder_output, b, s)
+ output = {**output, **bev_decoder_output}
+
+ if self.cfg.EVAL.RGB_SUPERVISION:
+ rgb_decoder_output = self.rgb_decoder(state)
+ rgb_decoder_output = unpack_sequence_dim(rgb_decoder_output, b, s)
+ output = {**output, **rgb_decoder_output}
+
+ if self.cfg.LIDAR_RE.ENABLED:
+ lidar_output = self.lidar_re(state)
+ lidar_output = unpack_sequence_dim(lidar_output, b, s)
+ output = {**output, **lidar_output}
+
+ if self.cfg.LIDAR_SEG.ENABLED:
+ lidar_seg_output = self.lidar_segmentation(state)
+ lidar_seg_output = unpack_sequence_dim(lidar_seg_output, b, s)
+ output = {**output, **lidar_seg_output}
+
+ if self.cfg.SEMANTIC_IMAGE.ENABLED:
+ sem_image_output = self.sem_image_decoder(state)
+ sem_image_output = unpack_sequence_dim(sem_image_output, b, s)
+ output = {**output, **sem_image_output}
+
+ if self.cfg.DEPTH.ENABLED:
+ depth_image_output = self.depth_image_decoder(state)
+ depth_image_output = unpack_sequence_dim(depth_image_output, b, s)
+ output = {**output, **depth_image_output}
+
+ if self.cfg.VOXEL_SEG.ENABLED:
+ # voxel_feature_xy = self.voxel_feature_xy_decoder(state)
+ # voxel_feature_xz = self.voxel_feature_xz_decoder(state)
+ # voxel_feature_yz = self.voxel_feature_yz_decoder(state)
+ # voxel_decoder_output = self.voxel_decoder(voxel_feature_xy, voxel_feature_xz, voxel_feature_yz)
+ voxel_decoder_output = self.voxel_decoder(state)
+ voxel_decoder_output = unpack_sequence_dim(voxel_decoder_output, b, s)
+ output = {**output, **voxel_decoder_output}
+
+ return output, state_dict
+
+ def encode(self, batch):
+ b, s = batch['image'].shape[:2]
+ image = pack_sequence_dim(batch['image'])
+ speed = pack_sequence_dim(batch['speed'])
+ intrinsics = pack_sequence_dim(batch['intrinsics'])
+ extrinsics = pack_sequence_dim(batch['extrinsics'])
+
+ # Image encoder, multiscale
+ xs = self.encoder(image)
+
+ # Lift features to bird's-eye view.
+ # Aggregate features to output resolution (H/8, W/8)
+ x = self.feat_decoder(xs)
+
+ if self.cfg.MODEL.TRANSFORMER.ENABLED:
+ if self.cfg.MODEL.TRANSFORMER.BEV:
+ # Depth distribution
+ depth = self.depth(self.depth_decoder(xs)).softmax(dim=1)
+
+ if self.sparse_depth:
+ # only lift depth for topk most likely depth bins
+ topk_bins = depth.topk(self.sparse_depth_count, dim=1)[1]
+ depth_mask = torch.zeros(depth.shape, device=depth.device, dtype=torch.bool)
+ depth_mask.scatter_(1, topk_bins, 1)
+ else:
+ depth_mask = torch.zeros(0, device=depth.device)
+ x = (depth.unsqueeze(1) * x.unsqueeze(2)).type_as(x) # outer product
+
+ # Add camera dimension
+ x = x.unsqueeze(1)
+ x = x.permute(0, 1, 3, 4, 5, 2)
+
+ x = self.frustum_pooling(x, intrinsics.unsqueeze(1), extrinsics.unsqueeze(1), depth_mask)
+ if not self.cfg.MODEL.TRANSFORMER.LARGE:
+ x = self.bev_down_sample_4(x)
+
+ # get lidar features
+ if self.cfg.MODEL.LIDAR.POINT_PILLAR.ENABLED:
+ lidar_list = pack_sequence_dim(batch['points_raw'])
+ num_points = pack_sequence_dim(batch['num_points'])
+ pp_features = self.point_pillars(lidar_list, num_points)
+ pp_xs = self.point_pillar_encoder(pp_features)
+ lidar_features = self.point_pillar_decoder(pp_xs)
+ else:
+ range_view = pack_sequence_dim(batch['range_view_pcd_xyzd'])
+ lidar_xs = self.range_view_encoder(range_view)
+ lidar_features = self.range_view_decoder(lidar_xs)
+ bs_image, _, h_image, w_image = x.shape
+ bs_lidar, _, h_lidar, w_lidar = lidar_features.shape
+
+ # add position embedding
+ image_tokens = x + self.position_encode(x)
+ lidar_tokens = lidar_features + self.position_encode(lidar_features)
+
+ # flatten features
+ image_tokens = image_tokens.flatten(start_dim=2).permute(2, 0, 1) # B, C, W, H -> N, B, C
+ lidar_tokens = lidar_tokens.flatten(start_dim=2).permute(2, 0, 1)
+
+ # add sensor type embedding
+ image_tokens += self.type_embedding[:, :, :, 0]
+ lidar_tokens += self.type_embedding[:, :, :, 1]
+
+ L_image, _, _ = image_tokens.shape
+ L_lidar, _, _ = lidar_tokens.shape
+
+ # concatenate image and lidar tokens
+ tokens = torch.cat([image_tokens, lidar_tokens], dim=0)
+ tokens_out = self.transformer_encoder(tokens)
+ # separate image and lidar tokens and reshape to original shape
+ image_tokens_out = tokens_out[:L_image].permute(1, 2, 0).reshape((bs_image, -1, h_image, w_image))
+ lidar_tokens_out = tokens_out[L_image:].permute(1, 2, 0).reshape((bs_lidar, -1, h_lidar, w_lidar))
+
+ # compress to 1D
+ image_features_out = self.image_feature_conv(image_tokens_out)
+ lidar_features_out = self.lidar_feature_conv(lidar_tokens_out)
+
+ features = [image_features_out, lidar_features_out]
+
+ # get other features
+ if self.cfg.MODEL.ROUTE.ENABLED:
+ route_map = pack_sequence_dim(batch['route_map'])
+ route_map_features = self.backbone_route(route_map)
+ features.append(route_map_features)
+
+ if self.cfg.MODEL.MEASUREMENTS.ENABLED:
+ route_command = pack_sequence_dim(batch['route_command'])
+ gps_vector = pack_sequence_dim(batch['gps_vector'])
+ route_command_next = pack_sequence_dim(batch['route_command_next'])
+ gps_vector_next = pack_sequence_dim(batch['gps_vector_next'])
+
+ command_features = self.command_encoder(route_command)
+ features.append(command_features)
+
+ command_next_features = self.command_next_encoder(route_command_next)
+ features.append(command_next_features)
+
+ gps_features = self.gps_encoder(torch.cat([gps_vector, gps_vector_next], dim=-1))
+ features.append(gps_features)
+
+ speed_features = self.speed_enc(speed / self.speed_normalisation)
+ features.append(speed_features)
+
+ embedding = self.features_combine(torch.cat(features, dim=-1))
+
+ else:
+ if not self.cfg.EVAL.NO_LIFTING:
+ # Depth distribution
+ depth = self.depth(self.depth_decoder(xs)).softmax(dim=1)
+
+ if self.sparse_depth:
+ # only lift depth for topk most likely depth bins
+ topk_bins = depth.topk(self.sparse_depth_count, dim=1)[1]
+ depth_mask = torch.zeros(depth.shape, device=depth.device, dtype=torch.bool)
+ depth_mask.scatter_(1, topk_bins, 1)
+ else:
+ depth_mask = torch.zeros(0, device=depth.device)
+ x = (depth.unsqueeze(1) * x.unsqueeze(2)).type_as(x) # outer product
+
+ # Add camera dimension
+ x = x.unsqueeze(1)
+ x = x.permute(0, 1, 3, 4, 5, 2)
+
+ x = self.frustum_pooling(x, intrinsics.unsqueeze(1), extrinsics.unsqueeze(1), depth_mask)
+
+ if self.cfg.MODEL.ROUTE.ENABLED:
+ route_map = pack_sequence_dim(batch['route_map'])
+ route_map_features = self.backbone_route(route_map)
+ route_map_features = route_map_features.unsqueeze(2).unsqueeze(3).expand(-1, -1, x.shape[2], x.shape[3])
+ x = torch.cat([x, route_map_features], dim=1)
+
+ if self.cfg.MODEL.MEASUREMENTS.ENABLED:
+ route_command = pack_sequence_dim(batch['route_command'])
+ gps_vector = pack_sequence_dim(batch['gps_vector'])
+ route_command_next = pack_sequence_dim(batch['route_command_next'])
+ gps_vector_next = pack_sequence_dim(batch['gps_vector_next'])
+
+ command_features = self.command_encoder(route_command)
+ command_features = command_features.unsqueeze(2).unsqueeze(3).expand(-1, -1, x.shape[2], x.shape[3])
+ x = torch.cat([x, command_features], dim=1)
+
+ command_next_features = self.command_next_encoder(route_command_next)
+ command_next_features = command_next_features.unsqueeze(2).unsqueeze(3).expand(-1, -1, x.shape[2], x.shape[3])
+ x = torch.cat([x, command_next_features], dim=1)
+
+ gps_features = self.gps_encoder(torch.cat([gps_vector, gps_vector_next], dim=-1))
+ gps_features = gps_features.unsqueeze(2).unsqueeze(3).expand(-1, -1, x.shape[2], x.shape[3])
+ x = torch.cat([x, gps_features], dim=1)
+
+ speed_features = self.speed_enc(speed / self.speed_normalisation)
+ speed_features = speed_features.unsqueeze(2).unsqueeze(3).expand(-1, -1, x.shape[2], x.shape[3])
+ x = torch.cat((x, speed_features), 1)
+
+ embedding = self.backbone_bev(x)[-1]
+ embedding = self.final_state_conv(embedding)
+
+ if self.cfg.MODEL.LIDAR.ENABLED:
+ # points_histogram_xy = pack_sequence_dim(batch['points_histogram_xy'])
+ # xs_lidar_xy = self.lidar_encoder_xy(points_histogram_xy)
+ # lidar_features_xy = self.lidar_decoder_xy(xs_lidar_xy)
+ # x = torch.cat([x, lidar_features_xy], dim=1)
+ if self.cfg.MODEL.LIDAR.POINT_PILLAR.ENABLED:
+ lidar_list = pack_sequence_dim(batch['points_raw'])
+ num_points = pack_sequence_dim(batch['num_points'])
+ pp_features = self.point_pillars(lidar_list, num_points)
+ pp_xs = self.point_pillar_encoder(pp_features)
+ lidar_features = self.point_pillar_decoder(pp_xs)
+ else:
+ range_view = pack_sequence_dim(batch['range_view_pcd_xyzd'])
+ lidar_xs = self.range_view_encoder(range_view)
+ lidar_features = self.range_view_decoder(lidar_xs)
+ lidar_embedding = self.lidar_state_conv(lidar_features)
+ # embedding = (lidar_embedding + embedding) / 2
+ embedding = self.embedding_combine(torch.cat([embedding, lidar_embedding], dim=-1))
+
+ # if self.cfg.MODEL.LIDAR.MULTI_VIEW:
+ # points_histogram_xz = pack_sequence_dim(batch['points_histogram_xz'])
+ # xs_lidar_xz = self.lidar_encoder_xz(points_histogram_xz)
+ # lidar_features_xz = self.lidar_decoder_xz(xs_lidar_xz)
+ # embedding_xz = self.backbone_lidar_xz(lidar_features_xz)[-1]
+ # embedding_xz = self.state_conv_xz(embedding_xz)
+ #
+ # points_histogram_yz = pack_sequence_dim(batch['points_histogram_yz'])
+ # xs_lidar_yz = self.lidar_encoder_yz(points_histogram_yz)
+ # lidar_features_yz = self.lidar_decoder_yz(xs_lidar_yz)
+ # embedding_yz = self.backbone_lidar_xz(lidar_features_yz)[-1]
+ # embedding_yz = self.state_conv_yz(embedding_yz)
+ #
+ # embedding = torch.cat([embedding, embedding_xz, embedding_yz], dim=-1)
+ # embedding = self.embedding_combine(embedding)
+
+ embedding = unpack_sequence_dim(embedding, b, s)
+ return embedding
+
+ def observe_and_imagine(self, batch, predict_action=False, future_horizon=None):
+ """ This is only used for visualisation of future prediction"""
+ assert self.cfg.MODEL.TRANSITION.ENABLED and self.cfg.SEMANTIC_SEG.ENABLED
+ if future_horizon is None:
+ future_horizon = self.cfg.FUTURE_HORIZON
+
+ # b, s = batch['image'].shape[:2]
+ b = batch['image'].shape[0]
+ s = self.cfg.RECEPTIVE_FIELD
+
+ if not predict_action:
+ assert batch['throttle_brake'].shape[1] == s + future_horizon
+ assert batch['steering'].shape[1] == s + future_horizon
+
+ # Observe past context
+ output_observe = self.forward({key: value[:, :s] for key, value in batch.items()})
+
+ # Imagine future states
+ output_imagine = {
+ 'action': [],
+ 'state': [],
+ 'hidden': [],
+ 'sample': [],
+ }
+ h_t = output_observe['posterior']['hidden_state'][:, -1]
+ sample_t = output_observe['posterior']['sample'][:, -1]
+ for t in range(future_horizon):
+ if predict_action:
+ action_t = self.policy(torch.cat([h_t, sample_t], dim=-1))
+ else:
+ action_t = torch.cat([batch['throttle_brake'][:, s+t], batch['steering'][:, s+t]], dim=-1)
+ prior_t = self.rssm.imagine_step(
+ h_t, sample_t, action_t, use_sample=True, policy=self.policy,
+ )
+ sample_t = prior_t['sample']
+ h_t = prior_t['hidden_state']
+ output_imagine['action'].append(action_t)
+ output_imagine['state'].append(torch.cat([h_t, sample_t], dim=-1))
+ output_imagine['hidden'].append(h_t)
+ output_imagine['sample'].append(sample_t)
+
+ for k, v in output_imagine.items():
+ output_imagine[k] = torch.stack(v, dim=1)
+
+ state = pack_sequence_dim(output_imagine['state'])
+
+ if self.cfg.SEMANTIC_SEG.ENABLED:
+ bev_decoder_output = self.bev_decoder(pack_sequence_dim(output_imagine['state']))
+ bev_decoder_output = unpack_sequence_dim(bev_decoder_output, b, future_horizon)
+ output_imagine = {**output_imagine, **bev_decoder_output}
+
+ if self.cfg.EVAL.RGB_SUPERVISION:
+ rgb_decoder_output = self.rgb_decoder(state)
+ rgb_decoder_output = unpack_sequence_dim(rgb_decoder_output, b, future_horizon)
+ output_imagine = {**output_imagine, **rgb_decoder_output}
+
+ if self.cfg.LIDAR_RE.ENABLED:
+ lidar_output = self.lidar_re(state)
+ lidar_output = unpack_sequence_dim(lidar_output, b, future_horizon)
+ output_imagine = {**output_imagine, **lidar_output}
+
+ if self.cfg.LIDAR_SEG.ENABLED:
+ lidar_seg_output = self.lidar_segmentation(state)
+ lidar_seg_output = unpack_sequence_dim(lidar_seg_output, b, future_horizon)
+ output_imagine = {**output_imagine, **lidar_seg_output}
+
+ if self.cfg.SEMANTIC_IMAGE.ENABLED:
+ sem_image_output = self.sem_image_decoder(state)
+ sem_image_output = unpack_sequence_dim(sem_image_output, b, future_horizon)
+ output_imagine = {**output_imagine, **sem_image_output}
+
+ if self.cfg.DEPTH.ENABLED:
+ depth_image_output = self.depth_image_decoder(state)
+ depth_image_output = unpack_sequence_dim(depth_image_output, b, future_horizon)
+ output_imagine = {**output_imagine, **depth_image_output}
+
+ if self.cfg.VOXEL_SEG.ENABLED:
+ # voxel_feature_xy = self.voxel_feature_xy_decoder(state)
+ # voxel_feature_xz = self.voxel_feature_xz_decoder(state)
+ # voxel_feature_yz = self.voxel_feature_yz_decoder(state)
+ # voxel_decoder_output = self.voxel_decoder(voxel_feature_xy, voxel_feature_xz, voxel_feature_yz)
+ voxel_decoder_output = self.voxel_decoder(state)
+ voxel_decoder_output = unpack_sequence_dim(voxel_decoder_output, b, future_horizon)
+ output_imagine = {**output_imagine, **voxel_decoder_output}
+
+ return output_observe, output_imagine
+
+ def imagine(self, batch, predict_action=False, future_horizon=None):
+ """ This is only used for visualisation of future prediction"""
+ assert self.cfg.MODEL.TRANSITION.ENABLED
+ if future_horizon is None:
+ future_horizon = self.cfg.FUTURE_HORIZON
+
+ # Imagine future states
+ output_imagine = {
+ 'action': [],
+ 'state': [],
+ 'hidden': [],
+ 'sample': [],
+ }
+ h_t = batch['hidden_state'] #(b, c)
+ sample_t = batch['sample'] #(b, s)
+ b = h_t.shape[0]
+ for t in range(future_horizon):
+ if predict_action:
+ action_t = self.policy(torch.cat([h_t, sample_t], dim=-1))
+ else:
+ action_t = torch.cat([batch['throttle_brake'][:, t], batch['steering'][:, t]], dim=-1)
+ prior_t = self.rssm.imagine_step(
+ h_t, sample_t, action_t, use_sample=True, policy=self.policy,
+ )
+ sample_t = prior_t['sample']
+ h_t = prior_t['hidden_state']
+ output_imagine['action'].append(action_t)
+ output_imagine['state'].append(torch.cat([h_t, sample_t], dim=-1))
+ output_imagine['hidden'].append(h_t)
+ output_imagine['sample'].append(sample_t)
+
+ for k, v in output_imagine.items():
+ output_imagine[k] = torch.stack(v, dim=1)
+
+ state = pack_sequence_dim(output_imagine['state'])
+ output_policy = self.policy(state)
+ throttle_brake, steering = torch.split(output_policy, 1, dim=-1)
+ output_imagine['throttle_brake'] = unpack_sequence_dim(throttle_brake, b, future_horizon)
+ output_imagine['steering'] = unpack_sequence_dim(steering, b, future_horizon)
+
+ if self.cfg.SEMANTIC_SEG.ENABLED:
+ bev_decoder_output = self.bev_decoder(state)
+ bev_decoder_output = unpack_sequence_dim(bev_decoder_output, b, future_horizon)
+ output_imagine = {**output_imagine, **bev_decoder_output}
+
+ if self.cfg.EVAL.RGB_SUPERVISION:
+ rgb_decoder_output = self.rgb_decoder(state)
+ rgb_decoder_output = unpack_sequence_dim(rgb_decoder_output, b, future_horizon)
+ output_imagine = {**output_imagine, **rgb_decoder_output}
+
+ if self.cfg.LIDAR_RE.ENABLED:
+ lidar_output = self.lidar_re(state)
+ lidar_output = unpack_sequence_dim(lidar_output, b, future_horizon)
+ output_imagine = {**output_imagine, **lidar_output}
+
+ if self.cfg.LIDAR_SEG.ENABLED:
+ lidar_seg_output = self.lidar_segmentation(state)
+ lidar_seg_output = unpack_sequence_dim(lidar_seg_output, b, future_horizon)
+ output_imagine = {**output_imagine, **lidar_seg_output}
+
+ if self.cfg.SEMANTIC_IMAGE.ENABLED:
+ sem_image_output = self.sem_image_decoder(state)
+ sem_image_output = unpack_sequence_dim(sem_image_output, b, future_horizon)
+ output_imagine = {**output_imagine, **sem_image_output}
+
+ if self.cfg.DEPTH.ENABLED:
+ depth_image_output = self.depth_image_decoder(state)
+ depth_image_output = unpack_sequence_dim(depth_image_output, b, future_horizon)
+ output_imagine = {**output_imagine, **depth_image_output}
+
+ if self.cfg.VOXEL_SEG.ENABLED:
+ # voxel_feature_xy = self.voxel_feature_xy_decoder(state)
+ # voxel_feature_xz = self.voxel_feature_xz_decoder(state)
+ # voxel_feature_yz = self.voxel_feature_yz_decoder(state)
+ # voxel_decoder_output = self.voxel_decoder(voxel_feature_xy, voxel_feature_xz, voxel_feature_yz)
+ voxel_decoder_output = self.voxel_decoder(state)
+ voxel_decoder_output = unpack_sequence_dim(voxel_decoder_output, b, future_horizon)
+ output_imagine = {**output_imagine, **voxel_decoder_output}
+
+ return output_imagine
+
+ def deployment_forward(self, batch, is_dreaming):
+ """
+ Keep latent states in memory for fast inference.
+
+ Parameters
+ ----------
+ batch: dict of torch.Tensor
+ keys:
+ image: (b, s, 3, h, w)
+ route_map: (b, s, 3, h_r, w_r)
+ speed: (b, s, 1)
+ intrinsics: (b, s, 3, 3)
+ extrinsics: (b, s, 4, 4)
+ throttle_brake: (b, s, 1)
+ steering: (b, s, 1)
+ """
+ assert self.cfg.MODEL.TRANSITION.ENABLED
+ b = batch['image'].shape[0]
+
+ if self.count == 0:
+ # Encode RGB images, route_map, speed using intrinsics and extrinsics
+ # to a 512 dimensional vector
+ s = batch['image'].shape[1]
+ action_t = batch['action'][:, -2] # action from t-1 to t
+ batch = remove_past(batch, s)
+ embedding_t = self.encode(batch)[:, -1] # dim (b, 1, 512)
+
+ # Recurrent state sequence module
+ if self.last_h is None:
+ h_t = action_t.new_zeros(b, self.cfg.MODEL.TRANSITION.HIDDEN_STATE_DIM)
+ sample_t = action_t.new_zeros(b, self.cfg.MODEL.TRANSITION.STATE_DIM)
+ else:
+ h_t = self.last_h
+ sample_t = self.last_sample
+
+ if is_dreaming:
+ rssm_output = self.rssm.imagine_step(
+ h_t, sample_t, action_t, use_sample=False, policy=self.policy,
+ )
+ else:
+ rssm_output = self.rssm.observe_step(
+ h_t, sample_t, action_t, embedding_t, use_sample=False, policy=self.policy,
+ )['posterior']
+ sample_t = rssm_output['sample']
+ h_t = rssm_output['hidden_state']
+
+ self.last_h = h_t
+ self.last_sample = sample_t
+
+ game_frequency = CARLA_FPS
+ model_stride_sec = self.cfg.DATASET.STRIDE_SEC
+ n_image_per_stride = int(game_frequency * model_stride_sec)
+ self.count = n_image_per_stride - 1
+ else:
+ self.count -= 1
+ s = 1
+ state = torch.cat([self.last_h, self.last_sample], dim=-1)
+ output_policy = self.policy(state)
+ throttle_brake, steering = torch.split(output_policy, 1, dim=-1)
+ output = dict()
+ output['throttle_brake'] = unpack_sequence_dim(throttle_brake, b, s)
+ output['steering'] = unpack_sequence_dim(steering, b, s)
+
+ output['hidden_state'] = self.last_h
+ output['sample'] = self.last_sample
+
+ if self.cfg.SEMANTIC_SEG.ENABLED and DISPLAY_SEGMENTATION:
+ bev_decoder_output = self.bev_decoder(state)
+ bev_decoder_output = unpack_sequence_dim(bev_decoder_output, b, s)
+ output = {**output, **bev_decoder_output}
+
+ return output
+
+ def sim_forward(self, batch, is_dreaming):
+ """
+ Keep latent states in memory for fast inference.
+ simulate 1 real run.
+ """
+ assert self.cfg.MODEL.TRANSITION.ENABLED
+ b = batch['image'].shape[0]
+
+ if self.count == 0:
+ # Encode RGB images, route_map, speed using intrinsics and extrinsics
+ # to a 512 dimensional vector
+ s = self.receptive_field
+ batch = remove_past(batch, s)
+ # action_t = batch['action'][:, 0] # action from t-1 to t
+ action_t = torch.cat([batch['throttle_brake'][:, 0], batch['steering'][:, 0]], dim=-1)
+ embedding_t = self.encode({key: value[:, :1] for key, value in batch.items()})[:, -1] # dim (b, 1, 512)
+
+ if self.last_action is None:
+ action_last = torch.zeros_like(action_t)
+ else:
+ action_last = self.last_action
+
+ # Recurrent state sequence module
+ if self.last_h is None:
+ h_t = action_t.new_zeros(b, self.cfg.MODEL.TRANSITION.HIDDEN_STATE_DIM)
+ sample_t = action_t.new_zeros(b, self.cfg.MODEL.TRANSITION.STATE_DIM)
+ else:
+ h_t = self.last_h
+ sample_t = self.last_sample
+
+ if is_dreaming:
+ rssm_output = self.rssm.imagine_step(
+ h_t, sample_t, action_last, use_sample=False, policy=self.policy,
+ )
+ else:
+ rssm_output = self.rssm.observe_step(
+ h_t, sample_t, action_last, embedding_t, use_sample=False, policy=self.policy,
+ )['posterior']
+ sample_t = rssm_output['sample']
+ h_t = rssm_output['hidden_state']
+
+ self.last_h = h_t
+ self.last_sample = sample_t
+ self.last_action = action_t
+
+ game_frequency = CARLA_FPS
+ model_stride_sec = self.cfg.DATASET.STRIDE_SEC
+ n_image_per_stride = int(game_frequency * model_stride_sec)
+ self.count = n_image_per_stride - 1
+ else:
+ self.count -= 1
+ s = 1
+ state = torch.cat([self.last_h, self.last_sample], dim=-1)
+ output_policy = self.policy(state)
+ throttle_brake, steering = torch.split(output_policy, 1, dim=-1)
+ output = dict()
+ output['throttle_brake'] = unpack_sequence_dim(throttle_brake, b, s)
+ output['steering'] = unpack_sequence_dim(steering, b, s)
+
+ output['hidden_state'] = self.last_h
+ output['sample'] = self.last_sample
+
+ if self.cfg.SEMANTIC_SEG.ENABLED and DISPLAY_SEGMENTATION:
+ bev_decoder_output = self.bev_decoder(state)
+ bev_decoder_output = unpack_sequence_dim(bev_decoder_output, b, s)
+ output = {**output, **bev_decoder_output}
+
+ if self.cfg.EVAL.RGB_SUPERVISION:
+ rgb_decoder_output = self.rgb_decoder(state)
+ rgb_decoder_output = unpack_sequence_dim(rgb_decoder_output, b, s)
+ output = {**output, **rgb_decoder_output}
+
+ if self.cfg.LIDAR_RE.ENABLED:
+ lidar_output = self.lidar_re(state)
+ lidar_output = unpack_sequence_dim(lidar_output, b, s)
+ output = {**output, **lidar_output}
+
+ if self.cfg.LIDAR_SEG.ENABLED:
+ lidar_seg_output = self.lidar_segmentation(state)
+ lidar_seg_output = unpack_sequence_dim(lidar_seg_output, b, s)
+ output = {**output, **lidar_seg_output}
+
+ if self.cfg.SEMANTIC_IMAGE.ENABLED:
+ sem_image_output = self.sem_image_decoder(state)
+ sem_image_output = unpack_sequence_dim(sem_image_output, b, s)
+ output = {**output, **sem_image_output}
+
+ if self.cfg.DEPTH.ENABLED:
+ depth_image_output = self.depth_image_decoder(state)
+ depth_image_output = unpack_sequence_dim(depth_image_output, b, s)
+ output = {**output, **depth_image_output}
+
+ if self.cfg.VOXEL_SEG.ENABLED:
+ # voxel_feature_xy = self.voxel_feature_xy_decoder(state)
+ # voxel_feature_xz = self.voxel_feature_xz_decoder(state)
+ # voxel_feature_yz = self.voxel_feature_yz_decoder(state)
+ # voxel_decoder_output = self.voxel_decoder(voxel_feature_xy, voxel_feature_xz, voxel_feature_yz)
+ voxel_decoder_output = self.voxel_decoder(state)
+ voxel_decoder_output = unpack_sequence_dim(voxel_decoder_output, b, s)
+ output = {**output, **voxel_decoder_output}
+
+ state_imagine = {'hidden_state': self.last_h,
+ 'sample': self.last_sample,
+ 'throttle_brake': batch['throttle_brake'],
+ 'steering': batch['steering']}
+ output_imagine = self.imagine(state_imagine, predict_action=False, future_horizon=batch['image'].shape[1] - 1)
+
+ return output, output_imagine
diff --git a/muvo/models/preprocess.py b/muvo/models/preprocess.py
new file mode 100644
index 0000000..1a16f01
--- /dev/null
+++ b/muvo/models/preprocess.py
@@ -0,0 +1,367 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.transforms as transforms
+import torchvision.transforms.functional as tvf
+# import skimage.transform as skt
+from typing import Dict, Tuple
+
+from muvo.utils.geometry_utils import get_out_of_view_mask
+from muvo.utils.instance_utils import convert_instance_mask_to_center_and_offset_label
+
+
+class PreProcess(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ self.cfg = cfg
+ self.crop = tuple(cfg.IMAGE.CROP)
+ self.route_map_size = cfg.ROUTE.SIZE
+
+ if self.cfg.EVAL.MASK_VIEW:
+ self.bev_out_of_view_mask = get_out_of_view_mask(cfg)
+
+ # Instance label parameters
+ self.center_sigma = cfg.INSTANCE_SEG.CENTER_LABEL_SIGMA_PX
+ self.ignore_index = cfg.INSTANCE_SEG.IGNORE_INDEX
+
+ self.min_depth = cfg.BEV.FRUSTUM_POOL.D_BOUND[0]
+ self.max_depth = cfg.BEV.FRUSTUM_POOL.D_BOUND[1]
+
+ self.pixel_augmentation = PixelAugmentation(cfg)
+ self.route_augmentation = RouteAugmentation(
+ cfg.ROUTE.AUGMENTATION_DROPOUT,
+ cfg.ROUTE.AUGMENTATION_END_OF_ROUTE,
+ cfg.ROUTE.AUGMENTATION_SMALL_ROTATION,
+ cfg.ROUTE.AUGMENTATION_LARGE_ROTATION,
+ cfg.ROUTE.AUGMENTATION_DEGREES,
+ cfg.ROUTE.AUGMENTATION_TRANSLATE,
+ cfg.ROUTE.AUGMENTATION_SCALE,
+ cfg.ROUTE.AUGMENTATION_SHEAR,
+ )
+
+ self.register_buffer('image_mean', torch.tensor(cfg.IMAGE.IMAGENET_MEAN).unsqueeze(1).unsqueeze(1))
+ self.register_buffer('image_std', torch.tensor(cfg.IMAGE.IMAGENET_STD).unsqueeze(1).unsqueeze(1))
+
+ def augmentation(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ batch = self.pixel_augmentation(batch)
+ batch = self.route_augmentation(batch)
+ return batch
+
+ def prepare_bev_labels(self, batch):
+ if 'birdview_label' in batch:
+ # Mask bird's-eye view label pixels that are not visible from the input image
+ if self.cfg.EVAL.MASK_VIEW:
+ batch['birdview_label'][:, :, :, self.bev_out_of_view_mask] = 0
+
+ # Currently the frustum pooling is set up such that the bev features are rotated by 90 degrees clockwise
+ batch['birdview_label'] = torch.rot90(batch['birdview_label'], k=-1, dims=[3, 4]).contiguous()
+
+ # Compute labels at half, quarter, and 1/8th resolution
+ batch['birdview_label_1'] = batch['birdview_label']
+ h, w = batch['birdview_label'].shape[-2:]
+ for downsample_factor in [2, 4]:
+ size = h // downsample_factor, w // downsample_factor
+ previous_label_factor = downsample_factor // 2
+ batch[f'birdview_label_{downsample_factor}'] = functional_resize(
+ batch[f'birdview_label_{previous_label_factor}'], size, mode=tvf.InterpolationMode.NEAREST
+ )
+
+ if 'instance_label' in batch:
+ # Mask elements not visible from the input image
+ if self.cfg.EVAL.MASK_VIEW:
+ batch['instance_label'][:, :, :, self.bev_out_of_view_mask] = 0
+ # Currently the frustum pooling is set up such that the bev features are rotated by 90 degrees clockwise
+ batch['instance_label'] = torch.rot90(batch['instance_label'], k=-1, dims=[3, 4]).contiguous()
+
+ center_label, offset_label = convert_instance_mask_to_center_and_offset_label(
+ batch['instance_label'], ignore_index=self.ignore_index, sigma=self.center_sigma,
+ )
+ batch['center_label'] = center_label
+ batch['offset_label'] = offset_label
+
+ # Compute labels at half, quarter, and 1/8th resolution
+ batch['instance_label_1'] = batch['instance_label']
+ batch['center_label_1'] = batch['center_label']
+ batch['offset_label_1'] = batch['offset_label']
+
+ h, w = batch['instance_label'].shape[-2:]
+ for downsample_factor in [2, 4]:
+ size = h // downsample_factor, w // downsample_factor
+ previous_label_factor = downsample_factor // 2
+ batch[f'instance_label_{downsample_factor}'] = functional_resize(
+ batch[f'instance_label_{previous_label_factor}'], size, mode=tvf.InterpolationMode.NEAREST
+ )
+
+ center_label, offset_label = convert_instance_mask_to_center_and_offset_label(
+ batch[f'instance_label_{downsample_factor}'], ignore_index=self.ignore_index,
+ sigma=self.center_sigma/downsample_factor,
+ )
+ batch[f'center_label_{downsample_factor}'] = center_label
+ batch[f'offset_label_{downsample_factor}'] = offset_label
+
+ if self.cfg.EVAL.RGB_SUPERVISION:
+ # Compute labels at half, quarter, and 1/8th resolution
+ batch['rgb_label_1'] = batch['image']
+ h, w = batch['rgb_label_1'].shape[-2:]
+ for downsample_factor in [2, 4]:
+ size = h // downsample_factor, w // downsample_factor
+ previous_label_factor = downsample_factor // 2
+ batch[f'rgb_label_{downsample_factor}'] = functional_resize(
+ batch[f'rgb_label_{previous_label_factor}'],
+ size,
+ mode=tvf.InterpolationMode.BILINEAR,
+ )
+
+ if self.cfg.LOSSES.RGB_INSTANCE:
+ batch['image_instance_mask_1'] = batch['image_instance_mask']
+ h, w = batch['image_instance_mask_1'].shape[-2:]
+ for downsample_factor in [2, 4]:
+ size = h // downsample_factor, w // downsample_factor
+ previous_label_factor = downsample_factor // 2
+ batch[f'image_instance_mask_{downsample_factor}'] = functional_resize(
+ batch[f'image_instance_mask_{previous_label_factor}'],
+ size,
+ mode=tvf.InterpolationMode.NEAREST,
+ )
+
+ if self.cfg.SEMANTIC_IMAGE.ENABLED:
+ batch['semantic_image_label_1'] = batch['semantic_image']
+ h, w = batch['semantic_image_label_1'].shape[-2:]
+ for downsample_factor in [2, 4]:
+ size = h // downsample_factor, w // downsample_factor
+ previous_label_factor = downsample_factor // 2
+ batch[f'semantic_image_label_{downsample_factor}'] = functional_resize(
+ batch[f'semantic_image_label_{previous_label_factor}'],
+ size,
+ mode=tvf.InterpolationMode.NEAREST,
+ )
+
+ if self.cfg.DEPTH.ENABLED:
+ batch['depth_label_1'] = batch['depth']
+ h, w = batch['depth_label_1'].shape[-2:]
+ for downsample_factor in [2, 4]:
+ size = h // downsample_factor, w // downsample_factor
+ previous_label_factor = downsample_factor // 2
+ batch[f'depth_label_{downsample_factor}'] = functional_resize(
+ batch[f'depth_label_{previous_label_factor}'],
+ size,
+ mode=tvf.InterpolationMode.BILINEAR,
+ )
+
+ if self.cfg.LIDAR_RE.ENABLED:
+ batch['range_view_pcd_xyzd'] = batch['range_view_pcd_xyzd'].float() / self.cfg.LIDAR_RE.SCALE
+ batch['range_view_label_1'] = batch['range_view_pcd_xyzd']
+ h, w = batch['range_view_label_1'].shape[-2:]
+ for downsample_factor in [2, 4]:
+ size = h // downsample_factor, w // downsample_factor
+ previous_label_factor = downsample_factor // 2
+ batch[f'range_view_label_{downsample_factor}'] = functional_resize(
+ batch[f'range_view_label_{previous_label_factor}'],
+ size,
+ mode=tvf.InterpolationMode.NEAREST,
+ )
+
+ if self.cfg.LIDAR_SEG.ENABLED:
+ batch['range_view_seg_label_1'] = batch['range_view_pcd_seg']
+ h, w = batch['range_view_seg_label_1'].shape[-2:]
+ for downsample_factor in [2, 4]:
+ size = h // downsample_factor, w // downsample_factor
+ previous_label_factor = downsample_factor // 2
+ batch[f'range_view_seg_label_{downsample_factor}'] = functional_resize(
+ batch[f'range_view_seg_label_{previous_label_factor}'],
+ size,
+ mode=tvf.InterpolationMode.NEAREST,
+ )
+
+ if self.cfg.VOXEL_SEG.ENABLED:
+ batch['voxel_label_1'] = batch['voxel']
+ x, y, z = batch['voxel_label_1'].shape[-3:]
+ for downsample_factor in [2, 4]:
+ size = (x // downsample_factor, y // downsample_factor, z // downsample_factor)
+ previous_label_factor = downsample_factor // 2
+ batch[f'voxel_label_{downsample_factor}'] = functional_resize_voxel(
+ batch[f'voxel_label_{previous_label_factor}'],
+ size,
+ mode='nearest',
+ )
+
+ # if 'points_histogram' in batch:
+ # # mask histogram the same as bev.
+ # if self.cfg.EVAL.MASK_VIEW:
+ # scale = self.cfg.POINTS.HISTOGRAM.RESOLUTION * self.cfg.BEV.RESOLUTION
+ # bev_shape = self.bev_out_of_view_mask.shape
+ # out_shape = [int(scale * bev_shape[0]), int(scale * bev_shape[1])]
+ # view_mask = skt.resize(self.bev_out_of_view_mask, out_shape)
+ # batch['points_histogram'] = tvf.center_crop(batch['points_histogram'], out_shape)
+ # batch['points_histogram'][:, :, :, view_mask[::-1, ::-1]] = 0
+ # batch['points_histogram'] = torch.rot90(batch['points_histogram'], k=-1, dims=[3, 4]).contiguous()
+
+ return batch
+
+ def forward(self, batch: Dict[str, torch.Tensor]):
+ # Normalise from [0, 255] to [0, 1]
+ batch['image'] = batch['image'].float() / 255
+
+ if 'route_map' in batch:
+ batch['route_map'] = batch['route_map'].float() / 255
+ batch['route_map'] = functional_resize(batch['route_map'], size=(self.route_map_size, self.route_map_size))
+ batch = functional_crop(batch, self.crop)
+ if self.cfg.EVAL.RESOLUTION.ENABLED:
+ batch = functional_resize_batch(batch, scale=1/self.cfg.EVAL.RESOLUTION.FACTOR)
+
+ batch = self.prepare_bev_labels(batch)
+
+ if self.training:
+ batch = self.augmentation(batch)
+
+ # Use imagenet mean and std normalisation, because we're loading pretrained backbones
+ batch['image'] = (batch['image'] - self.image_mean) / self.image_std
+ if 'route_map' in batch:
+ batch['route_map'] = (batch['route_map'] - self.image_mean) / self.image_std
+
+ if 'depth' in batch:
+ batch['depth_mask'] = (batch['depth'] > self.min_depth) & (batch['depth'] < self.max_depth)
+
+ return batch
+
+
+def functional_crop(batch: Dict[str, torch.Tensor], crop: Tuple[int, int, int, int]):
+ left, top, right, bottom = crop
+ height = bottom - top
+ width = right - left
+ if 'image' in batch:
+ batch['image'] = tvf.crop(batch['image'], top, left, height, width)
+ if 'depth' in batch:
+ batch['depth'] = tvf.crop(batch['depth'], top, left, height, width)
+ if 'depth_color' in batch:
+ batch['depth_color'] = tvf.crop(batch['depth'], top, left, height, width)
+ if 'semseg' in batch:
+ batch['semseg'] = tvf.crop(batch['semseg'], top, left, height, width)
+ if 'semantic_image' in batch:
+ batch['semantic_image'] = tvf.crop(batch['semantic_image'], top, left, height, width)
+ if 'image_instance_mask' in batch:
+ batch['image_instance_mask'] = tvf.crop(batch['image_instance_mask'], top, left, height, width)
+ if 'intrinsics' in batch:
+ intrinsics = batch['intrinsics'].clone()
+ intrinsics[..., 0, 2] -= left
+ intrinsics[..., 1, 2] -= top
+ batch['intrinsics'] = intrinsics
+
+ return batch
+
+
+def functional_resize_batch(batch, scale):
+ b, s, c, h, w = batch['image'].shape
+ h1, w1 = int(round(h * scale)), int(round(w * scale))
+ size = (h1, w1)
+ if 'image' in batch:
+ image = batch['image'].view(b*s, c, h, w)
+ image = tvf.resize(image, size, antialias=True)
+ batch['image'] = image.view(b, s, c, h1, w1)
+ if 'intrinsics' in batch:
+ intrinsics = batch['intrinsics'].clone()
+ intrinsics[..., :2, :] *= scale
+ batch['intrinsics'] = intrinsics
+ if 'image_instance_mask' in batch:
+ image = batch['image_instance_mask'].view(b*s, c, h, w)
+ image = tvf.resize(image, size, antialias=True)
+ batch['image_instance_mask'] = image.view(b, s, c, h1, w1)
+ if 'semantic_image' in batch:
+ image = batch['semantic_image'].view(b*s, c, h, w)
+ image = tvf.resize(image, size, antialias=True)
+ batch['semantic_image'] = image.view(b, s, c, h1, w1)
+
+ return batch
+
+
+def functional_resize(x, size, mode=tvf.InterpolationMode.NEAREST):
+ b, s, c, h, w = x.shape
+ x = x.view(b * s, c, h, w)
+ x = tvf.resize(x, size, interpolation=mode)
+ x = x.view(b, s, c, *size)
+
+ return x
+
+
+def functional_resize_voxel(voxel, size, mode='nearst'):
+ b, s, c, x, y, z = voxel.shape
+ voxel = voxel.view(b * s, c, x, y, z)
+ voxel = F.interpolate(voxel, size, mode=mode)
+ voxel = voxel.view(b, s, c, *size)
+
+ return voxel
+
+
+class PixelAugmentation(nn.Module):
+ def __init__(self, cfg):
+ super().__init__()
+ # TODO replace with ImageApply([RandomBlurSharpen(), RandomColorJitter(), ...])
+ self.blur_prob = cfg.IMAGE.AUGMENTATION.BLUR_PROB
+ self.sharpen_prob = cfg.IMAGE.AUGMENTATION.SHARPEN_PROB
+ self.blur_window = cfg.IMAGE.AUGMENTATION.BLUR_WINDOW
+ self.blur_std = cfg.IMAGE.AUGMENTATION.BLUR_STD
+ self.sharpen_factor = cfg.IMAGE.AUGMENTATION.SHARPEN_FACTOR
+ assert self.blur_prob + self.sharpen_prob <= 1
+
+ self.color_jitter = transforms.RandomApply(nn.ModuleList([
+ transforms.ColorJitter(
+ cfg.IMAGE.AUGMENTATION.COLOR_JITTER_BRIGHTNESS,
+ cfg.IMAGE.AUGMENTATION.COLOR_JITTER_CONTRAST,
+ cfg.IMAGE.AUGMENTATION.COLOR_JITTER_SATURATION,
+ cfg.IMAGE.AUGMENTATION.COLOR_JITTER_HUE
+ )
+ ]), cfg.IMAGE.AUGMENTATION.COLOR_PROB)
+
+ def forward(self, batch: Dict[str, torch.Tensor]):
+ image = batch['image']
+ for i in range(image.shape[0]):
+ for j in range(image.shape[1]):
+ # random blur
+ rand_value = torch.rand(1)
+ if rand_value < self.blur_prob:
+ std = torch.empty(1).uniform_(self.blur_std[0], self.blur_std[1]).item()
+ image[i, j] = tvf.gaussian_blur(image[i, j], self.blur_window, std)
+ # random sharpen
+ elif rand_value < self.blur_prob + self.sharpen_prob:
+ factor = torch.empty(1).uniform_(self.sharpen_factor[0], self.sharpen_factor[1]).item()
+ image[i, j] = tvf.adjust_sharpness(image[i, j], factor)
+
+ # random color jitter
+ image[i, j] = self.color_jitter(image[i, j])
+
+ batch['image'] = image
+ return batch
+
+
+class RouteAugmentation(nn.Module):
+ def __init__(self, drop=0.025, end_of_route=0.025, small_rotation=0.025, large_rotation=0.025, degrees=8.0,
+ translate=(.1, .1), scale=(.95, 1.05), shear=(.1, .1)):
+ super().__init__()
+ assert drop + end_of_route + small_rotation + large_rotation <= 1
+ self.drop = drop # random dropout of map
+ self.end_of_route = end_of_route # probability of end of route augmentation
+ self.small_rotation = small_rotation # probability of doing small rotation
+ self.large_rotation = large_rotation # probability of doing large rotation (arbitrary orientation)
+ self.small_perturbation = transforms.RandomAffine(degrees, translate, scale, shear) # small rotation
+ self.large_perturbation = transforms.RandomAffine(180, translate, scale, shear) # arbitrary orientation
+
+ def forward(self, batch):
+ if 'route_map' in batch:
+ route_map = batch['route_map']
+
+ # TODO: make augmentation independent of the sequence dimension?
+ for i in range(route_map.shape[0]):
+ rand_value = torch.rand(1)
+ if rand_value < self.drop:
+ route_map[i] = torch.zeros_like(route_map[i])
+ elif rand_value < self.drop + self.end_of_route:
+ height = torch.randint(route_map[i].shape[-2], (1,))
+ route_map[i][:, :, :height] = 0
+ elif rand_value < self.drop + self.end_of_route + self.small_rotation:
+ route_map[i] = self.small_perturbation(route_map[i])
+ elif rand_value < self.drop + self.end_of_route + self.small_rotation + self.large_rotation:
+ route_map[i] = self.large_perturbation(route_map[i])
+
+ batch['route_map'] = route_map
+
+ return batch
diff --git a/muvo/models/transition.py b/muvo/models/transition.py
new file mode 100644
index 0000000..49a78dd
--- /dev/null
+++ b/muvo/models/transition.py
@@ -0,0 +1,191 @@
+import torch
+import torch.nn as nn
+
+
+class RepresentationModel(nn.Module):
+ def __init__(self, in_channels, latent_dim):
+ super().__init__()
+ self.latent_dim = latent_dim
+ self.min_std = 0.1
+
+ self.module = nn.Sequential(
+ nn.Linear(in_channels, in_channels),
+ nn.LeakyReLU(True),
+ nn.Linear(in_channels, 2*self.latent_dim),
+ )
+
+ def forward(self, x):
+ def sigmoid2(tensor: torch.Tensor, min_value: float) -> torch.Tensor:
+ return 2 * torch.sigmoid(tensor / 2) + min_value
+
+ mu_log_sigma = self.module(x)
+ mu, log_sigma = torch.split(mu_log_sigma, self.latent_dim, dim=-1)
+
+ sigma = sigmoid2(log_sigma, self.min_std)
+ return mu, sigma
+
+
+class RSSM(nn.Module):
+ def __init__(self, embedding_dim, action_dim, hidden_state_dim, state_dim, action_latent_dim, receptive_field,
+ use_dropout=False,
+ dropout_probability=0.0):
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.state_dim = state_dim
+ self.action_dim = action_dim
+ self.hidden_state_dim = hidden_state_dim
+ self.action_latent_dim = action_latent_dim
+ self.receptive_field = receptive_field
+ # Sometimes unroll the prior instead of always updating with the posterior
+ # so that the model learns to unroll for more than 1 step in the future when imagining
+ self.use_dropout = use_dropout
+ self.dropout_probability = dropout_probability
+
+ # Map input of the gru to a space with easier temporal dynamics
+ self.pre_gru_net = nn.Sequential(
+ nn.Linear(state_dim, hidden_state_dim),
+ nn.LeakyReLU(True),
+ )
+
+ self.recurrent_model = nn.GRUCell(
+ input_size=hidden_state_dim,
+ hidden_size=hidden_state_dim,
+ )
+
+ # Map action to a higher dimensional input
+ self.posterior_action_module = nn.Sequential(
+ nn.Linear(action_dim, self.action_latent_dim),
+ nn.LeakyReLU(True),
+ )
+
+ self.posterior = RepresentationModel(
+ in_channels=hidden_state_dim + embedding_dim + self.action_latent_dim,
+ latent_dim=state_dim,
+ )
+
+ # Map action to a higher dimensional input
+ self.prior_action_module = nn.Sequential(
+ nn.Linear(action_dim, self.action_latent_dim),
+ nn.LeakyReLU(True),
+ )
+ self.prior = RepresentationModel(in_channels=hidden_state_dim + self.action_latent_dim, latent_dim=state_dim)
+ self.active_inference = False
+ if self.active_inference:
+ print('ACTIVE INFERENCE!!')
+
+ def forward(self, input_embedding, action, use_sample=True, policy=None):
+ """
+ Inputs
+ ------
+ input_embedding: torch.Tensor size (B, S, C)
+ action: torch.Tensor size (B, S, 2)
+ use_sample: bool
+ whether to use sample from the distributions, or taking the mean
+
+ Returns
+ -------
+ output: dict
+ prior: dict
+ hidden_state: torch.Tensor (B, S, C_h)
+ sample: torch.Tensor (B, S, C_s)
+ mu: torch.Tensor (B, S, C_s)
+ sigma: torch.Tensor (B, S, C_s)
+ posterior: dict
+ hidden_state: torch.Tensor (B, S, C_h)
+ sample: torch.Tensor (B, S, C_s)
+ mu: torch.Tensor (B, S, C_s)
+ sigma: torch.Tensor (B, S, C_s)
+ """
+ output = {
+ 'prior': [],
+ 'posterior': [],
+ }
+
+ # Initialisation
+ batch_size, sequence_length, _ = input_embedding.shape
+ h_t = input_embedding.new_zeros((batch_size, self.hidden_state_dim))
+ sample_t = input_embedding.new_zeros((batch_size, self.state_dim))
+ for t in range(sequence_length):
+ if t == 0:
+ action_t = torch.zeros_like(action[:, 0])
+ else:
+ action_t = action[:, t-1]
+ output_t = self.observe_step(
+ h_t, sample_t, action_t, input_embedding[:, t], use_sample=use_sample, policy=policy,
+ )
+ # During training sample from the posterior, except when using dropout
+ # always use posterior for the first frame
+ use_prior = self.training and self.use_dropout and torch.rand(1).item() < self.dropout_probability and t > 0
+
+ if use_prior:
+ sample_t = output_t['prior']['sample']
+ else:
+ sample_t = output_t['posterior']['sample']
+ h_t = output_t['prior']['hidden_state']
+
+ for key, value in output_t.items():
+ output[key].append(value)
+
+ output = self.stack_list_of_dict_tensor(output, dim=1)
+ return output
+
+ def observe_step(self, h_t, sample_t, action_t, embedding_t, use_sample=True, policy=None):
+ imagine_output = self.imagine_step(h_t, sample_t, action_t, use_sample, policy=policy)
+
+ latent_action_t = self.posterior_action_module(action_t)
+ posterior_mu_t, posterior_sigma_t = self.posterior(
+ torch.cat([imagine_output['hidden_state'], embedding_t, latent_action_t], dim=-1)
+ )
+
+ sample_t = self.sample_from_distribution(posterior_mu_t, posterior_sigma_t, use_sample=use_sample)
+
+ posterior_output = {
+ 'hidden_state': imagine_output['hidden_state'],
+ 'sample': sample_t,
+ 'mu': posterior_mu_t,
+ 'sigma': posterior_sigma_t,
+ }
+
+ output = {
+ 'prior': imagine_output,
+ 'posterior': posterior_output,
+ }
+
+ return output
+
+ def imagine_step(self, h_t, sample_t, action_t, use_sample=True, policy=None):
+ if self.active_inference:
+ # Predict action with policy
+ action_t = policy(torch.cat([h_t, sample_t], dim=-1))
+
+ latent_action_t = self.prior_action_module(action_t)
+
+ input_t = self.pre_gru_net(sample_t)
+ h_t = self.recurrent_model(input_t, h_t)
+ prior_mu_t, prior_sigma_t = self.prior(torch.cat([h_t, latent_action_t], dim=-1))
+ sample_t = self.sample_from_distribution(prior_mu_t, prior_sigma_t, use_sample=use_sample)
+ imagine_output = {
+ 'hidden_state': h_t,
+ 'sample': sample_t,
+ 'mu': prior_mu_t,
+ 'sigma': prior_sigma_t,
+ }
+ return imagine_output
+
+ @staticmethod
+ def sample_from_distribution(mu, sigma, use_sample):
+ sample = mu
+ if use_sample:
+ noise = torch.randn_like(sample)
+ sample = sample + sigma * noise
+ return sample
+
+ @staticmethod
+ def stack_list_of_dict_tensor(output, dim=1):
+ new_output = {}
+ for outter_key, outter_value in output.items():
+ if len(outter_value) > 0:
+ new_output[outter_key] = dict()
+ for inner_key in outter_value[0].keys():
+ new_output[outter_key][inner_key] = torch.stack([x[inner_key] for x in outter_value], dim=dim)
+ return new_output
diff --git a/muvo/models/utils.py b/muvo/models/utils.py
new file mode 100644
index 0000000..5faeb6f
--- /dev/null
+++ b/muvo/models/utils.py
@@ -0,0 +1,13 @@
+import torch
+import torch.nn as nn
+
+
+class Concat(nn.Module):
+ def forward(self, x1, x2):
+ return torch.cat([x1, x2], dim=-1)
+
+
+class PixelNorm(nn.Module):
+ def forward(self, x, epsilon=1e-8):
+ return x / torch.sqrt(torch.mean(x**2, dim=1, keepdims=True) + epsilon)
+
diff --git a/muvo/trainer.py b/muvo/trainer.py
new file mode 100644
index 0000000..a1f13e4
--- /dev/null
+++ b/muvo/trainer.py
@@ -0,0 +1,1095 @@
+import os
+
+import numpy as np
+import cv2
+import torch
+import torch.nn.functional as F
+import lightning.pytorch as pl
+from torchmetrics import JaccardIndex
+
+from muvo.config import get_cfg
+from muvo.models.mile import Mile
+from muvo.losses import \
+ SegmentationLoss, KLLoss, RegressionLoss, SpatialRegressionLoss, VoxelLoss, SSIMLoss, SemScalLoss, GeoScalLoss
+from muvo.metrics import SSCMetrics, SSIMMetric, CDMetric, PSNRMetric
+from muvo.models.preprocess import PreProcess
+from muvo.utils.geometry_utils import PointCloud, compute_pcd_transformation
+from constants import BIRDVIEW_COLOURS, VOXEL_COLOURS, VOXEL_LABEL
+
+import matplotlib
+import matplotlib.pyplot as plt
+
+matplotlib.use('Agg')
+
+
+class WorldModelTrainer(pl.LightningModule):
+ def __init__(self, hparams, path_to_conf_file=None, pretrained_path=None):
+ super().__init__()
+ self.save_hyperparameters()
+ self.cfg = get_cfg(cfg_dict=hparams)
+ if path_to_conf_file:
+ self.cfg.merge_from_file(path_to_conf_file)
+ if pretrained_path:
+ self.cfg.PRETRAINED.PATH = pretrained_path
+ # print(self.cfg)
+ self.vis_step = -1
+ self.rf = self.cfg.RECEPTIVE_FIELD
+ self.fh = self.cfg.FUTURE_HORIZON
+
+ self.cml_logger = None
+ self.preprocess = PreProcess(self.cfg)
+
+ # Model
+ self.model = Mile(self.cfg)
+ self.load_pretrained_weights()
+
+ # self.metrics_vals = [dict() for _ in range(len(self.val_dataloader()))]
+ # self.metrics_vals_imagine = [dict() for _ in range(len(self.val_dataloader()))]
+ # self.metrics_tests = [dict() for _ in range(len(self.test_dataloader()))]
+ # self.metrics_tests_imagine = [dict() for _ in range(len(self.test_dataloader()))]
+ # self.metrics_train = dict()
+ self.metrics_vals = [{}, {}, {}]
+ self.metrics_vals_imagine = [{}, {}, {}]
+ self.metrics_tests = [{}, {}, {}]
+ self.metrics_tests_imagine = [{}, {}, {}]
+ self.metrics_train = dict()
+
+ # Losses
+ self.action_loss = RegressionLoss(norm=1)
+ if self.cfg.MODEL.TRANSITION.ENABLED:
+ self.probabilistic_loss = KLLoss(alpha=self.cfg.LOSSES.KL_BALANCING_ALPHA)
+
+ if self.cfg.SEMANTIC_SEG.ENABLED:
+ self.segmentation_loss = SegmentationLoss(
+ use_top_k=self.cfg.SEMANTIC_SEG.USE_TOP_K,
+ top_k_ratio=self.cfg.SEMANTIC_SEG.TOP_K_RATIO,
+ use_weights=self.cfg.SEMANTIC_SEG.USE_WEIGHTS,
+ is_bev=True,
+ )
+
+ self.center_loss = SpatialRegressionLoss(norm=2)
+ self.offset_loss = SpatialRegressionLoss(norm=1, ignore_index=self.cfg.INSTANCE_SEG.IGNORE_INDEX)
+
+ for metrics_val, metrics_val_imagine in zip(self.metrics_vals, self.metrics_vals_imagine):
+ metrics_val['iou'] = JaccardIndex(
+ task='multiclass', num_classes=self.cfg.SEMANTIC_SEG.N_CHANNELS, average='none',
+ )
+ metrics_val_imagine['iou'] = JaccardIndex(
+ task='multiclass', num_classes=self.cfg.SEMANTIC_SEG.N_CHANNELS, average='none',
+ )
+
+ for metrics_test, metrics_test_imagine in zip(self.metrics_tests, self.metrics_tests_imagine):
+ metrics_test['iou'] = JaccardIndex(
+ task='multiclass', num_classes=self.cfg.SEMANTIC_SEG.N_CHANNELS, average='none',
+ )
+ metrics_test_imagine['iou'] = JaccardIndex(
+ task='multiclass', num_classes=self.cfg.SEMANTIC_SEG.N_CHANNELS, average='none',
+ )
+
+ # self.metrics_train['iou'] = JaccardIndex(
+ # task='multiclass', num_classes=self.cfg.SEMANTIC_SEG.N_CHANNELS, average='none',
+ # )
+
+ if self.cfg.EVAL.RGB_SUPERVISION:
+ self.rgb_loss = SpatialRegressionLoss(norm=1)
+ if self.cfg.LOSSES.RGB_INSTANCE:
+ self.rgb_instance_loss = SpatialRegressionLoss(norm=1)
+ if self.cfg.LOSSES.SSIM:
+ self.ssim_loss = SSIMLoss(channel=3)
+
+ for metrics_val, metrics_val_imagine in zip(self.metrics_vals, self.metrics_vals_imagine):
+ metrics_val['ssim'] = SSIMMetric(channel=3)
+ metrics_val_imagine['ssim'] = SSIMMetric(channel=3)
+ metrics_val['psnr'] = PSNRMetric(max_pixel_val=1.0)
+ metrics_val_imagine['psnr'] = PSNRMetric(max_pixel_val=1.0)
+ for metrics_test, metrics_test_imagine in zip(self.metrics_tests, self.metrics_tests_imagine):
+ metrics_test['ssim'] = SSIMMetric(channel=3)
+ metrics_test_imagine['ssim'] = SSIMMetric(channel=3)
+ metrics_test['psnr'] = PSNRMetric(max_pixel_val=1.0)
+ metrics_test_imagine['psnr'] = PSNRMetric(max_pixel_val=1.0)
+ # self.metrics_train['ssim'] = SSIMMetric(channel=3)
+ # self.metrics_train['psnr'] = PSNRMetric(max_pixel_val=1.0)
+
+ if self.cfg.LIDAR_RE.ENABLED:
+ self.lidar_re_loss = SpatialRegressionLoss(norm=2)
+ self.lidar_depth_loss = SpatialRegressionLoss(norm=1)
+ # self.lidar_cd_loss = CDLoss()
+ self.pcd = PointCloud(
+ self.cfg.POINTS.CHANNELS,
+ self.cfg.POINTS.HORIZON_RESOLUTION,
+ *self.cfg.POINTS.FOV,
+ self.cfg.POINTS.LIDAR_POSITION
+ )
+
+ for metrics_val, metrics_val_imagine in zip(self.metrics_vals, self.metrics_vals_imagine):
+ metrics_val['cd'] = CDMetric()
+ metrics_val_imagine['cd'] = CDMetric()
+ for metrics_test, metrics_test_imagine in zip(self.metrics_tests, self.metrics_tests_imagine):
+ metrics_test['cd'] = CDMetric()
+ metrics_test_imagine['cd'] = CDMetric()
+ # self.metrics_train['cd'] = CDMetric()
+
+ if self.cfg.LIDAR_SEG.ENABLED:
+ self.lidar_seg_loss = SegmentationLoss(
+ use_top_k=self.cfg.LIDAR_SEG.USE_TOP_K,
+ top_k_ratio=self.cfg.LIDAR_SEG.TOP_K_RATIO,
+ use_weights=self.cfg.LIDAR_SEG.USE_WEIGHTS,
+ is_bev=False,
+ )
+
+ for metrics_val, metrics_val_imagine in zip(self.metrics_vals, self.metrics_vals_imagine):
+ metrics_val['pcd_iou'] = JaccardIndex(
+ task='multiclass', num_classes=self.cfg.LIDAR_SEG.N_CLASSES, average='none',
+ )
+ metrics_val_imagine['pcd_iou'] = JaccardIndex(
+ task='multiclass', num_classes=self.cfg.LIDAR_SEG.N_CLASSES, average='none',
+ )
+
+ for metrics_test, metrics_test_imagine in zip(self.metrics_tests, self.metrics_tests_imagine):
+ metrics_test['pcd_iou'] = JaccardIndex(
+ task='multiclass', num_classes=self.cfg.LIDAR_SEG.N_CLASSES, average='none',
+ )
+ metrics_test_imagine['pcd_iou'] = JaccardIndex(
+ task='multiclass', num_classes=self.cfg.LIDAR_SEG.N_CLASSES, average='none',
+ )
+
+ if self.cfg.SEMANTIC_IMAGE.ENABLED:
+ self.sem_image_loss = SegmentationLoss(
+ use_top_k=self.cfg.SEMANTIC_IMAGE.USE_TOP_K,
+ top_k_ratio=self.cfg.SEMANTIC_IMAGE.TOP_K_RATIO,
+ use_weights=self.cfg.SEMANTIC_IMAGE.USE_WEIGHTS,
+ is_bev=False,
+ )
+
+ for metrics_val, metrics_val_imagine in zip(self.metrics_vals, self.metrics_vals_imagine):
+ metrics_val['image_iou'] = JaccardIndex(
+ task='multiclass', num_classes=self.cfg.SEMANTIC_IMAGE.N_CLASSES, average='none',
+ )
+ metrics_val_imagine['image_iou'] = JaccardIndex(
+ task='multiclass', num_classes=self.cfg.SEMANTIC_IMAGE.N_CLASSES, average='none',
+ )
+
+ for metrics_test, metrics_test_imagine in zip(self.metrics_tests, self.metrics_tests_imagine):
+ metrics_test['image_iou'] = JaccardIndex(
+ task='multiclass', num_classes=self.cfg.SEMANTIC_IMAGE.N_CLASSES, average='none',
+ )
+ metrics_test_imagine['image_iou'] = JaccardIndex(
+ task='multiclass', num_classes=self.cfg.SEMANTIC_IMAGE.N_CLASSES, average='none',
+ )
+
+ if self.cfg.DEPTH.ENABLED:
+ self.depth_image_loss = SpatialRegressionLoss(norm=1)
+
+ if self.cfg.VOXEL_SEG.ENABLED:
+ self.voxel_loss = VoxelLoss(
+ use_top_k=self.cfg.VOXEL_SEG.USE_TOP_K,
+ top_k_ratio=self.cfg.VOXEL_SEG.TOP_K_RATIO,
+ use_weights=self.cfg.VOXEL_SEG.USE_WEIGHTS,
+ )
+ self.sem_scal_loss = SemScalLoss()
+ self.geo_scal_loss = GeoScalLoss()
+ for metrics_val, metrics_val_imagine in zip(self.metrics_vals, self.metrics_vals_imagine):
+ metrics_val['ssc'] = SSCMetrics(self.cfg.VOXEL_SEG.N_CLASSES)
+ metrics_val_imagine['ssc'] = SSCMetrics(self.cfg.VOXEL_SEG.N_CLASSES)
+ for metrics_test, metrics_test_imagine in zip(self.metrics_tests, self.metrics_tests_imagine):
+ metrics_test['ssc'] = SSCMetrics(self.cfg.VOXEL_SEG.N_CLASSES)
+ metrics_test_imagine['ssc'] = SSCMetrics(self.cfg.VOXEL_SEG.N_CLASSES)
+ # self.metrics_train['ssc'] = SSCMetrics(self.cfg.VOXEL_SEG.N_CLASSES)
+
+ def get_cml_logger(self, cml_logger):
+ self.cml_logger = cml_logger
+
+ def load_pretrained_weights(self):
+ if self.cfg.PRETRAINED.PATH:
+ if os.path.isfile(self.cfg.PRETRAINED.PATH):
+ checkpoint = torch.load(self.cfg.PRETRAINED.PATH, map_location='cpu')['state_dict']
+ checkpoint = {key[6:]: value for key, value in checkpoint.items() if key[:5] == 'model'}
+
+ self.model.load_state_dict(checkpoint, strict=True)
+ print(f'Loaded weights from: {self.cfg.PRETRAINED.PATH}')
+ else:
+ raise FileExistsError(self.cfg.PRETRAINED.PATH)
+
+ def forward(self, batch, deployment=False):
+ batch = self.preprocess(batch)
+ output, state_dict = self.model.forward(batch, deployment=deployment)
+ return output, state_dict
+
+ def deployment_forward(self, batch, is_dreaming):
+ batch = self.preprocess(batch)
+ output = self.model.deployment_forward(batch, is_dreaming)
+ return output
+
+ def shared_step(self, batch, mode='train', predict_action=False):
+ n_prediction_samples = self.cfg.PREDICTION.N_SAMPLES
+ output_imagines = []
+ losses_imagines = []
+
+ if mode == 'train':
+ # in training, only reconstruction
+ output, state_dict = self.forward(batch)
+ losses = self.compute_loss(batch, output)
+ else:
+ batch = self.preprocess(batch)
+ batch_rf = {key: value[:, :self.rf] for key, value in batch.items()} # dim (b, s, 512)
+ batch_fh = {key: value[:, self.rf:] for key, value in batch.items()} # dim (b, s, 512)
+ output, state_dict = self.model.forward(batch_rf, deployment=False)
+ losses = self.compute_loss(batch_rf, output)
+
+ # in evaluation, do imagination (prediction)
+ state_imagine = {'hidden_state': state_dict['posterior']['hidden_state'][:, -1],
+ 'sample': state_dict['posterior']['sample'][:, -1],
+ 'throttle_brake': batch['throttle_brake'][:, self.rf:],
+ 'steering': batch['steering'][:, self.rf:]}
+ for _ in range(n_prediction_samples):
+ output_imagine = self.model.imagine(state_imagine, predict_action=predict_action, future_horizon=self.fh)
+ output_imagines.append(output_imagine)
+ losses_imagines.append(self.compute_loss(batch_fh, output_imagine))
+
+ return losses, output, losses_imagines, output_imagines
+
+ def compute_loss(self, batch, output):
+ losses = dict()
+
+ action_weight = self.cfg.LOSSES.WEIGHT_ACTION
+ if 'throttle_brake' in output.keys():
+ losses['throttle_brake'] = action_weight * self.action_loss(output['throttle_brake'],
+ batch['throttle_brake'])
+ if 'steering' in output.keys():
+ losses['steering'] = action_weight * self.action_loss(output['steering'], batch['steering'])
+
+ if self.cfg.MODEL.TRANSITION.ENABLED:
+ if 'prior' in output.keys() and 'posterior' in output.keys():
+ probabilistic_loss = self.probabilistic_loss(output['prior'], output['posterior'])
+
+ losses['probabilistic'] = self.cfg.LOSSES.WEIGHT_PROBABILISTIC * probabilistic_loss
+
+ # compute losses in down-sampling scale 1, 2, 4, separately.
+ if self.cfg.SEMANTIC_SEG.ENABLED:
+ for downsampling_factor in [1, 2, 4]:
+ bev_segmentation_loss = self.segmentation_loss(
+ prediction=output[f'bev_segmentation_{downsampling_factor}'],
+ target=batch[f'birdview_label_{downsampling_factor}'],
+ )
+ discount = 1 / downsampling_factor
+ losses[f'bev_segmentation_{downsampling_factor}'] = discount * self.cfg.LOSSES.WEIGHT_SEGMENTATION * \
+ bev_segmentation_loss
+
+ center_loss = self.center_loss(
+ prediction=output[f'bev_instance_center_{downsampling_factor}'],
+ target=batch[f'center_label_{downsampling_factor}']
+ )
+ offset_loss = self.offset_loss(
+ prediction=output[f'bev_instance_offset_{downsampling_factor}'],
+ target=batch[f'offset_label_{downsampling_factor}']
+ )
+
+ center_loss = self.cfg.INSTANCE_SEG.CENTER_LOSS_WEIGHT * center_loss
+ offset_loss = self.cfg.INSTANCE_SEG.OFFSET_LOSS_WEIGHT * offset_loss
+
+ losses[f'bev_center_{downsampling_factor}'] = discount * self.cfg.LOSSES.WEIGHT_INSTANCE * center_loss
+ # Offset are already discounted in the labels
+ losses[f'bev_offset_{downsampling_factor}'] = self.cfg.LOSSES.WEIGHT_INSTANCE * offset_loss
+
+ if self.cfg.EVAL.RGB_SUPERVISION:
+ for downsampling_factor in [1, 2, 4]:
+ rgb_weight = 0.1
+ discount = 1 / downsampling_factor
+ rgb_loss = self.rgb_loss(
+ prediction=output[f'rgb_{downsampling_factor}'],
+ target=batch[f'rgb_label_{downsampling_factor}'],
+ )
+
+ if self.cfg.LOSSES.RGB_INSTANCE:
+ rgb_instance_loss = self.rgb_instance_loss(
+ prediction=output[f'rgb_{downsampling_factor}'],
+ target=batch[f'rgb_label_{downsampling_factor}'],
+ instance_mask=batch[f'image_instance_mask_{downsampling_factor}']
+ )
+ else:
+ rgb_instance_loss = 0
+
+ if self.cfg.LOSSES.SSIM:
+ ssim_loss = 1 - self.ssim_loss(
+ prediction=output[f'rgb_{downsampling_factor}'],
+ target=batch[f'rgb_label_{downsampling_factor}'],
+ )
+ ssim_weight = 0.6
+ losses[f'ssim_{downsampling_factor}'] = rgb_weight * discount * ssim_loss * ssim_weight
+
+ losses[f'rgb_{downsampling_factor}'] = \
+ rgb_weight * discount * (rgb_loss + 0.5 * rgb_instance_loss)
+
+ if self.cfg.LIDAR_RE.ENABLED:
+ for downsampling_factor in [1, 2, 4]:
+ discount = 1 / downsampling_factor
+ lidar_re_loss = self.lidar_re_loss(
+ prediction=output[f'lidar_reconstruction_{downsampling_factor}'][:, :, :3, :, :],
+ target=batch[f'range_view_label_{downsampling_factor}'][:, :, :3, :, :]
+ )
+ lidar_depth_loss = self.lidar_depth_loss(
+ prediction=output[f'lidar_reconstruction_{downsampling_factor}'][:, :, -1:, :, :],
+ target=batch[f'range_view_label_{downsampling_factor}'][:, :, -1:, :, :]
+ )
+ losses[f'lidar_re_{downsampling_factor}'] = lidar_re_loss * discount * self.cfg.LOSSES.WEIGHT_LIDAR_RE
+ losses[
+ f'lidar_depth_{downsampling_factor}'] = lidar_depth_loss * discount * self.cfg.LOSSES.WEIGHT_LIDAR_RE
+
+ if self.cfg.LIDAR_SEG.ENABLED:
+ for downsampling_factor in [1, 2, 4]:
+ discount = 1 / downsampling_factor
+ lidar_seg_loss = self.lidar_seg_loss(
+ prediction=output[f'lidar_segmentation_{downsampling_factor}'],
+ target=batch[f'range_view_seg_label_{downsampling_factor}']
+ )
+ losses[f'lidar_seg_{downsampling_factor}'] = \
+ lidar_seg_loss * discount * self.cfg.LOSSES.WEIGHT_LIDAR_SEG
+
+ if self.cfg.SEMANTIC_IMAGE.ENABLED:
+ for downsampling_factor in [1, 2, 4]:
+ discount = 1 / downsampling_factor
+ sem_image_loss = self.sem_image_loss(
+ prediction=output[f'semantic_image_{downsampling_factor}'],
+ target=batch[f'semantic_image_label_{downsampling_factor}']
+ )
+ losses[f'semantic_image_{downsampling_factor}'] = \
+ sem_image_loss * discount * self.cfg.LOSSES.WEIGHT_SEM_IMAGE
+
+ if self.cfg.DEPTH.ENABLED:
+ for downsampling_factor in [1, 2, 4]:
+ discount = 1 / downsampling_factor
+ depth_image_loss = self.depth_image_loss(
+ prediction=output[f'depth_{downsampling_factor}'],
+ target=batch[f'depth_label_{downsampling_factor}']
+ )
+ losses[f'depth_{downsampling_factor}'] = \
+ depth_image_loss * discount * self.cfg.LOSSES.WEIGHT_DEPTH
+
+ if self.cfg.VOXEL_SEG.ENABLED:
+ for downsampling_factor in [1, 2, 4]:
+ discount = 1 / downsampling_factor
+ voxel_loss = self.voxel_loss(
+ prediction=output[f'voxel_{downsampling_factor}'],
+ target=batch[f'voxel_label_{downsampling_factor}'].type(torch.long)
+ )
+ sem_scal_loss = self.sem_scal_loss(
+ prediction=output[f'voxel_{downsampling_factor}'],
+ target=batch[f'voxel_label_{downsampling_factor}']
+ )
+ geo_scal_loss = self.geo_scal_loss(
+ prediction=output[f'voxel_{downsampling_factor}'],
+ target=batch[f'voxel_label_{downsampling_factor}']
+ )
+ losses[f'voxel_{downsampling_factor}'] = discount * self.cfg.LOSSES.WEIGHT_VOXEL * voxel_loss
+ losses[f'sem_scal_{downsampling_factor}'] = discount * self.cfg.LOSSES.WEIGHT_VOXEL * sem_scal_loss
+ losses[f'geo_scal_{downsampling_factor}'] = discount * self.cfg.LOSSES.WEIGHT_VOXEL * geo_scal_loss
+
+ if self.cfg.MODEL.REWARD.ENABLED:
+ reward_loss = self.action_loss(output['reward'], batch['reward'])
+ losses['reward'] = self.cfg.LOSSES.WEIGHT_REWARD * reward_loss
+ return losses
+
+ def training_step(self, batch, batch_idx):
+ if batch_idx == self.cfg.STEPS and self.cfg.MODEL.TRANSITION.ENABLED:
+ print('!' * 50)
+ print('ACTIVE INFERENCE ACTIVATED')
+ print('!' * 50)
+ self.model.rssm.active_inference = True
+ losses, output, _, _ = self.shared_step(batch, mode='train')
+
+ self.logging_and_visualisation(batch, output, [], losses, None, batch_idx, prefix='train')
+
+ return self.loss_reducing(losses)
+
+ def validation_step(self, batch, batch_idx, dataloader_idx):
+ self.train()
+ for module in self.modules():
+ if isinstance(module, torch.nn.Dropout):
+ module.eval()
+ with torch.no_grad():
+ loss, output, loss_imagines, output_imagines = self.shared_step(batch, mode='val', predict_action=False)
+ self.eval()
+
+ batch_rf = {key: value[:, :self.rf] for key, value in batch.items()} # dim (b, s, 512)
+ batch_fh = {key: value[:, self.rf:] for key, value in batch.items()} # dim (b, s, 512)
+ self.add_metrics(self.metrics_vals[dataloader_idx], batch_rf, output)
+ for output_imagine in output_imagines:
+ self.add_metrics(self.metrics_vals_imagine[dataloader_idx], batch_fh, output_imagine)
+
+ self.logging_and_visualisation(batch, output, output_imagines, loss, loss_imagines,
+ batch_idx, prefix=f'val{dataloader_idx}')
+
+ return {f'val{dataloader_idx}_loss': self.loss_reducing(loss),
+ f'val{dataloader_idx}_loss_imagine':
+ sum([self.loss_reducing(loss_imagine) for loss_imagine in loss_imagines]) / len(loss_imagines)}
+
+ def add_metrics(self, metrics, batch, output):
+ if self.cfg.SEMANTIC_SEG.ENABLED:
+ seg_prediction = output['bev_segmentation_1'].detach()
+ seg_prediction = torch.argmax(seg_prediction, dim=2)
+ metrics['iou'](
+ seg_prediction.view(-1).cpu(),
+ batch['birdview_label'].view(-1).cpu()
+ )
+
+ if self.cfg.EVAL.RGB_SUPERVISION:
+ metrics['ssim'].add_batch(
+ prediction=output[f'rgb_1'].detach(),
+ target=batch[f'rgb_label_1'],
+ )
+ metrics['psnr'].add_batch(
+ prediction=output[f'rgb_1'].detach(),
+ target=batch[f'rgb_label_1'],
+ )
+
+ if self.cfg.LIDAR_RE.ENABLED:
+ lidar_target = batch['range_view_label_1']
+ lidar_pred = output['lidar_reconstruction_1'].detach()
+
+ pcd_target = lidar_target.detach().permute(0, 1, 3, 4, 2).flatten(2, 3).flatten(0, 1) \
+ * self.cfg.LIDAR_RE.SCALE
+ pcd_pred = lidar_pred.detach().permute(0, 1, 3, 4, 2).flatten(2, 3).flatten(0, 1) \
+ * self.cfg.LIDAR_RE.SCALE
+ index = np.random.randint(0, pcd_target.size(-2), 10000)
+ metrics['cd'].add_batch(pcd_pred[:, index, :-1], pcd_target[:, index, :-1])
+
+ # pcd_target = lidar_target.detach().permute(0, 1, 3, 4, 2).flatten(0, 1) \
+ # * self.cfg.LIDAR_RE.SCALE
+ # valid_target = pcd_target[..., -1] > 0
+ # pcd_pred = lidar_pred.detach().permute(0, 1, 3, 4, 2).flatten(0, 1) \
+ # * self.cfg.LIDAR_RE.SCALE
+ # valid_pred = pcd_pred[..., -1] > 0
+ # metrics['cd'].add_batch(pcd_pred[..., :-1], pcd_target[..., :-1], valid_pred, valid_target)
+
+ if self.cfg.LIDAR_SEG.ENABLED:
+ pcd_sem_prediction = output['lidar_segmentation_1'].detach()
+ pcd_sem_prediction = torch.argmax(pcd_sem_prediction, dim=2)
+ metrics['pcd_iou'](
+ pcd_sem_prediction.view(-1).cpu(),
+ batch['range_view_seg_label_1'].view(-1).cpu()
+ )
+
+ if self.cfg.SEMANTIC_IMAGE.ENABLED:
+ image_sem_prediction = output['semantic_image_1'].detach()
+ image_sem_prediction = torch.argmax(image_sem_prediction, dim=2)
+ metrics['image_iou'](
+ image_sem_prediction.view(-1).cpu(),
+ batch['semantic_image_label_1'].reshape(-1).cpu()
+ )
+
+ if self.cfg.VOXEL_SEG.ENABLED:
+ self.compute_ssc_metrics(batch, output, metrics['ssc'])
+
+ def compute_ssc_metrics(self, batch, output, metric):
+ y_true = batch['voxel_label_1']
+ y_pred = output['voxel_1'].detach()
+ b, s, c, x, y, z = y_pred.shape
+ y_pred = y_pred.reshape(b * s, c, x, y, z)
+ y_true = y_true.reshape(b * s, x, y, z)
+ y_pred = torch.argmax(y_pred, dim=1)
+ metric.add_batch(y_pred, y_true)
+
+ def logging_and_visualisation(self, batch, output, output_imagine, loss, loss_imagines, batch_idx, prefix='train'):
+ # Logging
+ self.log('-global_step', torch.tensor(-self.global_step, dtype=torch.float32))
+ for key, value in loss.items():
+ self.log(f'{prefix}_{key}', value)
+ if loss_imagines:
+ for key, value in loss_imagines[0].items():
+ self.log(f'{prefix}_{key}_imagine', value)
+
+ # Visualisation
+ if prefix == 'train':
+ visualisation_criteria = (self.global_step % self.cfg.LOG_VIDEO_INTERVAL == 0) \
+ & (self.global_step != self.vis_step)
+ self.vis_step = self.global_step
+ else:
+ visualisation_criteria = batch_idx == 0
+ if visualisation_criteria:
+ self.visualise(batch, output, output_imagine, batch_idx, prefix=prefix)
+
+ def loss_reducing(self, loss):
+ total_loss = sum([x for x in loss.values()])
+ return total_loss
+
+ def on_validation_epoch_end(self):
+ self.log_metrics(self.metrics_vals, 'val')
+ self.log_metrics(self.metrics_vals_imagine, 'val_imagine')
+
+ def log_metrics(self, metrics_list, metrics_type):
+ class_names = ['Background', 'Road', 'Lane marking', 'Vehicle', 'Pedestrian', 'Green light', 'Yellow light',
+ 'Red light and stop sign']
+ class_names_voxel = list(VOXEL_LABEL.values())
+ for idx, metrics in enumerate(metrics_list):
+ prefix = f'{metrics_type}{idx}'
+ if self.cfg.SEMANTIC_SEG.ENABLED:
+ scores = metrics['iou'].compute()
+ for key, value in zip(class_names, scores):
+ self.logger.experiment.add_scalar(f'{prefix}_bev_iou_' + key, value, global_step=self.global_step)
+ self.logger.experiment.add_scalar(f'{prefix}_bev_mean_iou', torch.mean(scores), global_step=self.global_step)
+ metrics['iou'].reset()
+
+ if self.cfg.EVAL.RGB_SUPERVISION:
+ self.log(f'{prefix}_ssim', metrics['ssim'].get_stat())
+ metrics['ssim'].reset()
+ self.log(f'{prefix}_psnr', metrics['psnr'].get_stat())
+ metrics['psnr'].reset()
+
+ if self.cfg.LIDAR_RE.ENABLED:
+ self.log(f'{prefix}_chamfer_distance', metrics['cd'].get_stat())
+ metrics['cd'].reset()
+
+ if self.cfg.LIDAR_SEG.ENABLED:
+ scores_pcd = metrics['pcd_iou'].compute()
+ for key, value in zip(class_names_voxel, scores_pcd):
+ self.logger.experiment.add_scalar(f'{prefix}_lidar_iou_' + key, value, global_step=self.global_step)
+ self.logger.experiment.add_scalar(f'{prefix}_lidar_mean_iou', torch.mean(scores_pcd), global_step=self.global_step)
+ metrics['pcd_iou'].reset()
+
+ if self.cfg.SEMANTIC_IMAGE.ENABLED:
+ scores_img = metrics['image_iou'].compute()
+ for key, value in zip(class_names_voxel, scores_img):
+ self.logger.experiment.add_scalar(f'{prefix}_camera_iou_' + key, value, global_step=self.global_step)
+ self.logger.experiment.add_scalar(f'{prefix}_camera_mean_iou', torch.mean(scores_img), global_step=self.global_step)
+ metrics['image_iou'].reset()
+
+ if self.cfg.VOXEL_SEG.ENABLED:
+ # class_names_voxel = ['Background', 'Road', 'RoadLines', 'Sidewalk', 'Vehicle',
+ # 'Pedestrian', 'TrafficSign', 'TrafficLight', 'Others']
+
+ stats = metrics['ssc'].get_stats()
+ for i, class_name in enumerate(class_names_voxel):
+ self.log(f'{prefix}_Voxel_{class_name}_SemIoU', stats['iou_ssc'][i])
+ self.log(f'{prefix}_Voxel_mIoU', stats["iou_ssc_mean"])
+ self.log(f'{prefix}_Voxel_IoU', stats["iou"])
+ self.log(f'{prefix}_Voxel_Precision', stats["precision"])
+ self.log(f'{prefix}_Voxel_Recall', stats["recall"])
+ metrics['ssc'].reset()
+
+ def visualise(self, batch, output, output_imagines, batch_idx, prefix='train', writer=None):
+ writer = writer if writer else self.logger.experiment
+ s = list(batch.values())[0].shape[1] # total sequence length
+ rf = list(output.values())[-1].shape[1] # receptive field
+
+ name = f'{prefix}_outputs'
+ if prefix != 'train':
+ name = name + f'_{batch_idx}'
+ # global_step = batch_idx if prefix == 'pred' else self.global_step
+ global_step = self.global_step
+
+ if self.cfg.SEMANTIC_SEG.ENABLED:
+
+ # target = batch['birdview_label'][:, :, 0]
+ # pred = torch.argmax(output['bev_segmentation_1'].detach(), dim=-3)
+
+ # colours = torch.tensor(BIRDVIEW_COLOURS, dtype=torch.uint8, device=pred.device)
+
+ # target = colours[target]
+ # pred = colours[pred]
+
+ # # Move channel to third position
+ # target = target.permute(0, 1, 4, 2, 3)
+ # pred = pred.permute(0, 1, 4, 2, 3)
+
+ # visualisation_video = torch.cat([target, pred], dim=-1).detach()
+
+ # # Rotate for visualisation
+ # visualisation_video = torch.rot90(visualisation_video, k=1, dims=[3, 4])
+
+ # name = f'{prefix}_outputs'
+ # if prefix == 'val':
+ # name = name + f'_{batch_idx}'
+ # self.logger.experiment.add_video(name, visualisation_video, global_step=self.global_step, fps=2)
+
+ target = batch['birdview_label'][:, :, 0].cpu()
+ pred = torch.argmax(output['bev_segmentation_1'].detach().cpu(), dim=-3)
+ bev_imagines = []
+ if output_imagines:
+ # multi samples of future
+ for imagine in output_imagines:
+ bev_imagines.append(torch.argmax(imagine['bev_segmentation_1'].detach().cpu(), dim=-3))
+ else:
+ bev_imagines.append(None)
+
+ colours = torch.tensor(BIRDVIEW_COLOURS, dtype=torch.uint8, device=pred.device) / 255.0
+
+ target = colours[target]
+ # pred = colours[pred]
+
+ # Move channel to third position and add white border
+ target = F.pad(target.permute(0, 1, 4, 2, 3), [2, 2, 2, 2], 'constant', 0.8)
+ # pred = F.pad(pred.permute(0, 1, 4, 2, 3), [2, 2, 2, 2], 'constant', 0.8)
+ preds = []
+ # put reconstruction and all imaginations together
+ for i, bev_imagine in enumerate(bev_imagines):
+ bev_receptive = pred if i == 0 else torch.zeros_like(pred)
+ p_i = bev_receptive if bev_imagine is None else torch.cat([bev_receptive, bev_imagine], dim=1)
+ p_i = colours[p_i]
+ p_i = F.pad(p_i.permute(0, 1, 4, 2, 3), [2, 2, 2, 2], 'constant', 0.8)
+ preds.append(p_i)
+
+ bev = torch.cat([*preds[::-1], target], dim=-1).detach()
+ # Rotation for Visualization
+ bev = torch.rot90(bev, k=1, dims=[3, 4])
+
+ b, _, c, h, w = bev.size()
+
+ visualisation_bev = []
+ for step in range(s):
+ if step == rf:
+ # separate the receptive filed and future horizon
+ visualisation_bev.append(torch.ones(b, c, h, int(w / 4), device=pred.device))
+ visualisation_bev.append(bev[:, step])
+ visualisation_bev = torch.cat(visualisation_bev, dim=-1).detach()
+
+ name_ = f'{name}_bev'
+ writer.add_images(name_, visualisation_bev, global_step=global_step)
+
+ if self.cfg.EVAL.RGB_SUPERVISION:
+ # rgb_target = batch['rgb_label_1']
+ # rgb_pred = output['rgb_1'].detach()
+
+ # visualisation_rgb = torch.cat([rgb_pred, rgb_target], dim=-2).detach()
+ # name_ = f'{name}_rgb'
+ # writer.add_video(name_, visualisation_rgb, global_step=global_step, fps=2)
+
+ rgb_target = batch['rgb_label_1'].cpu()
+ rgb_pred = output['rgb_1'].detach().cpu()
+ rgb_imagines = []
+ if output_imagines:
+ for imagine in output_imagines:
+ rgb_imagines.append(imagine['rgb_1'].detach().cpu())
+ else:
+ rgb_imagines.append(None)
+
+ b, _, c, h, w = rgb_target.size()
+
+ rgb_preds = []
+ for i, rgb_imagine in enumerate(rgb_imagines):
+ rgb_receptive = rgb_pred if i == 0 else torch.ones_like(rgb_pred)
+ pred_imagine = rgb_receptive if rgb_imagine is None else torch.cat([rgb_receptive, rgb_imagine], dim=1)
+ rgb_preds.append(F.pad(pred_imagine, [5, 5, 5, 5], 'constant', 0.8))
+
+ rgb_target = F.pad(rgb_target, [5, 5, 5, 5], 'constant', 0.8)
+ # rgb_pred = F.pad(rgb_pred, [5, 5, 5, 5], 'constant', 0.8)
+
+ acc = batch['throttle_brake']
+ steer = batch['steering']
+
+ acc_bar = np.ones((b, s, int(h / 4), w + 10, c)).astype(np.uint8) * 255
+ steer_bar = np.ones((b, s, int(h / 4), w + 10, c)).astype(np.uint8) * 255
+
+ red = np.array([200, 0, 0])[None, None]
+ green = np.array([0, 200, 0])[None, None]
+ blue = np.array([0, 0, 200])[None, None]
+ mid = int(w / 2) + 5
+
+ # visualize accelerating and steering. green for throttle, red for brake, blue for steer.
+ for b_idx in range(b):
+ for step in range(s):
+ if acc[b_idx, step] >= 0:
+ acc_bar[b_idx, step, 5: -5, mid: mid + int(w / 2 * acc[b_idx, step]), :] = green
+ cv2.putText(acc_bar[b_idx, step], f'{acc[b_idx, step, 0]:.5f}', (mid - 220, int(h / 8) + 15),
+ cv2.FONT_HERSHEY_DUPLEX, 1.5, (0, 0, 0), 2, cv2.LINE_AA)
+ else:
+ acc_bar[b_idx, step, 5: -5, mid + int(w / 2 * acc[b_idx, step]): mid, :] = red
+ cv2.putText(acc_bar[b_idx, step], f'{acc[b_idx, step, 0]:.5f}', (mid + 10, int(h / 8) + 15),
+ cv2.FONT_HERSHEY_DUPLEX, 1.5, (0, 0, 0), 2, cv2.LINE_AA)
+ if steer[b_idx, step] >= 0:
+ steer_bar[b_idx, step, 5: -5, mid: mid + int(w / 2 * steer[b_idx, step]), :] = blue
+ cv2.putText(steer_bar[b_idx, step], f'{steer[b_idx, step, 0]:.5f}',
+ (mid - 220, int(h / 8) + 15),
+ cv2.FONT_HERSHEY_DUPLEX, 1.5, (0, 0, 0), 2, cv2.LINE_AA)
+ else:
+ steer_bar[b_idx, step, 5: -5, mid + int(w / 2 * steer[b_idx, step]): mid, :] = blue
+ cv2.putText(steer_bar[b_idx, step], f'{steer[b_idx, step, 0]:.5f}', (mid + 10, int(h / 8) + 15),
+ cv2.FONT_HERSHEY_DUPLEX, 1.5, (0, 0, 0), 2, cv2.LINE_AA)
+ acc_bar = torch.tensor(acc_bar.transpose((0, 1, 4, 2, 3)),
+ dtype=torch.float, device=rgb_pred.device) / 255.0
+ steer_bar = torch.tensor(steer_bar.transpose((0, 1, 4, 2, 3)),
+ dtype=torch.float, device=rgb_pred.device) / 255.0
+
+ rgb = torch.cat([acc_bar, steer_bar, rgb_target, *rgb_preds], dim=-2)
+ visualisation_rgb = []
+ for step in range(s):
+ if step == rf:
+ visualisation_rgb.append(torch.ones(b, c, rgb.size(-2), int(w / 4), device=rgb_pred.device))
+ visualisation_rgb.append(rgb[:, step, ...])
+ visualisation_rgb = torch.cat(visualisation_rgb, dim=-1).detach()
+
+ name_ = f'{name}_rgb'
+ writer.add_images(name_, visualisation_rgb, global_step=global_step)
+
+ ###################################
+ # visualize optical flow of rgb images.
+ flows = []
+ rgb_target_np = (rgb_target.detach().cpu().numpy().transpose(0, 1, 3, 4, 2) * 255).astype(np.uint8)
+ rgb_preds_np = [(rgb_pred_.detach().cpu().numpy().transpose(0, 1, 3, 4, 2) * 255).astype(np.uint8)
+ for rgb_pred_ in rgb_preds]
+ for bs in range(rgb.size(0)):
+ flows.append(list())
+ for step in range(1, rgb.size(1)):
+ img1_target = rgb_target_np[bs, step - 1][5: -5, 5: -5]
+ img2_target = rgb_target_np[bs, step][5: -5, 5: -5]
+ # use color to present flow
+ flow_target = self.get_color_coded_flow(img1_target, img2_target)
+ flow_target = F.pad(flow_target, [5, 5, 5, 5], 'constant', 0.8)
+
+ flow_preds = []
+ for i, rgb_pred_np in enumerate(rgb_preds_np):
+ img1_pred = rgb_pred_np[bs, step - 1][5: -5, 5: -5]
+ if i == rf:
+ img1_pred = rgb_preds_np[0][bs, step - 1][5: -5, 5: -5]
+ img2_pred = rgb_pred_np[bs, step][5: -5, 5: -5]
+ flow_pred = self.get_color_coded_flow(img1_pred, img2_pred)
+ flow_pred = F.pad(flow_pred, [5, 5, 5, 5], 'constant', 0.8)
+ flow_preds.append(flow_pred)
+
+ flows[bs].append(torch.cat([flow_target, *flow_preds], dim=1))
+
+ visualisation_flow = torch.stack([torch.cat(flow, dim=-1) for flow in flows], dim=0)
+ name_ = f'{name}_flow'
+ writer.add_images(name_, visualisation_flow, global_step=global_step)
+
+ if self.cfg.LIDAR_RE.ENABLED:
+ lidar_target = batch['range_view_label_1'].cpu()
+ lidar_pred = output['lidar_reconstruction_1'].detach().cpu()
+ # lidar_imagine = output_imagine[0]['lidar_reconstruction_1'].detach()
+ if output_imagines:
+ lidar_imagines = [imagine['lidar_reconstruction_1'].detach().cpu() for imagine in output_imagines]
+ lidar_pred_imagine = torch.cat([lidar_pred, lidar_imagines[0]], dim=1)
+ else:
+ lidar_imagines = [None]
+ lidar_pred_imagine = lidar_pred
+
+ visualisation_lidar = torch.cat(
+ [lidar_target[:, :, -1, :, :], lidar_pred_imagine[:, :, -1, :, :]],
+ dim=-2).detach().unsqueeze(-3)
+ name_ = f'{name}_lidar'
+ writer.add_video(name_, visualisation_lidar, global_step=global_step, fps=2)
+
+ # get the bird-eye-view of point cloud
+ pcd_image_target, pcd_target, valid_target = self.pcd_xy_image(lidar_target)
+ pcd_image_target = F.pad(pcd_image_target, [2, 2, 2, 2], 'constant', 0.2)
+
+ pcd_image_pred, pcd_pred, valid_pred = self.pcd_xy_image(lidar_pred)
+
+ pcd_image_preds = []
+ pcd_preds = []
+ valid_preds = []
+ for i, lidar_imagine in enumerate(lidar_imagines):
+ pcd_image_receptive = pcd_image_pred if i == 0 else torch.ones_like(pcd_image_pred)
+ if lidar_imagine is None:
+ pcd_image_pred_imagine = pcd_image_receptive
+ pcd_pred_imagine = pcd_pred
+ valid_pred_imagine = valid_pred
+ else:
+ pcd_image_imagine, pcd_imagine, valid_imagine = self.pcd_xy_image(lidar_imagine)
+ pcd_image_pred_imagine = torch.cat([pcd_image_receptive, pcd_image_imagine], dim=1)
+ pcd_pred_imagine = np.concatenate([pcd_pred, pcd_imagine], axis=1)
+ valid_pred_imagine = np.concatenate([valid_pred, valid_imagine], axis=1)
+ pcd_image_preds.append(F.pad(pcd_image_pred_imagine, [2, 2, 2, 2], 'constant', 0.2))
+ pcd_preds.append(pcd_pred_imagine)
+ valid_preds.append(valid_pred_imagine)
+
+ pcd_image = torch.cat([pcd_image_target, *pcd_image_preds], dim=-2)
+ b, _, c, h, w = pcd_image.size()
+
+ visualisation_pcd = []
+ for step in range(s):
+ if step == rf:
+ visualisation_pcd.append(torch.ones(b, c, h, int(w / 4), device=pcd_image.device))
+ visualisation_pcd.append(pcd_image[:, step])
+ visualisation_pcd = torch.cat(visualisation_pcd, dim=-1).detach()
+
+ name_ = f'{name}_pcd_xy'
+ writer.add_images(name_, visualisation_pcd, global_step=global_step)
+
+ # calculate the ego-vehicle trajectory from point cloud and visualize it
+ visualisation_traj = []
+ for bs in range(pcd_target.shape[0]):
+ path_target = [{'Rot': np.eye(3), 'pos': np.zeros((3, 1))}]
+ traj_target = np.pad(np.zeros((192, 192)), pad_width=2, mode='constant', constant_values=50)
+ traj_target = np.tile(traj_target[..., None], (1, 1, 3))
+ traj_target = self.plot_traj(path_target, traj_target)
+ path_preds = []
+ traj_preds = []
+ for i in range(len(pcd_preds)):
+ path_pred = [{'Rot': np.eye(3), 'pos': np.zeros((3, 1))}]
+ # traj_pred = np.pad(np.zeros((192, 192)), pad_width=2, mode='constant', constant_values=50)
+ # traj_pred = np.tile(traj_pred[..., None], (1, 1, 3))
+ # traj_pred = self.plot_traj(path_pred, traj_pred)
+ path_preds.append(path_pred)
+ traj_preds.append(traj_target.copy())
+ for step in range(1, pcd_target.shape[1]):
+ pcd1 = pcd_target[bs, step - 1][valid_target[bs, step - 1]][:, :3]
+ pcd2 = pcd_target[bs, step][valid_target[bs, step]][:, :3]
+ _, Rt_target = compute_pcd_transformation(pcd1, pcd2, path_target[-1], threshold=5)
+ path_target.append(Rt_target)
+ traj_target = self.plot_traj(path_target, traj_target)
+
+ for j in range(len(pcd_preds)):
+ pcd1 = pcd_preds[j][bs, step - 1][valid_preds[j][bs, step - 1]][:, :3]
+ pcd2 = pcd_preds[j][bs, step][valid_preds[j][bs, step]][:, :3]
+ _, Rt_pred = compute_pcd_transformation(pcd1, pcd2, path_preds[j][-1], threshold=5)
+ path_preds[j].append(Rt_pred)
+ traj_preds[j] = self.plot_traj(path_preds[j], traj_preds[j])
+
+ traj = np.concatenate([traj_target, *traj_preds], axis=1).transpose((2, 0, 1))[None]
+ visualisation_traj.append(torch.tensor(traj, device=lidar_target.device, dtype=torch.float))
+ visualisation_traj = torch.cat(visualisation_traj, dim=0) / 255.0
+ name_ = f'{name}_traj'
+ writer.add_images(name_, visualisation_traj, global_step=global_step)
+
+ if self.cfg.LIDAR_SEG.ENABLED:
+ lidar_seg_target = batch['range_view_seg_label_1'][:, :, 0].cpu()
+ lidar_seg_pred = torch.argmax(output['lidar_segmentation_1'].detach().cpu(), dim=-3)
+ lidar_seg_imagines = []
+ if output_imagines:
+ for imagine in output_imagines:
+ lidar_seg_imagines.append(torch.argmax(imagine['lidar_segmentation_1'].detach().cpu(), dim=-3))
+ else:
+ lidar_seg_imagines.append(None)
+
+ colours = torch.tensor(VOXEL_COLOURS, dtype=torch.uint8, device=lidar_seg_pred.device) / 255.0
+
+ lidar_seg_target = colours[lidar_seg_target]
+ lidar_seg_target = F.pad(lidar_seg_target.permute(0, 1, 4, 2, 3), [3, 3, 3, 3], 'constant', 0.8)
+
+ lidar_seg_preds = []
+
+ for i, lidar_seg_imagine in enumerate(lidar_seg_imagines):
+ lidar_seg_receptive = lidar_seg_pred if i == 0 else torch.zeros_like(lidar_seg_pred)
+ lidar_seg_i = lidar_seg_receptive if lidar_seg_imagine is None else torch.cat([lidar_seg_receptive, lidar_seg_imagine], dim=1)
+ lidar_seg_i = colours[lidar_seg_i]
+ lidar_seg_i = F.pad(lidar_seg_i.permute(0, 1, 4, 2, 3), [3, 3, 3, 3], 'constant', 0.8)
+ lidar_seg_preds.append(lidar_seg_i)
+
+ lidar_seg = torch.cat([lidar_seg_target, torch.ones_like(lidar_seg_target[:, -1:, ...]), *lidar_seg_preds], dim=1).detach()
+ visualisation_lidar_seg = lidar_seg.transpose(1, 2).flatten(2, 3)
+
+ name_ = f'{name}_lidar_seg'
+ writer.add_images(name_, visualisation_lidar_seg, global_step=global_step)
+
+ if self.cfg.SEMANTIC_IMAGE.ENABLED:
+ sem_target = batch['semantic_image_label_1'][:, :, 0].cpu()
+ sem_pred = torch.argmax(output['semantic_image_1'].detach().cpu(), dim=-3)
+ sem_imagines = []
+ if output_imagines:
+ for imagine in output_imagines:
+ sem_imagines.append(torch.argmax(imagine['semantic_image_1'].detach().cpu(), dim=-3))
+ else:
+ sem_imagines.append(None)
+
+ colours = torch.tensor(VOXEL_COLOURS, dtype=torch.uint8, device=sem_pred.device) / 255.0
+
+ sem_target = colours[sem_target]
+ sem_target = F.pad(sem_target.permute(0, 1, 4, 2, 3), [5, 5, 5, 5], 'constant', 0.8)
+
+ sem_preds = []
+
+ for i, sem_imagine in enumerate(sem_imagines):
+ sem_receptive = sem_pred if i == 0 else torch.zeros_like(sem_pred)
+ sem_i = sem_receptive if sem_imagine is None else torch.cat([sem_receptive, sem_imagine], dim=1)
+ sem_i = colours[sem_i]
+ sem_i = F.pad(sem_i.permute(0, 1, 4, 2, 3), [5, 5, 5, 5], 'constant', 0.8)
+ sem_preds.append(sem_i)
+
+ sem_image = torch.cat([sem_target, *sem_preds], dim=-2).detach()
+
+ b, _, c, h, w = sem_image.size()
+
+ visualisation_sem_image = []
+ for step in range(s):
+ if step == rf:
+ visualisation_sem_image.append(torch.ones(b, c, h, int(w / 4), device=sem_pred.device))
+ visualisation_sem_image.append(sem_image[:, step])
+ visualisation_sem_image = torch.cat(visualisation_sem_image, dim=-1).detach()
+
+ name_ = f'{name}_sem_image'
+ writer.add_images(name_, visualisation_sem_image, global_step=global_step)
+
+ if self.cfg.DEPTH.ENABLED:
+ depth_target = batch['depth_label_1'].cpu()
+ depth_pred = output['depth_1'].detach().cpu()
+ if output_imagines:
+ depth_imagine = output_imagines[0]['depth_1'].detach().cpu()
+ depth_pred = torch.cat([depth_pred, depth_imagine], dim=1)
+
+ visualisation_depth = torch.cat([depth_pred, depth_target], dim=-2).detach()
+ name_ = f'{name}_depth'
+ writer.add_video(name_, visualisation_depth, global_step=global_step, fps=2)
+
+ if self.cfg.VOXEL_SEG.ENABLED:
+ voxel_target = batch['voxel_label_1'][0, 0, 0].cpu().numpy()
+ voxel_pred = torch.argmax(output['voxel_1'].detach(), dim=-4).cpu().numpy()[0, 0]
+ colours = np.asarray(VOXEL_COLOURS, dtype=float) / 255.0
+ voxel_color_target = colours[voxel_target]
+ voxel_color_pred = colours[voxel_pred]
+ name_ = f'{name}_voxel'
+ self.write_voxel_figure(voxel_target, voxel_color_target, f'{name_}_target', global_step, writer)
+ self.write_voxel_figure(voxel_pred, voxel_color_pred, f'{name_}_pred', global_step, writer)
+ if output_imagines:
+ voxel_imagine_target = batch['voxel_label_1'][0, self.rf, 0].cpu().numpy()
+ voxel_imagine_pred = torch.argmax(output_imagines[0]['voxel_1'].detach(), dim=-4).cpu().numpy()[0, 0]
+ voxel_color_imagine_target = colours[voxel_imagine_target]
+ voxel_color_imagine_pred = colours[voxel_imagine_pred]
+ self.write_voxel_figure(
+ voxel_imagine_target, voxel_color_imagine_target, f'{name_}_imagine_target', global_step, writer)
+ self.write_voxel_figure(
+ voxel_imagine_pred, voxel_color_imagine_pred, f'{name_}_imagine_pred', global_step, writer)
+
+ # visualize route map
+ if self.cfg.MODEL.ROUTE.ENABLED:
+ route_map = batch['route_map'].cpu()
+ route_map = F.pad(route_map, [2, 2, 2, 2], 'constant', 0.8)
+
+ b, _, c, h, w = route_map.size()
+
+ visualisation_route = []
+ for step in range(s):
+ if step == rf:
+ visualisation_route.append(torch.ones(b, c, h, int(w / 4), device=route_map.device))
+ visualisation_route.append(route_map[:, step])
+ visualisation_route = torch.cat(visualisation_route, dim=-1).detach()
+
+ name_ = f'{name}_input_route_map'
+ writer.add_images(name_, visualisation_route, global_step=global_step)
+
+ # render 1 frame voxel
+ def write_voxel_figure(self, voxel, voxel_color, name, global_step, writer):
+ fig = plt.figure(figsize=(10, 10))
+ ax = fig.add_subplot(projection='3d')
+ ax.voxels(voxel, facecolors=voxel_color, shade=False)
+ ax.view_init(elev=60, azim=165)
+ ax.set_axis_off()
+ writer.add_figure(name, fig, global_step=global_step)
+
+ # render trajectory
+ def plot_traj(self, traj, img):
+ x, y, z = traj[-1]['pos']
+ plot_x = int(96 - 5 * y)
+ plot_y = int(96 - 5 * x)
+ x_, y_, _ = traj[-2]['pos'] if len(traj) > 1 else traj[-1]['pos']
+ plot_x_ = int(96 - 5 * y_)
+ plot_y_ = int(96 - 5 * x_)
+ cv2.line(img, (plot_x, plot_y), (plot_x_, plot_y_), (20, 150, 20), 1)
+ cv2.circle(img, (plot_x, plot_y), 2, [150, 20, 20], -2, cv2.LINE_AA)
+ return img
+
+ def pcd_xy_image(self, lidar):
+ image_size = np.array([256, 256])
+ lidar_range = 50
+
+ pcd = lidar.cpu().detach().numpy().transpose(0, 1, 3, 4, 2) * self.cfg.LIDAR_RE.SCALE
+ # pcd_target = pcd_target[..., :-1].flatten(1, 2)
+ # pcd_target = pcd_target[pcd_target[..., -1] > 0][..., :-1]
+ # pcd0 = self.pcd.restore_pcd_coor(lidar[:, :, -1].cpu().numpy() * self.cfg.LIDAR_RE.SCALE)
+ pcd_xy = -pcd[..., :2]
+ pcd_xy *= min(image_size) / (2 * lidar_range)
+ pcd_xy += 0.5 * image_size.reshape((1, 1, 1, 1, -1))
+ # only the point which range > 0 is valid
+ valid = pcd[..., -1] > 0
+
+ b, s, _, _, _ = pcd.shape
+ pcd_xy_image = np.zeros((b, s, *image_size, 3))
+
+ # projection point cloud to xy coordinate (bird-eye-view)
+ for i in range(b):
+ for j in range(s):
+ hw = pcd_xy[i, j][valid[i, j]]
+ hw = hw[(0 < hw[:, 0]) & (hw[:, 0] < image_size[0]) &
+ (0 < hw[:, 1]) & (hw[:, 1] < image_size[1])]
+ hw = np.fabs(hw)
+ hw = hw.astype(np.int32)
+ pcd_xy_image[i, j][tuple(hw.T)] = (1.0, 1.0, 1.0)
+
+ return torch.tensor(pcd_xy_image.transpose((0, 1, 4, 2, 3)), device=lidar.device), pcd, valid
+
+ def get_color_coded_flow(self, img1, img2):
+ img1 = cv2.cvtColor(img1, cv2.COLOR_RGB2GRAY)
+ img2 = cv2.cvtColor(img2, cv2.COLOR_RGB2GRAY)
+ flow = cv2.calcOpticalFlowFarneback(img1, img2, None, 0.5, 3, 15, 3, 5, 1.2, 0)
+
+ hsv = np.zeros((*flow.shape[:2], 3), dtype=np.uint8)
+ hsv[..., 2] = 255
+ mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
+ hsv[..., 0] = ang * (180 / np.pi / 2)
+ hsv[..., 1] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
+ color_coded_flow = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
+ return torch.tensor(color_coded_flow.transpose(2, 0, 1), dtype=torch.float) / 255.0
+
+ def configure_optimizers(self):
+ # frozen the layer that not in train list
+ def frozen_params(model, no_frozen_list=[]):
+ for name, param in model.named_parameters():
+ if not any(name.startswith(layer) for layer in no_frozen_list):
+ param.requires_grad = False
+
+ # Do not decay batch norm parameters and biases
+ # https://discuss.pytorch.org/t/weight-decay-in-the-optimizers-is-a-bad-idea-especially-with-batchnorm/16994/2
+ def add_weight_decay(model, weight_decay=0.01, skip_list=[]):
+ no_decay = []
+ decay = []
+ train_list = []
+ frozen_list = []
+ for name, param in model.named_parameters():
+ if not param.requires_grad:
+ frozen_list.append(name)
+ continue
+ train_list.append(name)
+ if len(param.shape) == 1 or any(x in name for x in skip_list):
+ no_decay.append(param)
+ else:
+ decay.append(param)
+ print(f'train_layers: {train_list}\nfrozen_layers: {frozen_list}')
+ return [
+ {'params': no_decay, 'weight_decay': 0.},
+ {'params': decay, 'weight_decay': weight_decay},
+ ]
+
+ if self.cfg.OPTIMIZER.FROZEN.ENABLED:
+ frozen_params(self.model, self.cfg.OPTIMIZER.FROZEN.TRAIN_LIST)
+
+ parameters = add_weight_decay(
+ self.model,
+ self.cfg.OPTIMIZER.WEIGHT_DECAY,
+ skip_list=['relative_position_bias_table'],
+ )
+ weight_decay = 0.
+ optimizer = torch.optim.AdamW(parameters, lr=self.cfg.OPTIMIZER.LR, weight_decay=weight_decay)
+
+ # scheduler
+ if self.cfg.SCHEDULER.NAME == 'none':
+ lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda lr: 1)
+ elif self.cfg.SCHEDULER.NAME == 'OneCycleLR':
+ lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
+ optimizer,
+ max_lr=self.cfg.OPTIMIZER.LR,
+ total_steps=self.cfg.STEPS,
+ pct_start=self.cfg.SCHEDULER.PCT_START,
+ )
+
+ return [optimizer], [{'scheduler': lr_scheduler, 'interval': 'step'}]
+
+ def on_test_epoch_end(self):
+ self.log_metrics(self.metrics_tests, 'test')
+ self.log_metrics(self.metrics_tests_imagine, 'test_imagine')
+
+ def test_step(self, batch, batch_idx, dataloader_idx):
+ self.train()
+ for module in self.modules():
+ if isinstance(module, torch.nn.Dropout):
+ module.eval()
+ with torch.no_grad():
+ loss, output, loss_imagines, output_imagines = self.shared_step(batch, mode='test', predict_action=False)
+ self.eval()
+
+ batch_rf = {key: value[:, :self.rf] for key, value in batch.items()} # dim (b, s, 512)
+ batch_fh = {key: value[:, self.rf:] for key, value in batch.items()} # dim (b, s, 512)
+ self.add_metrics(self.metrics_tests[dataloader_idx], batch_rf, output)
+ for output_imagine in output_imagines:
+ self.add_metrics(self.metrics_tests_imagine[dataloader_idx], batch_fh, output_imagine)
+
+ self.visualise(batch, output, output_imagines, batch_idx, prefix=f'pred{dataloader_idx}')
+ return output, output_imagines
diff --git a/muvo/utils/carla_utils.py b/muvo/utils/carla_utils.py
new file mode 100644
index 0000000..3eae646
--- /dev/null
+++ b/muvo/utils/carla_utils.py
@@ -0,0 +1,23 @@
+import numpy as np
+
+
+def get_vector3(vec3):
+ return np.array([vec3.x, vec3.y, vec3.z])
+
+
+def get_wheel_base(vehicle):
+ physics_control = vehicle.get_physics_control()
+ wheel_base = np.linalg.norm(
+ get_vector3(physics_control.wheels[0].position - physics_control.wheels[2].position)) / 100
+ return wheel_base
+
+
+def convert_steer_to_curvature(steer, wheel_base):
+ return -np.tan(steer) / wheel_base
+
+
+def gps_dict_to_numpy_array(gps_dict):
+ return np.array(
+ [gps_dict['lat'], gps_dict['lon'], gps_dict['z']],
+ dtype=np.float32
+ )
diff --git a/muvo/utils/geometry_utils.py b/muvo/utils/geometry_utils.py
new file mode 100644
index 0000000..6b72a03
--- /dev/null
+++ b/muvo/utils/geometry_utils.py
@@ -0,0 +1,357 @@
+import numpy as np
+import torch
+import open3d as o3d
+import cv2
+from muvo.data.dataset_utils import preprocess_gps
+
+
+def bev_params_to_intrinsics(size, scale, offsetx):
+ """
+ size: number of pixels (width, height)
+ scale: pixel size (in meters)
+ offsetx: offset in x direction (direction of car travel)
+ """
+ intrinsics_bev = np.array([
+ [1/scale, 0, size[0]/2 + offsetx],
+ [0, -1/scale, size[1]/2],
+ [0, 0, 1]
+ ], dtype=np.float32)
+ return intrinsics_bev
+
+
+def intrinsics_inverse(intrinsics):
+ fx = intrinsics[..., 0, 0]
+ fy = intrinsics[..., 1, 1]
+ cx = intrinsics[..., 0, 2]
+ cy = intrinsics[..., 1, 2]
+ one = torch.ones_like(fx)
+ zero = torch.zeros_like(fx)
+ intrinsics_inv = torch.stack((
+ torch.stack((1/fx, zero, -cx/fx), -1),
+ torch.stack((zero, 1/fy, -cy/fy), -1),
+ torch.stack((zero, zero, one), -1),
+ ), -2)
+ return intrinsics_inv
+
+
+def get_out_of_view_mask(cfg):
+ """ Returns a mask of everything that is not visible from the image given a certain bird's-eye view grid."""
+ fov = cfg.IMAGE.FOV
+ w = cfg.IMAGE.SIZE[1]
+ resolution = cfg.BEV.RESOLUTION
+
+ f = w / (2 * np.tan(fov * np.pi / 360.0))
+ c_u = w / 2 - cfg.IMAGE.CROP[0] # Adjust center point due to cropping
+
+ bev_left = -np.round((cfg.BEV.SIZE[0] // 2) * resolution, decimals=1)
+ bev_right = np.round((cfg.BEV.SIZE[0] // 2) * resolution, decimals=1)
+ bev_bottom = 0.01
+ # The camera is not exactly at the bottom of the bev image, so need to offset it.
+ camera_offset = (cfg.BEV.SIZE[1] / 2 + cfg.BEV.OFFSET_FORWARD) * resolution + cfg.IMAGE.CAMERA_POSITION[0]
+ bev_top = np.round(cfg.BEV.SIZE[1] * resolution - camera_offset, decimals=1)
+
+ x, z = np.arange(bev_left, bev_right, resolution), np.arange(bev_bottom, bev_top, resolution)
+ ucoords = x / z[:, None] * f + c_u
+
+ # Return all points which lie within the camera bounds
+ new_w = cfg.IMAGE.CROP[2] - cfg.IMAGE.CROP[0]
+ mask = (ucoords >= 0) & (ucoords < new_w)
+ mask = ~mask[::-1]
+ mask_behind_ego_vehicle = np.ones((int(camera_offset / resolution), mask.shape[1]), dtype=np.bool)
+ return np.vstack([mask, mask_behind_ego_vehicle])
+
+
+def calculate_geometry(image_fov, height, width, forward, right, up, pitch, yaw, roll):
+ """Intrinsics and extrinsics for a single camera.
+ See https://github.com/bradyz/carla_utils_fork/blob/dynamic-scene/carla_utils/leaderboard/camera.py
+ and https://github.com/bradyz/carla_utils_fork/blob/dynamic-scene/carla_utils/recording/sensors/camera.py
+ """
+ f = width / (2 * np.tan(image_fov * np.pi / 360.0))
+ cx = width / 2
+ cy = height / 2
+ intrinsics = np.float32([[f, 0, cx], [0, f, cy], [0, 0, 1]])
+ extrinsics = get_extrinsics(forward, right, up, pitch, yaw, roll)
+ return intrinsics, extrinsics
+
+
+def get_extrinsics(forward, right, up, pitch, yaw, roll):
+ # After multiplying the image coordinates by in the inverse intrinsics,
+ # the resulting coordinates are defined with the axes (right, down, forward)
+ assert pitch == yaw == roll == 0.0
+
+ # After multiplying by the extrinsics, we want the axis to be (forward, left, up), and centered in the
+ # inertial center of the ego-vehicle.
+ mat = np.float32([
+ [0, 0, 1, forward],
+ [-1, 0, 0, -right],
+ [0, -1, 0, up],
+ [0, 0, 0, 1],
+ ])
+
+ return mat
+
+
+def lidar_to_histogram_features(lidar, cfg, crop=256):
+ """
+ Convert LiDAR point cloud into 2-bin histogram over 256x256 grid
+ """
+
+ # fit the center of the histogram the same as the bev.
+ offset = np.asarray(cfg.VOXEL.EV_POSITION) * cfg.VOXEL.RESOLUTION # ego position relative to min boundary.
+ pixels_per_meter = cfg.POINTS.HISTOGRAM.RESOLUTION
+ hist_max_per_pixel = cfg.POINTS.HISTOGRAM.HIST_MAX
+ x_range = cfg.POINTS.HISTOGRAM.X_RANGE
+ y_range = cfg.POINTS.HISTOGRAM.Y_RANGE
+ z_range = cfg.POINTS.HISTOGRAM.Z_RANGE
+
+ # 256 x 256 grid
+ xbins = np.linspace(
+ -offset[0],
+ -offset[0] + x_range / pixels_per_meter,
+ x_range + 1
+ )
+ ybins = np.linspace(
+ -offset[1],
+ -offset[1] + y_range / pixels_per_meter,
+ y_range + 1,
+ )
+ zbins = np.linspace(
+ -offset[2],
+ -offset[2] + z_range / pixels_per_meter,
+ z_range + 1
+ )
+
+ def splat_points(point_cloud, bins1, bins2):
+ hist = np.histogramdd(point_cloud, bins=(bins1, bins2))[0]
+ hist[hist > hist_max_per_pixel] = hist_max_per_pixel
+ overhead_splat = hist / hist_max_per_pixel
+ # return overhead_splat[::-1, ::-1]
+ return overhead_splat
+
+ # xy plane
+ below = lidar[lidar[..., 2] <= 0][..., :2]
+ middle = lidar[(0 < lidar[..., 2]) & (lidar[..., 2] <= 2.5)][..., :2]
+ above = lidar[lidar[..., 2] > 2.5][..., :2]
+ below_features = splat_points(below, xbins, ybins)
+ middle_features = splat_points(middle, xbins, ybins)
+ above_features = splat_points(above, xbins, ybins)
+ total_features_xy = below_features + middle_features + above_features
+ features_xy = np.stack([below_features, middle_features, above_features, total_features_xy], axis=-1)
+ features_xy = np.transpose(features_xy, (2, 0, 1)).astype(np.float32)
+
+ # xz plane
+ left = lidar[lidar[..., 1] >= 1.5][..., ::2]
+ center = lidar[(-1.5 < lidar[..., 1]) & (lidar[..., 1] < 1.5)][..., ::2]
+ right = lidar[lidar[..., 1] <= -1.5][..., ::2]
+ left_features = splat_points(left, xbins, zbins)
+ center_features = splat_points(center, xbins, zbins)
+ right_features = splat_points(right, xbins, zbins)
+ total_features_xz = left_features + center_features + right_features
+ features_xz = np.stack([left_features, center_features, right_features, total_features_xz], axis=-1)
+ features_xz = np.transpose(features_xz, (2, 0, 1)).astype(np.float32)
+
+ # yz plane
+ behind = lidar[lidar[..., 0] < -2.5][..., 1:]
+ mid = lidar[(-2.5 <= lidar[..., 0]) & (lidar[..., 0] <= 10)][..., 1:]
+ front = lidar[lidar[..., 0] > 10][..., 1:]
+ behind_features = splat_points(behind, ybins, zbins)
+ mid_features = splat_points(mid, ybins, zbins)
+ front_features = splat_points(front, ybins, zbins)
+ total_features_yz = behind_features + mid_features + front_features
+ features_yz = np.stack([behind_features, mid_features, front_features, total_features_yz], axis=-1)
+ features_yz = np.transpose(features_yz, (2, 0, 1)).astype(np.float32)
+ return features_xy, features_xz, features_yz
+
+
+class PointCloud(object):
+ def __init__(self, H=64, W=1024, fov_down=-30, fov_up=10, lidar_position=(1, 0, 2)):
+ self.fov_up = fov_up / 180.0 * np.pi # in rad
+ self.fov_down = fov_down / 180.0 * np.pi
+ self.fov = self.fov_up - self.fov_down
+ self.H = H
+ self.W = W
+ self.lidar_position = np.asarray(lidar_position)
+
+ def do_range_projection(self, points, semantics):
+ # restore points coordinate to original carla's lidar.
+ points_carla = points * np.array([1, -1, 1])
+ points_carla -= self.lidar_position
+
+ depth = np.linalg.norm(points_carla, 2, axis=1)
+
+ x = points_carla[:, 0]
+ y = -points_carla[:, 1] # carla-coor is left-hand.
+ z = points_carla[:, 2]
+
+ yaw = np.arctan2(y, x)
+ pitch = np.arcsin(z / depth)
+
+ proj_w = 0.5 * (1.0 - yaw / np.pi)
+ proj_h = 1.0 - (pitch + abs(self.fov_down)) / self.fov
+ proj_w *= self.W
+ proj_h *= self.H
+
+ proj_w = np.floor(proj_w)
+ proj_w = np.minimum(self.W - 1, proj_w)
+ proj_w = np.maximum(0, proj_w).astype(np.int32)
+
+ proj_h = np.floor(proj_h)
+ proj_h = np.minimum(self.H - 1, proj_h)
+ proj_h = np.maximum(0, proj_h).astype(np.int32)
+
+ # After sorting by depth from largest to smallest, close point will rewrite distant point in the same pixel.
+ order = np.argsort(depth)[::-1]
+ depth = depth[order]
+ proj_w = proj_w[order]
+ proj_h = proj_h[order]
+ points = points[order]
+ semantics = semantics[order]
+
+ range_depth = np.full((self.H, self.W), -1, dtype=np.float32)
+ range_xyz = np.full((self.H, self.W, 3), 0, dtype=np.float32)
+ range_sem = np.full((self.H, self.W), 0, dtype=np.uint8)
+
+ # points += self.lidar_position
+ # points[:, 1] *= -1
+
+ range_depth[proj_h, proj_w] = depth
+ range_xyz[proj_h, proj_w] = points
+ range_sem[proj_h, proj_w] = semantics
+ return range_depth, range_xyz, range_sem
+
+ # re-projection range-view pcd to original coordinate.
+ def restore_pcd_coor(self, range_depth):
+ h, w = np.arange(0, self.H), np.arange(0, self.W)
+ proj_w, proj_h = np.meshgrid(w, h)
+ # valid = range_depth > 0
+ proj_w = proj_w.astype(float)[None, None, ...]
+ proj_h = proj_h.astype(float)[None, None, ...]
+ depth = range_depth
+
+ proj_w /= self.W
+ proj_h /= self.H
+ pitch = (1.0 - proj_h) * self.fov - abs(self.fov_down)
+ yaw = (1.0 - proj_w / 0.5) * np.pi
+
+ z = depth * np.sin(pitch)
+ depth_ = depth * np.cos(pitch)
+ x = depth_ * np.cos(yaw)
+ y = depth_ * np.sin(yaw)
+
+ points = np.concatenate([x[..., None], -y[..., None], z[..., None]], axis=-1)
+ points += self.lidar_position.reshape((1, 1, 1, 1, -1))
+ points *= np.array([1, -1, 1]).reshape((1, 1, 1, 1, -1))
+ return np.concatenate([points, depth[..., None]], axis=-1)
+
+
+# use open3d to calculate the transformation between two point clouds.
+def compute_pcd_transformation(pcd1, pcd2, Rt, threshold=0.02):
+ if len(pcd1) > 0 and len(pcd2) > 0:
+ source = o3d.geometry.PointCloud()
+ source.points = o3d.utility.Vector3dVector(pcd2)
+ target = o3d.geometry.PointCloud()
+ target.points = o3d.utility.Vector3dVector(pcd1)
+ reg_p2p = o3d.pipelines.registration.registration_icp(
+ source, target, threshold, np.eye(4),
+ o3d.pipelines.registration.TransformationEstimationPointToPoint(),
+ o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=2000))
+ transformation = reg_p2p.transformation
+ else:
+ transformation = np.eye(4)
+
+ R = transformation[:3, :3]
+ t = transformation[:3, -1:]
+ Rot = R @ Rt['Rot']
+ pos = Rt['pos'] + Rt['Rot'] @ t
+
+ return transformation, {'Rot': Rot, 'pos': pos}
+
+#
+# def find_motion_optimized(img1, img2):
+# # 转换为灰度图
+# gray1 = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
+# gray2 = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
+#
+# # 使用ORB找到关键点和描述符
+# orb = cv2.ORB_create()
+# kp1, des1 = orb.detectAndCompute(gray1, None)
+# kp2, des2 = orb.detectAndCompute(gray2, None)
+#
+# # 使用BFMatcher进行描述符匹配
+# bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)
+# matches = bf.match(des1, des2)
+#
+# # 按距离排序
+# matches = sorted(matches, key=lambda x: x.distance)
+#
+# # 选择最佳匹配
+# good_matches = matches[:int(len(matches) * 0.15)] # 取前15%的匹配点
+#
+# # 获取匹配点的坐标
+# src_pts = np.float32([kp1[m.queryIdx].pt for m in good_matches]).reshape(-1, 1, 2)
+# dst_pts = np.float32([kp2[m.trainIdx].pt for m in good_matches]).reshape(-1, 1, 2)
+#
+# # 计算基础矩阵并使用RANSAC进行鲁棒估计
+# F, mask = cv2.findFundamentalMat(src_pts, dst_pts, cv2.FM_RANSAC, 3, 0.99)
+#
+# # 选择仅在内部点上操作的点
+# src_pts = src_pts[mask.ravel() == 1]
+# dst_pts = dst_pts[mask.ravel() == 1]
+#
+# # 计算相机的旋转和平移
+# _, R, t, _ = cv2.recoverPose(F, src_pts, dst_pts)
+#
+# pts1, pts2 = src_pts, dst_pts
+# P1 = np.hstack((np.eye(3), np.zeros((3, 1)))) # 第一相机的投影矩阵
+# P2 = np.hstack((R, t)) # 第二相机的投影矩阵
+#
+# # 使用线性三角测量方法获取3D点
+# points_3D = cv2.triangulatePoints(P1, P2, pts1.T, pts2.T).T
+# points_3D /= points_3D[:, 3, None] # 归一化齐次坐标
+#
+# # Step 3: Bundle Adjustment
+# # 假设你已经定义了optimize_with_bundle_adjustment函数
+# camera_params = np.hstack((cv2.Rodrigues(R)[0], t)) # 将旋转矩阵转换为旋转向量
+# optimized_3D_points, optimized_camera_params = optimize_with_bundle_adjustment(
+# points_3D[:, :3],
+# np.array([np.zeros(6), camera_params]), # 初始相机参数
+# np.array([pts1, pts2]) # 观测到的2D点
+# )
+#
+# optimized_rot_vector = optimized_camera_params[1, :3]
+# optimized_t = optimized_camera_params[1, 3:]
+#
+# # 使用cv2.Rodrigues()将旋转向量转换为旋转矩阵
+# optimized_R, _ = cv2.Rodrigues(optimized_rot_vector)
+#
+# return optimized_R, optimized_t
+#
+#
+# class BundleAdjustmentProblem(cv2.optim.SimpleBundleAdjuster):
+# def __init__(self, _3D_points, camera_params, _2D_points):
+# super(BundleAdjustmentProblem, self).__init__()
+#
+# # 添加3D点
+# for point in _3D_points:
+# self.addPoint(point)
+#
+# # 添加相机参数和2D点
+# for cam_param, points_2D in zip(camera_params, _2D_points):
+# rvec, tvec = cam_param[:3], cam_param[3:]
+# for point_2D in points_2D:
+# self.addCamera(rvec, tvec, point_2D)
+#
+# # 设置相机内参(例如,焦距)
+# self.setFocalLength(800) # 假设焦距为800
+#
+#
+# # 使用Bundle Adjustment优化3D点和相机参数
+# def optimize_with_bundle_adjustment(_3D_points, camera_params, _2D_points):
+# ba_problem = BundleAdjustmentProblem(_3D_points, camera_params, _2D_points)
+# ba_problem.run(100) # 设置最大迭代次数
+#
+# # 获取优化后的3D点和相机参数
+# optimized_3D_points = np.array([ba_problem.getPoint(i) for i in range(len(_3D_points))])
+# optimized_camera_params = np.array([ba_problem.getCameraParams(i) for i in range(len(camera_params))])
+#
+# return optimized_3D_points, optimized_camera_params
diff --git a/muvo/utils/instance_utils.py b/muvo/utils/instance_utils.py
new file mode 100644
index 0000000..cb8908c
--- /dev/null
+++ b/muvo/utils/instance_utils.py
@@ -0,0 +1,35 @@
+import torch
+
+
+def convert_instance_mask_to_center_and_offset_label(instance_label, ignore_index=255, sigma=3):
+ instance_label = instance_label.squeeze(2)
+ batch_size, seq_len, h, w = instance_label.shape
+ center_label = torch.zeros(batch_size, seq_len, 1, h, w, device=instance_label.device)
+ offset_label = ignore_index * torch.ones(batch_size, seq_len, 2, h, w, device=instance_label.device)
+ # x is vertical displacement, y is horizontal displacement
+ x, y = torch.meshgrid(
+ torch.arange(h, dtype=torch.float, device=instance_label.device),
+ torch.arange(w, dtype=torch.float, device=instance_label.device),
+ )
+
+ # Ignore id 0 which is the background
+ for b in range(batch_size):
+ num_instances = instance_label[b].max()
+ for instance_id in range(1, num_instances+1):
+ for t in range(seq_len):
+ instance_mask = (instance_label[b, t] == instance_id)
+ if instance_mask.sum() == 0:
+ # this instance is not in this frame
+ continue
+
+ xc = x[instance_mask].mean().round().long()
+ yc = y[instance_mask].mean().round().long()
+
+ off_x = xc - x
+ off_y = yc - y
+ g = torch.exp(-(off_x ** 2 + off_y ** 2) / sigma ** 2)
+ center_label[b, t, 0] = torch.maximum(center_label[b, t, 0], g)
+ offset_label[b, t, 0, instance_mask] = off_x[instance_mask]
+ offset_label[b, t, 1, instance_mask] = off_y[instance_mask]
+
+ return center_label, offset_label
diff --git a/muvo/utils/network_utils.py b/muvo/utils/network_utils.py
new file mode 100644
index 0000000..18deb4b
--- /dev/null
+++ b/muvo/utils/network_utils.py
@@ -0,0 +1,144 @@
+import torch
+import torch.nn as nn
+import torchvision
+
+
+def set_bn_momentum(model, momentum=0.1):
+ for m in model.modules():
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
+ m.momentum = momentum
+
+
+def preprocess_batch(batch, device, unsqueeze=False):
+ for key, value in batch.items():
+ if isinstance(value, torch.Tensor):
+ batch[key] = value.to(device)
+ if unsqueeze:
+ batch[key] = batch[key].unsqueeze(0)
+ else:
+ preprocess_batch(value, device, unsqueeze=unsqueeze)
+
+
+def squeeze_batch(batch):
+ for key, value in batch.items():
+ if isinstance(value, torch.Tensor):
+ batch[key] = value.squeeze(0)
+ else:
+ squeeze_batch(value)
+
+
+def remove_past(x, receptive_field):
+ """ Removes past tensors. The past is indicated by the receptive field. Creates a copy."""
+ if isinstance(x, torch.Tensor):
+ return x[:, (receptive_field-1):].contiguous()
+
+ output = {}
+ for key, value in x.items():
+ output[key] = remove_past(value, receptive_field)
+ return output
+
+
+def remove_last(x):
+ """ Removes last timestep. Creates a copy."""
+ if isinstance(x, torch.Tensor):
+ return x[:, :-1].contiguous()
+
+ output = {}
+ for key, value in x.items():
+ output[key] = remove_last(value)
+ return output
+
+
+def pack_sequence_dim(x):
+ """ Does not create a copy."""
+ if isinstance(x, torch.Tensor):
+ b, s = x.shape[:2]
+ return x.view(b * s, *x.shape[2:])
+
+ if isinstance(x, list):
+ return [pack_sequence_dim(elt) for elt in x]
+
+ output = {}
+ for key, value in x.items():
+ output[key] = pack_sequence_dim(value)
+ return output
+
+
+def unpack_sequence_dim(x, b, s):
+ """ Does not create a copy."""
+ if isinstance(x, torch.Tensor):
+ return x.view(b, s, *x.shape[1:])
+
+ if isinstance(x, list):
+ return [unpack_sequence_dim(elt, b, s) for elt in x]
+
+ output = {}
+ for key, value in x.items():
+ output[key] = unpack_sequence_dim(value, b, s)
+ return output
+
+
+def select_time_indices(x, time_indices):
+ """
+ Selects a particular time index for each element in the batch. Creates a copy.
+
+ Parameters
+ ----------
+ x: dict of tensors shape (batch_size, sequence_length, ...)
+ time_indices: torch.int64 shape (batch_size)
+ """
+ if isinstance(x, torch.Tensor):
+ b = x.shape[0]
+ return x[torch.arange(b), time_indices]
+
+ if isinstance(x, list):
+ return [select_time_indices(elt, time_indices) for elt in x]
+
+ output = {}
+ for key, value in x.items():
+ output[key] = select_time_indices(value, time_indices)
+ return output
+
+
+def calculate_birds_eye_view_parameters(x_bounds, y_bounds, z_bounds):
+ """
+ Parameters
+ ----------
+ x_bounds: Forward direction in the ego-car.
+ y_bounds: Sides
+ z_bounds: Height
+ Returns
+ -------
+ bev_resolution: Bird's-eye view bev_resolution
+ bev_start_position Bird's-eye view first element
+ bev_dimension Bird's-eye view tensor spatial dimension
+ """
+ bev_resolution = torch.tensor([row[2] for row in [x_bounds, y_bounds, z_bounds]])
+ bev_start_position = torch.tensor([row[0] + row[2] / 2.0 for row in [x_bounds, y_bounds, z_bounds]])
+ bev_dimension = torch.tensor([(row[1] - row[0]) / row[2] for row in [x_bounds, y_bounds, z_bounds]],
+ dtype=torch.long)
+
+ return bev_resolution, bev_start_position, bev_dimension
+
+
+class NormalizeInverse(torchvision.transforms.Normalize):
+ # https://discuss.pytorch.org/t/simple-way-to-inverse-transform-normalization/4821/8
+ def __init__(self, mean, std):
+ mean = torch.as_tensor(mean)
+ std = torch.as_tensor(std)
+ std_inv = 1 / (std + 1e-7)
+ mean_inv = -mean * std_inv
+ super().__init__(mean=mean_inv, std=std_inv)
+
+ def __call__(self, tensor):
+ return super().__call__(tensor.clone())
+
+
+def freeze_network(network):
+ for p in network.parameters():
+ p.requires_grad = False
+
+
+def unfreeze_network(network):
+ for p in network.parameters():
+ p.requires_grad = True
diff --git a/muvo/visualisation.py b/muvo/visualisation.py
new file mode 100644
index 0000000..8288fdb
--- /dev/null
+++ b/muvo/visualisation.py
@@ -0,0 +1,341 @@
+import matplotlib.pylab
+from PIL import Image, ImageDraw, ImageFont
+import numpy as np
+import torch
+import torchvision.transforms.functional as tvf
+
+from constants import EGO_VEHICLE_DIMENSION, BIRDVIEW_COLOURS
+
+
+DEFAULT_COLORMAP = matplotlib.pylab.cm.jet
+HEATMAP_PALETTE = (
+ torch.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0], [1.0, 1.0, 0.0]]).permute(1, 0).view(1, 3, 4, 1)
+)
+
+
+def prepare_final_display_image(img_rgb, route_map, birdview_label, birdview_pred, render_dict, is_predicting=False,
+ future_sec=None):
+ if is_predicting:
+ pred_colour = [79, 171, 198]
+ else:
+ pred_colour = [0, 0, 0]
+
+ rgb_height, rgb_width = img_rgb.shape[:2]
+ birdview_height, birdview_width = birdview_pred.shape[:2]
+
+ final_display_image_margin_height = 110
+ final_display_image_margin_width = 0
+ final_height = max(rgb_height, birdview_height) + final_display_image_margin_height
+ final_width = rgb_width + 2 * birdview_width + final_display_image_margin_width
+ final_display_image = 255 * np.ones([final_height, final_width, 3], dtype=np.uint8)
+ final_display_image[:rgb_height, :rgb_width] = img_rgb
+
+ # add route map
+ route_map_height, route_map_width = route_map.shape[:2]
+ margin = 10
+ route_map_width_slice = slice(margin, route_map_height + margin)
+ route_map_height_slice = slice(rgb_width - route_map_width - margin, rgb_width - margin)
+ final_display_image[route_map_width_slice, route_map_height_slice] = \
+ (0.3 * final_display_image[route_map_width_slice, route_map_height_slice]
+ + 0.7 * route_map
+ ).astype(np.uint8)
+
+ # Bev prediction
+ final_display_image[:birdview_height, rgb_width:(rgb_width + birdview_width)] = birdview_label
+ final_display_image[:birdview_height, (rgb_width + birdview_width):(rgb_width + 2 * birdview_width)] = birdview_pred
+
+ # Action gauges
+ final_display_image = add_action_gauges(
+ final_display_image, render_dict, height=birdview_height + 45, width=rgb_width, pred_colour=pred_colour
+ )
+
+ # Legend
+ final_display_image = add_legend(final_display_image, f'RGB input (time t)',
+ (0, rgb_height + 5), colour=[0, 0, 0], size=24)
+ final_display_image = add_legend(final_display_image, f'Ground truth BEV (time t)',
+ (rgb_width, birdview_height + 5), colour=[0, 0, 0], size=24)
+ label = 'Pred. BEV (time t)'
+ if future_sec is not None:
+ label = f'Pred. BEV (time t + {future_sec:.1f}s)'
+ final_display_image = add_legend(final_display_image, label,
+ (rgb_width + birdview_width, birdview_height + 5), colour=pred_colour, size=24)
+ if is_predicting:
+ final_display_image = add_legend(final_display_image, 'IMAGINING',
+ (rgb_width + birdview_width + 5, birdview_height - 30), colour=pred_colour,
+ size=24)
+ return final_display_image
+
+
+def upsample_bev(x, size=(320, 320)):
+ _, h, w = x.shape
+ x = tvf.resize(
+ x.unsqueeze(0), size, interpolation=tvf.InterpolationMode.NEAREST,
+ )
+ return x[0]
+
+
+def convert_bev_to_image(bev, cfg, upsample_factor=2):
+ bev = BIRDVIEW_COLOURS[bev]
+ bev_pixel_per_m = upsample_factor*int(1 / cfg.BEV.RESOLUTION)
+ ego_vehicle_bottom_offset_pixel = int(cfg.BEV.SIZE[0] / 2 + cfg.BEV.OFFSET_FORWARD)
+ bev = add_ego_vehicle(
+ bev,
+ pixel_per_m=bev_pixel_per_m,
+ ego_vehicle_bottom_offset_pixel=ego_vehicle_bottom_offset_pixel,
+ )
+ bev = make_contour(bev, colour=[0, 0, 0])
+ return bev
+
+
+def add_ego_vehicle(img, pixel_per_m=5, ego_vehicle_bottom_offset_pixel=32):
+ h, w = img.shape[:2]
+ # Assume vehicle is symmetrical in the x and y axis.
+ ego_vehicle_dimension_pixel = [int((x/2)*pixel_per_m) for x in EGO_VEHICLE_DIMENSION]
+
+ bottom_coordinate = h - ego_vehicle_bottom_offset_pixel - ego_vehicle_dimension_pixel[0]
+ top_coordinate = h - ego_vehicle_bottom_offset_pixel + ego_vehicle_dimension_pixel[0] + 1
+ left_coordinate = w//2 - ego_vehicle_dimension_pixel[1]
+ right_coordinate = w//2 + ego_vehicle_dimension_pixel[1] + 1
+
+ copy_img = img.copy()
+ copy_img[bottom_coordinate:top_coordinate, left_coordinate:right_coordinate] = [0, 0, 0]
+ return copy_img
+
+
+def make_contour(img, colour=[0, 0, 0], double_line=False):
+ h, w = img.shape[:2]
+ out = img.copy()
+ # Vertical lines
+ out[np.arange(h), np.repeat(0, h)] = colour
+ out[np.arange(h), np.repeat(w - 1, h)] = colour
+
+ # Horizontal lines
+ out[np.repeat(0, w), np.arange(w)] = colour
+ out[np.repeat(h - 1, w), np.arange(w)] = colour
+
+ if double_line:
+ out[np.arange(h), np.repeat(1, h)] = colour
+ out[np.arange(h), np.repeat(w - 2, h)] = colour
+
+ # Horizontal lines
+ out[np.repeat(1, w), np.arange(w)] = colour
+ out[np.repeat(h - 2, w), np.arange(w)] = colour
+ return out
+
+
+def merge_sparse_image_to_image_torch(
+ base_image: torch.Tensor,
+ sparse_image: torch.Tensor,
+ transparency: float = 0.4,
+) -> torch.Tensor:
+ assert base_image.shape == sparse_image.shape
+ canvas = base_image.clone()
+ mask = (sparse_image > 0).any(dim=0, keepdim=True).expand(*base_image.shape)
+
+ canvas[mask] = ((1 - transparency) * sparse_image[mask] + transparency * base_image[mask]).to(torch.uint8)
+ return canvas
+
+
+def add_legend(img, text='hello', position=(0, 0), colour=[255, 255, 255], size=14):
+ font_path = 'DejaVuSans.ttf'
+ font = ImageFont.truetype(font_path, size)
+
+ pil_img = Image.fromarray(img)
+ draw = ImageDraw.Draw(pil_img)
+ draw.text(position, text, tuple(colour), font=font)
+ return np.array(pil_img)
+
+
+def add_action_gauges(img, render_dict, height, width, pred_colour):
+ def plot_gauge(img, label, value, gauge_height, color=(79, 171, 198), max_value=None):
+ bar_height = 15
+ bar_width = 150
+ centering_offset = 40
+ width_offset = bar_width + width + 200 + centering_offset
+ cursor = value
+ if max_value is not None:
+ cursor /= max_value
+ if cursor > 0:
+ start = 0
+ end = int(cursor * bar_width)
+ else:
+ start = int(cursor * bar_width)
+ end = 0
+
+ # fill
+ img[gauge_height:gauge_height + bar_height, width_offset + start:width_offset + end] = color
+ # contour
+ height_slice = slice(gauge_height, gauge_height + bar_height)
+ width_slice = slice(width_offset - bar_width, width_offset + bar_width)
+ img[height_slice, width_slice] = make_contour(img[height_slice, width_slice], colour=[0, 0, 0])
+
+ # Middle gauge
+ img[gauge_height - 2:gauge_height + bar_height + 2, width_offset:width_offset + 1] = (0, 0, 0)
+ # Add labels
+ img = add_legend(img, f'{label}:', (width + centering_offset - 35, gauge_height - bar_height // 2), pred_colour,
+ size=24)
+ img = add_legend(img, f'{value:.2f}', (width_offset + bar_width + 10, gauge_height - bar_height // 2),
+ pred_colour,
+ size=24)
+ return img
+
+ acceleration = render_dict['throttle_brake'].item()
+ steering = render_dict['steering'].item()
+
+ img = plot_gauge(img, 'Pred. acceleration', acceleration, gauge_height=height + 10, color=(224, 102, 102))
+ img = plot_gauge(img, 'Pred. steering', steering, gauge_height=height + 40, color=(255, 127, 80))
+ return img
+
+
+def heatmap_image(
+ image: np.ndarray, cmap: matplotlib.colors.LinearSegmentedColormap = DEFAULT_COLORMAP, autoscale: bool = True
+) -> np.ndarray:
+ """Colorize an 1 or 2 channel image with a colourmap."""
+ if not issubclass(image.dtype.type, np.floating):
+ raise ValueError(f"Expected a ndarray of float type, but got dtype {image.dtype}")
+ if not (image.ndim == 2 or (image.ndim == 3 and image.shape[0] in [1, 2])):
+ raise ValueError(f"Expected a ndarray of shape [H, W] or [1, H, W] or [2, H, W], but got shape {image.shape}")
+ heatmap_np = apply_colour_map(image.copy(), cmap=cmap, autoscale=autoscale)
+ heatmap_np = np.uint8(heatmap_np * 255)
+ return heatmap_np
+
+
+def apply_colour_map(
+ image: np.ndarray, cmap: matplotlib.colors.LinearSegmentedColormap = DEFAULT_COLORMAP, autoscale: bool = False
+) -> np.ndarray:
+ """
+ Applies a colour map to the given 1 or 2 channel numpy image. if 2 channel, must be 2xHxW.
+ Returns a HxWx3 numpy image
+ """
+ if image.ndim == 2 or (image.ndim == 3 and image.shape[0] == 1):
+ if image.ndim == 3:
+ image = image[0]
+ # grayscale scalar image
+ if autoscale:
+ image = _normalise(image)
+ return cmap(image)[:, :, :3]
+ if image.shape[0] == 2:
+ # 2 dimensional UV
+ return flow_to_image(image, autoscale=autoscale)
+ if image.shape[0] == 3:
+ # normalise rgb channels
+ if autoscale:
+ image = _normalise(image)
+ return np.transpose(image, axes=[1, 2, 0])
+ raise Exception('Image must be 1, 2 or 3 channel to convert to colour_map (CxHxW)')
+
+
+def _normalise(image: np.ndarray) -> np.ndarray:
+ lower = np.min(image)
+ delta = np.max(image) - lower
+ if delta == 0:
+ delta = 1
+ image = (image.astype(np.float32) - lower) / delta
+ return image
+
+
+def flow_to_image(flow: np.ndarray, autoscale: bool = False) -> np.ndarray:
+ """
+ Applies colour map to flow which should be a 2 channel image tensor HxWx2. Returns a HxWx3 numpy image
+ Code adapted from: https://github.com/liruoteng/FlowNet/blob/master/models/flownet/scripts/flowlib.py
+ """
+ u = flow[0, :, :]
+ v = flow[1, :, :]
+
+ # Convert to polar coordinates
+ rad = np.sqrt(u ** 2 + v ** 2)
+ maxrad = np.max(rad)
+
+ # Normalise flow maps
+ if autoscale:
+ u /= maxrad + np.finfo(float).eps
+ v /= maxrad + np.finfo(float).eps
+
+ # visualise flow with cmap
+ return np.uint8(compute_color(u, v) * 255)
+
+
+def compute_color(u: np.ndarray, v: np.ndarray) -> np.ndarray:
+ assert u.shape == v.shape
+ [h, w] = u.shape
+ img = np.zeros([h, w, 3])
+ nan_mask = np.isnan(u) | np.isnan(v)
+ u[nan_mask] = 0
+ v[nan_mask] = 0
+
+ colorwheel = make_color_wheel()
+ ncols = np.size(colorwheel, 0)
+
+ rad = np.sqrt(u ** 2 + v ** 2)
+ a = np.arctan2(-v, -u) / np.pi
+ f_k = (a + 1) / 2 * (ncols - 1) + 1
+ k_0 = np.floor(f_k).astype(int)
+ k_1 = k_0 + 1
+ k_1[k_1 == ncols + 1] = 1
+ f = f_k - k_0
+
+ for i in range(0, np.size(colorwheel, 1)):
+ tmp = colorwheel[:, i]
+ col0 = tmp[k_0 - 1] / 255
+ col1 = tmp[k_1 - 1] / 255
+ col = (1 - f) * col0 + f * col1
+
+ idx = rad <= 1
+ col[idx] = 1 - rad[idx] * (1 - col[idx])
+ notidx = np.logical_not(idx)
+
+ col[notidx] *= 0.75
+ img[:, :, i] = col * (1 - nan_mask)
+
+ return img
+
+
+def make_color_wheel() -> np.ndarray:
+ """
+ Create colour wheel.
+ Code adapted from https://github.com/liruoteng/FlowNet/blob/master/models/flownet/scripts/flowlib.py
+ """
+ red_yellow = 15
+ yellow_green = 6
+ green_cyan = 4
+ cyan_blue = 11
+ blue_magenta = 13
+ magenta_red = 6
+
+ ncols = red_yellow + yellow_green + green_cyan + cyan_blue + blue_magenta + magenta_red
+ colorwheel = np.zeros([ncols, 3])
+
+ col = 0
+
+ # red_yellow
+ colorwheel[0:red_yellow, 0] = 255
+ colorwheel[0:red_yellow, 1] = np.transpose(np.floor(255 * np.arange(0, red_yellow) / red_yellow))
+ col += red_yellow
+
+ # yellow_green
+ colorwheel[col: col + yellow_green, 0] = 255 - np.transpose(
+ np.floor(255 * np.arange(0, yellow_green) / yellow_green)
+ )
+ colorwheel[col: col + yellow_green, 1] = 255
+ col += yellow_green
+
+ # green_cyan
+ colorwheel[col: col + green_cyan, 1] = 255
+ colorwheel[col: col + green_cyan, 2] = np.transpose(np.floor(255 * np.arange(0, green_cyan) / green_cyan))
+ col += green_cyan
+
+ # cyan_blue
+ colorwheel[col: col + cyan_blue, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, cyan_blue) / cyan_blue))
+ colorwheel[col: col + cyan_blue, 2] = 255
+ col += cyan_blue
+
+ # blue_magenta
+ colorwheel[col: col + blue_magenta, 2] = 255
+ colorwheel[col: col + blue_magenta, 0] = np.transpose(np.floor(255 * np.arange(0, blue_magenta) / blue_magenta))
+ col += +blue_magenta
+
+ # magenta_red
+ colorwheel[col: col + magenta_red, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, magenta_red) / magenta_red))
+ colorwheel[col: col + magenta_red, 0] = 255
+
+ return colorwheel
diff --git a/prediction.py b/prediction.py
new file mode 100644
index 0000000..27e89bc
--- /dev/null
+++ b/prediction.py
@@ -0,0 +1,123 @@
+import git
+import os
+import socket
+import time
+from weakref import proxy
+
+import torch
+import lightning.pytorch as pl
+from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
+from lightning.pytorch.callbacks import ModelSummary, LearningRateMonitor
+
+from muvo.config import get_parser, get_cfg
+from muvo.data.dataset import DataModule
+from muvo.trainer import WorldModelTrainer
+
+from clearml import Task, Dataset, Model
+
+
+class SaveGitDiffHashCallback(pl.Callback):
+ def setup(self, trainer, pl_model, stage):
+ repo = git.Repo()
+ trainer.git_hash = repo.head.object.hexsha
+ trainer.git_diff = repo.git.diff(repo.head.commit.tree)
+
+ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
+ checkpoint['world_size'] = trainer.world_size
+ checkpoint['git_hash'] = trainer.git_hash
+ checkpoint['git_diff'] = trainer.git_diff
+
+
+class MyModelCheckpoint(ModelCheckpoint):
+ def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
+ filename = filepath.split('/')[-1]
+ _checkpoint = trainer._checkpoint_connector.dump_checkpoint(self.save_weights_only)
+ try:
+ torch.save(_checkpoint, filename)
+ except AttributeError as err:
+ key = "hyper_parameters"
+ _checkpoint.pop(key, None)
+ print(f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}")
+ torch.save(_checkpoint, filename)
+
+ self._last_global_step_saved = trainer.global_step
+
+ # notify loggers
+ if trainer.is_global_zero:
+ for logger in trainer.loggers:
+ logger.after_save_checkpoint(proxy(self))
+
+
+def main():
+ args = get_parser().parse_args()
+ cfg = get_cfg(args)
+
+ # task = Task.init(project_name=cfg.CML_PROJECT, task_name=cfg.CML_TASK, task_type=cfg.CML_TYPE, tags=cfg.TAG)
+ # task.connect(cfg)
+ # cml_logger = task.get_logger()
+ #
+ # dataset_root = Dataset.get(dataset_project=cfg.CML_PROJECT,
+ # dataset_name=cfg.CML_DATASET,
+ # ).get_local_copy()
+
+ # data = DataModule(cfg, dataset_root=dataset_root)
+ data = DataModule(cfg)
+
+ input_model = Model(model_id='').get_local_copy() if cfg.PRETRAINED.CML_MODEL else None
+ # input_model = cfg.PRETRAINED.PATH
+ model = WorldModelTrainer(cfg.convert_to_dict(), pretrained_path=input_model)
+ # model = WorldModelTrainer.load_from_checkpoint(checkpoint_path=input_model)
+ # model.get_cml_logger(cml_logger)
+
+ save_dir = os.path.join(
+ cfg.LOG_DIR, time.strftime('%d%B%Yat%H:%M:%S%Z') + '_' + socket.gethostname() + '_' + cfg.TAG
+ )
+ logger = pl.loggers.TensorBoardLogger(save_dir=save_dir)
+
+ callbacks = [
+ ModelSummary(),
+ SaveGitDiffHashCallback(),
+ LearningRateMonitor(),
+ MyModelCheckpoint(
+ save_dir, every_n_train_steps=cfg.VAL_CHECK_INTERVAL,
+ ),
+ ]
+
+ if cfg.LIMIT_VAL_BATCHES in [0, 1]:
+ limit_val_batches = float(cfg.LIMIT_VAL_BATCHES)
+ else:
+ limit_val_batches = cfg.LIMIT_VAL_BATCHES
+
+ replace_sampler_ddp = not cfg.SAMPLER.ENABLED
+
+ trainer = pl.Trainer(
+ # devices=cfg.GPUS,
+ accelerator='auto',
+ # strategy='ddp',
+ precision=cfg.PRECISION,
+ # sync_batchnorm=True,
+ max_epochs=None,
+ max_steps=cfg.STEPS,
+ callbacks=callbacks,
+ logger=logger,
+ log_every_n_steps=cfg.LOGGING_INTERVAL,
+ val_check_interval=cfg.VAL_CHECK_INTERVAL * cfg.OPTIMIZER.ACCUMULATE_GRAD_BATCHES,
+ check_val_every_n_epoch=None,
+ # limit_val_batches=limit_val_batches,
+ limit_val_batches=3,
+ # use_distributed_sampler=replace_sampler_ddp,
+ accumulate_grad_batches=cfg.OPTIMIZER.ACCUMULATE_GRAD_BATCHES,
+ num_sanity_val_steps=0,
+ profiler='simple',
+ )
+
+ # trainer.fit(model, datamodule=data)
+ trainer.test(model, dataloaders=data.test_dataloader())
+
+
+if __name__ == '__main__':
+ main()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..9422663
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,124 @@
+aiohttp==3.8.5
+aiosignal==1.3.1
+annotated-types==0.5.0
+anyio==3.7.1
+arrow==1.2.3
+async-timeout==4.0.3
+attrs==23.1.0
+backoff==2.2.1
+beautifulsoup4==4.12.2
+blessed==1.20.0
+carla==0.9.14
+certifi==2023.7.22
+charset-normalizer==3.2.0
+clearml==1.13.1
+click==8.1.7
+cmake==3.27.5
+contourpy==1.1.1
+croniter==1.4.1
+cycler==0.11.0
+dateutils==0.6.12
+deepdiff==6.5.0
+exceptiongroup==1.1.3
+fastapi==0.103.1
+filelock==3.12.4
+fonttools==4.42.1
+frozenlist==1.4.0
+fsspec==2023.9.2
+furl==2.1.3
+fvcore==0.1.5.post20221221
+gitdb==4.0.10
+GitPython==3.1.37
+h11==0.14.0
+huggingface-hub==0.17.3
+idna==3.4
+importlib-resources==6.1.0
+inquirer==3.1.3
+iopath==0.1.10
+itsdangerous==2.1.2
+Jinja2==3.1.2
+joblib==1.3.2
+jsonschema==4.19.1
+jsonschema-specifications==2023.7.1
+kiwisolver==1.4.5
+lightning==2.0.9
+lightning-cloud==0.5.38
+lightning-utilities==0.9.0
+lit==17.0.1
+markdown-it-py==3.0.0
+MarkupSafe==2.1.3
+matplotlib==3.7.3
+mdurl==0.1.2
+mpmath==1.3.0
+multidict==6.0.4
+networkx==3.1
+numpy==1.24.4
+nvidia-cublas-cu11==11.10.3.66
+nvidia-cuda-cupti-cu11==11.7.101
+nvidia-cuda-nvrtc-cu11==11.7.99
+nvidia-cuda-runtime-cu11==11.7.99
+nvidia-cudnn-cu11==8.5.0.96
+nvidia-cufft-cu11==10.9.0.58
+nvidia-curand-cu11==10.2.10.91
+nvidia-cusolver-cu11==11.4.0.1
+nvidia-cusparse-cu11==11.7.4.91
+nvidia-nccl-cu11==2.14.3
+nvidia-nvtx-cu11==11.7.91
+opencv-python==4.8.1.78
+ordered-set==4.1.0
+orderedmultidict==1.0.1
+packaging==23.1
+pandas==2.0.3
+pathlib2==2.3.7.post1
+Pillow==10.0.1
+pkg_resources==0.0.0
+pkgutil_resolve_name==1.3.10
+portalocker==2.8.2
+psutil==5.9.5
+pydantic==2.1.1
+pydantic_core==2.4.0
+Pygments==2.16.1
+PyJWT==2.4.0
+pyparsing==3.1.1
+python-dateutil==2.8.2
+python-editor==1.0.4
+python-multipart==0.0.6
+pytorch-lightning==2.0.9
+pytz==2023.3.post1
+PyYAML==6.0.1
+readchar==4.0.5
+referencing==0.30.2
+requests==2.31.0
+rich==13.5.3
+rpds-py==0.10.3
+safetensors==0.3.3
+scikit-learn==1.3.1
+scipy==1.10.1
+six==1.16.0
+smmap==5.0.1
+sniffio==1.3.0
+soupsieve==2.5
+starlette==0.27.0
+starsessions==1.3.0
+sympy==1.12
+tabulate==0.9.0
+termcolor==2.3.0
+threadpoolctl==3.2.0
+timm==0.9.7
+torch==2.0.1
+torchaudio==2.0.2
+torchmetrics==1.2.0
+torchvision==0.15.2
+tqdm==4.66.1
+traitlets==5.10.1
+triton==2.0.0
+typing_extensions==4.8.0
+tzdata==2023.3
+urllib3==2.0.5
+uvicorn==0.23.2
+wcwidth==0.2.6
+websocket-client==1.6.3
+websockets==11.0.3
+yacs==0.1.8
+yarl==1.9.2
+zipp==3.17.0
diff --git a/rl_birdview/models/distributions.py b/rl_birdview/models/distributions.py
new file mode 100644
index 0000000..566bbcb
--- /dev/null
+++ b/rl_birdview/models/distributions.py
@@ -0,0 +1,280 @@
+"""Adapted from https://github.com/zhejz/carla-roach CC-BY-NC 4.0 license."""
+
+from typing import Optional, Tuple
+import torch as th
+import torch.nn as nn
+from torch.distributions import Beta, Normal
+from torch.nn import functional as F
+import numpy as np
+
+
+def sum_independent_dims(tensor: th.Tensor) -> th.Tensor:
+ if len(tensor.shape) > 1:
+ tensor = tensor.sum(dim=1)
+ else:
+ tensor = tensor.sum()
+ return tensor
+
+
+class DiagGaussianDistribution():
+ def __init__(self, action_dim: int, dist_init=None, action_dependent_std=False):
+ self.distribution = None
+ self.action_dim = action_dim
+ self.dist_init = dist_init
+ self.action_dependent_std = action_dependent_std
+
+ self.low = None
+ self.high = None
+ self.log_std_max = 2
+ self.log_std_min = -20
+
+ # [mu, log_std], [0, 1]
+ self.acc_exploration_dist = {
+ 'go': th.FloatTensor([0.66, -3]),
+ 'stop': th.FloatTensor([-0.66, -3])
+ }
+ self.steer_exploration_dist = {
+ 'turn': th.FloatTensor([0.0, -1]),
+ 'straight': th.FloatTensor([3.0, 3.0])
+ }
+
+ if th.cuda.is_available():
+ self.device = 'cuda'
+ else:
+ self.device = 'cpu'
+
+ def proba_distribution_net(self, latent_dim: int) -> Tuple[nn.Module, nn.Parameter]:
+ mean_actions = nn.Linear(latent_dim, self.action_dim)
+ if self.action_dependent_std:
+ log_std = nn.Linear(latent_dim, self.action_dim)
+ else:
+ log_std = nn.Parameter(-2.0*th.ones(self.action_dim), requires_grad=True)
+
+ if self.dist_init is not None:
+ # log_std.weight.data.fill_(0.01)
+ # mean_actions.weight.data.fill_(0.01)
+ # acc/steer
+ mean_actions.bias.data[0] = self.dist_init[0][0]
+ mean_actions.bias.data[1] = self.dist_init[1][0]
+ if self.action_dependent_std:
+ log_std.bias.data[0] = self.dist_init[0][1]
+ log_std.bias.data[1] = self.dist_init[1][1]
+ else:
+ init_tensor = th.FloatTensor([self.dist_init[0][1], self.dist_init[1][1]])
+ log_std = nn.Parameter(init_tensor, requires_grad=True)
+
+ return mean_actions, log_std
+
+ def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor) -> "DiagGaussianDistribution":
+ if self.action_dependent_std:
+ log_std = th.clamp(log_std, self.log_std_min, self.log_std_max)
+ action_std = th.ones_like(mean_actions) * log_std.exp()
+ self.distribution = Normal(mean_actions, action_std)
+ return self
+
+ def log_prob(self, actions: th.Tensor) -> th.Tensor:
+ log_prob = self.distribution.log_prob(actions)
+ return sum_independent_dims(log_prob)
+
+ def entropy_loss(self) -> th.Tensor:
+ entropy_loss = -1.0 * self.distribution.entropy()
+ return th.mean(entropy_loss)
+
+ def exploration_loss(self, exploration_suggests) -> th.Tensor:
+ # [('stop'/'go'/None, 'turn'/'straight'/None)]
+ # (batch_size, action_dim)
+ mu = self.distribution.loc.detach().clone()
+ sigma = self.distribution.scale.detach().clone()
+
+ for i, (acc_suggest, steer_suggest) in enumerate(exploration_suggests):
+ if acc_suggest != '':
+ mu[i, 0] = self.acc_exploration_dist[acc_suggest][0]
+ sigma[i, 0] = self.acc_exploration_dist[acc_suggest][1]
+ if steer_suggest != '':
+ mu[i, 1] = self.steer_exploration_dist[steer_suggest][0]
+ sigma[i, 1] = self.steer_exploration_dist[steer_suggest][1]
+
+ dist_ent = Normal(mu, sigma)
+
+ exploration_loss = th.distributions.kl_divergence(dist_ent, self.distribution)
+ return th.mean(exploration_loss)
+
+ def sample(self) -> th.Tensor:
+ return self.distribution.rsample()
+
+ def mode(self) -> th.Tensor:
+ return self.distribution.mean
+
+ def get_actions(self, deterministic: bool = False) -> th.Tensor:
+ if deterministic:
+ return self.mode()
+ return self.sample()
+
+
+class SquashedDiagGaussianDistribution():
+ def __init__(self, action_dim: int, log_std_init: float = 0.0, action_dependent_std=False):
+ self.distribution = None
+
+ self.action_dim = action_dim
+ self.log_std_init = log_std_init
+ self.eps = 1e-7
+ self.action_dependent_std = action_dependent_std
+
+ self.low = -1.0
+ self.high = 1.0
+ self.log_std_max = 2
+ self.log_std_min = -20
+
+ self.gaussian_actions = None
+
+ def proba_distribution_net(self, latent_dim: int):
+ mean_actions = nn.Linear(latent_dim, self.action_dim)
+ if self.action_dependent_std:
+ log_std = nn.Linear(latent_dim, self.action_dim)
+ else:
+ log_std = nn.Parameter(th.ones(self.action_dim) * self.log_std_init, requires_grad=True)
+ return mean_actions, log_std
+
+ def proba_distribution(self, mean_actions: th.Tensor, log_std: th.Tensor):
+ if self.action_dependent_std:
+ log_std = th.clamp(log_std, self.log_std_min, self.log_std_max)
+ action_std = th.ones_like(mean_actions) * log_std.exp()
+ self.distribution = Normal(mean_actions, action_std)
+ return self
+
+ def log_prob(self, actions: th.Tensor, gaussian_actions: Optional[th.Tensor] = None) -> th.Tensor:
+ # Inverse tanh
+ if gaussian_actions is None:
+ gaussian_actions = th.clamp(actions, min=-1.0 + self.eps, max=1.0 - self.eps)
+ gaussian_actions = 0.5 * (gaussian_actions.log1p() - (-gaussian_actions).log1p())
+
+ # Log likelihood for a Gaussian distribution
+ log_prob = self.distribution.log_prob(gaussian_actions)
+ log_prob = sum_independent_dims(log_prob)
+
+ # sb3 correction
+ # log_prob -= th.sum(th.log(1 - actions ** 2 + self.eps), dim=1)
+ # spinning-up correction
+ log_prob -= (2*(np.log(2) - gaussian_actions - F.softplus(-2*gaussian_actions))).sum(axis=1)
+ return log_prob
+
+ def entropy(self) -> Optional[th.Tensor]:
+ return None
+
+ def sample(self) -> th.Tensor:
+ return th.tanh(self.distribution.rsample())
+
+ def mode(self) -> th.Tensor:
+ return th.tanh(self.distribution.mean)
+
+ def get_actions(self, deterministic: bool = False) -> th.Tensor:
+ if deterministic:
+ return self.mode()
+ return self.sample()
+
+
+class BetaDistribution():
+ def __init__(self, action_dim=2, dist_init=None):
+ assert action_dim == 2
+
+ self.distribution = None
+ self.action_dim = action_dim
+ self.dist_init = dist_init
+ self.low = 0.0
+ self.high = 1.0
+
+ # [beta, alpha], [0, 1]
+ self.acc_exploration_dist = {
+ # [1, 2.5]
+ # [1.5, 1.0]
+ 'go': th.FloatTensor([1.0, 2.5]),
+ 'stop': th.FloatTensor([1.5, 1.0])
+ }
+ self.steer_exploration_dist = {
+ 'turn': th.FloatTensor([1.0, 1.0]),
+ 'straight': th.FloatTensor([3.0, 3.0])
+ }
+
+ if th.cuda.is_available():
+ self.device = 'cuda'
+ else:
+ self.device = 'cpu'
+
+ def proba_distribution_net(self, latent_dim: int) -> Tuple[nn.Module, nn.Module]:
+
+ linear_alpha = nn.Linear(latent_dim, self.action_dim)
+ linear_beta = nn.Linear(latent_dim, self.action_dim)
+
+ if self.dist_init is not None:
+ # linear_alpha.weight.data.fill_(0.01)
+ # linear_beta.weight.data.fill_(0.01)
+ # acc
+ linear_alpha.bias.data[0] = self.dist_init[0][1]
+ linear_beta.bias.data[0] = self.dist_init[0][0]
+ # steer
+ linear_alpha.bias.data[1] = self.dist_init[1][1]
+ linear_beta.bias.data[1] = self.dist_init[1][0]
+
+ alpha = nn.Sequential(linear_alpha, nn.Softplus())
+ beta = nn.Sequential(linear_beta, nn.Softplus())
+ return alpha, beta
+
+ def proba_distribution(self, alpha, beta):
+ self.distribution = Beta(alpha, beta)
+ return self
+
+ def log_prob(self, actions: th.Tensor) -> th.Tensor:
+ log_prob = self.distribution.log_prob(actions)
+ return sum_independent_dims(log_prob)
+
+ def entropy_loss(self) -> th.Tensor:
+ entropy_loss = -1.0 * self.distribution.entropy()
+ return th.mean(entropy_loss)
+
+ def exploration_loss(self, exploration_suggests) -> th.Tensor:
+ # [('stop'/'go'/None, 'turn'/'straight'/None)]
+ # (batch_size, action_dim)
+ alpha = self.distribution.concentration1.detach().clone()
+ beta = self.distribution.concentration0.detach().clone()
+
+ for i, (acc_suggest, steer_suggest) in enumerate(exploration_suggests):
+ if acc_suggest != '':
+ beta[i, 0] = self.acc_exploration_dist[acc_suggest][0]
+ alpha[i, 0] = self.acc_exploration_dist[acc_suggest][1]
+ if steer_suggest != '':
+ beta[i, 1] = self.steer_exploration_dist[steer_suggest][0]
+ alpha[i, 1] = self.steer_exploration_dist[steer_suggest][1]
+
+ dist_ent = Beta(alpha, beta)
+
+ exploration_loss = th.distributions.kl_divergence(self.distribution, dist_ent)
+ return th.mean(exploration_loss)
+
+ def sample(self) -> th.Tensor:
+ # Reparametrization trick to pass gradients
+ return self.distribution.rsample()
+
+ def mode(self) -> th.Tensor:
+ alpha = self.distribution.concentration1
+ beta = self.distribution.concentration0
+ x = th.zeros_like(alpha)
+ x[:, 1] += 0.5
+ mask1 = (alpha > 1) & (beta > 1)
+ x[mask1] = (alpha[mask1]-1)/(alpha[mask1]+beta[mask1]-2)
+
+ mask2 = (alpha <= 1) & (beta > 1)
+ x[mask2] = 0.0
+
+ mask3 = (alpha > 1) & (beta <= 1)
+ x[mask3] = 1.0
+
+ # mean
+ mask4 = (alpha <= 1) & (beta <= 1)
+ x[mask4] = self.distribution.mean[mask4]
+
+ return x
+
+ def get_actions(self, deterministic: bool = False) -> th.Tensor:
+ if deterministic:
+ return self.mode()
+ return self.sample()
diff --git a/rl_birdview/models/ppo.py b/rl_birdview/models/ppo.py
new file mode 100644
index 0000000..18c9ca7
--- /dev/null
+++ b/rl_birdview/models/ppo.py
@@ -0,0 +1,279 @@
+"""Adapted from https://github.com/zhejz/carla-roach CC-BY-NC 4.0 license."""
+
+import time
+import torch as th
+import numpy as np
+from collections import deque
+from torch.nn import functional as F
+
+from stable_baselines3.common.vec_env import VecEnv
+from stable_baselines3.common.callbacks import BaseCallback
+from stable_baselines3.common.utils import explained_variance
+
+from .ppo_buffer import PpoBuffer
+
+
+class PPO():
+ def __init__(self, policy, env,
+ learning_rate: float = 1e-5,
+ n_steps_total: int = 8192,
+ batch_size: int = 256,
+ n_epochs: int = 20,
+ gamma: float = 0.99,
+ gae_lambda: float = 0.9,
+ clip_range: float = 0.2,
+ clip_range_vf: float = None,
+ ent_coef: float = 0.05,
+ explore_coef: float = 0.05,
+ vf_coef: float = 0.5,
+ max_grad_norm: float = 0.5,
+ target_kl: float = 0.01,
+ update_adv=False,
+ lr_schedule_step=None,
+ start_num_timesteps: int = 0):
+
+ self.policy = policy
+ self.env = env
+ self.learning_rate = learning_rate
+ self.n_steps_total = n_steps_total
+ self.n_steps = n_steps_total//env.num_envs
+ self.batch_size = batch_size
+ self.n_epochs = n_epochs
+ self.gamma = gamma
+ self.gae_lambda = gae_lambda
+ self.clip_range = clip_range
+ self.clip_range_vf = clip_range_vf
+ self.ent_coef = ent_coef
+ self.explore_coef = explore_coef
+ self.vf_coef = vf_coef
+ self.max_grad_norm = max_grad_norm
+ self.target_kl = target_kl
+ self.update_adv = update_adv
+ self.lr_schedule_step = lr_schedule_step
+ self.start_num_timesteps = start_num_timesteps
+ self.num_timesteps = start_num_timesteps
+
+ self._last_obs = None
+ self._last_dones = None
+ self.ep_stat_buffer = None
+
+ self.buffer = PpoBuffer(self.n_steps, self.env.observation_space, self.env.action_space,
+ gamma=self.gamma, gae_lambda=self.gae_lambda, n_envs=self.env.num_envs)
+ self.policy = self.policy.to(self.policy.device)
+
+ model_parameters = filter(lambda p: p.requires_grad, self.policy.parameters())
+ total_params = sum([np.prod(p.size()) for p in model_parameters])
+ print(f'trainable parameters: {total_params/1000000:.2f}M')
+
+ def collect_rollouts(self, env: VecEnv, callback: BaseCallback,
+ rollout_buffer: PpoBuffer, n_rollout_steps: int) -> bool:
+ assert self._last_obs is not None, "No previous observation was provided"
+ n_steps = 0
+ rollout_buffer.reset()
+
+ self.action_statistics = []
+ self.mu_statistics = []
+ self.sigma_statistics = []
+
+ while n_steps < n_rollout_steps:
+ actions, values, log_probs, mu, sigma, _ = self.policy.forward(self._last_obs)
+ self.action_statistics.append(actions)
+ self.mu_statistics.append(mu)
+ self.sigma_statistics.append(sigma)
+
+ new_obs, rewards, dones, infos = env.step(actions)
+
+ if callback.on_step() is False:
+ return False
+
+ # update_info_buffer
+ for i in np.where(dones)[0]:
+ self.ep_stat_buffer.append(infos[i]['episode_stat'])
+
+ n_steps += 1
+ self.num_timesteps += env.num_envs
+
+ rollout_buffer.add(self._last_obs, actions, rewards, self._last_dones, values, log_probs, mu, sigma, infos)
+ self._last_obs = new_obs
+ self._last_dones = dones
+
+ last_values = self.policy.forward_value(self._last_obs)
+ rollout_buffer.compute_returns_and_advantage(last_values, dones=self._last_dones)
+
+ return True
+
+ def train(self):
+ for param_group in self.policy.optimizer.param_groups:
+ param_group["lr"] = self.learning_rate
+
+ entropy_losses, exploration_losses, pg_losses, value_losses, losses = [], [], [], [], []
+ clip_fractions = []
+ approx_kl_divs = []
+
+ # train for gradient_steps epochs
+ epoch = 0
+ data_len = int(self.buffer.buffer_size * self.buffer.n_envs / self.batch_size)
+ for epoch in range(self.n_epochs):
+ approx_kl_divs = []
+ # Do a complete pass on the rollout buffer
+ self.buffer.start_caching(self.batch_size)
+ # while self.buffer.sample_queue.qsize() < 3:
+ # time.sleep(0.01)
+ for i in range(data_len):
+
+ if self.buffer.sample_queue.empty():
+ while self.buffer.sample_queue.empty():
+ # print(f'buffer_empty: {self.buffer.sample_queue.qsize()}')
+ time.sleep(0.01)
+ rollout_data = self.buffer.sample_queue.get()
+
+ values, log_prob, entropy_loss, exploration_loss, distribution = self.policy.evaluate_actions(
+ rollout_data.observations, rollout_data.actions, rollout_data.exploration_suggests,
+ detach_values=False)
+ # Normalize advantage
+ advantages = rollout_data.advantages
+ # advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
+
+ # ratio between old and new policy, should be one at the first iteration
+ ratio = th.exp(log_prob - rollout_data.old_log_prob)
+
+ # clipped surrogate loss
+ policy_loss_1 = advantages * ratio
+ policy_loss_2 = advantages * th.clamp(ratio, 1 - self.clip_range, 1 + self.clip_range)
+ policy_loss = -th.min(policy_loss_1, policy_loss_2).mean()
+
+ # Logging
+ clip_fraction = th.mean((th.abs(ratio - 1) > self.clip_range).float()).item()
+ clip_fractions.append(clip_fraction)
+
+ if self.clip_range_vf is None:
+ # No clipping
+ values_pred = values
+ else:
+ # Clip the different between old and new value
+ # NOTE: this depends on the reward scaling
+ values_pred = rollout_data.old_values + th.clamp(values - rollout_data.old_values,
+ -self.clip_range_vf, self.clip_range_vf)
+ # Value loss using the TD(gae_lambda) target
+ value_loss = F.mse_loss(rollout_data.returns, values_pred)
+
+ loss = policy_loss + self.vf_coef * value_loss \
+ + self.ent_coef * entropy_loss + self.explore_coef * exploration_loss
+
+ losses.append(loss.item())
+ pg_losses.append(policy_loss.item())
+ value_losses.append(value_loss.item())
+ entropy_losses.append(entropy_loss.item())
+ exploration_losses.append(exploration_loss.item())
+
+ # Optimization step
+ self.policy.optimizer.zero_grad()
+ loss.backward()
+ # Clip grad norm
+ th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
+ self.policy.optimizer.step()
+
+ with th.no_grad():
+ old_distribution = self.policy.action_dist.proba_distribution(
+ rollout_data.old_mu, rollout_data.old_sigma)
+ kl_div = th.distributions.kl_divergence(old_distribution.distribution, distribution)
+
+ approx_kl_divs.append(kl_div.mean().item())
+
+ if self.target_kl is not None and np.mean(approx_kl_divs) > 1.5 * self.target_kl:
+ if self.lr_schedule_step is not None:
+ self.kl_early_stop += 1
+ if self.kl_early_stop >= self.lr_schedule_step:
+ self.learning_rate *= 0.5
+ self.kl_early_stop = 0
+ break
+
+ # update advantages
+ if self.update_adv:
+ self.buffer.update_values(self.policy)
+ last_values = self.policy.forward_value(self._last_obs)
+ self.buffer.compute_returns_and_advantage(last_values, dones=self._last_dones)
+
+ explained_var = explained_variance(self.buffer.returns.flatten(), self.buffer.values.flatten())
+
+ # Logs
+ self.train_debug = {
+ "train/entropy_loss": np.mean(entropy_losses),
+ "train/exploration_loss": np.mean(exploration_losses),
+ "train/policy_gradient_loss": np.mean(pg_losses),
+ "train/value_loss": np.mean(value_losses),
+ "train/last_epoch_kl": np.mean(approx_kl_divs),
+ "train/clip_fraction": np.mean(clip_fractions),
+ "train/loss": np.mean(losses),
+ "train/explained_variance": explained_var,
+ "train/clip_range": self.clip_range,
+ "train/train_epoch": epoch,
+ "train/learning_rate": self.learning_rate
+ }
+
+ def learn(self, total_timesteps, callback=None, seed=2021):
+ # reset env seed
+ self.env.action_space.seed(seed)
+ self.env.observation_space.seed(seed)
+ self.env.seed(seed)
+
+ self.start_time = time.time()
+
+ self.kl_early_stop = 0
+ self.t_train_values = 0.0
+
+ self.ep_stat_buffer = deque(maxlen=100)
+ self._last_obs = self.env.reset()
+ self._last_dones = np.zeros((self.env.num_envs,), dtype=np.bool)
+
+ callback.init_callback(self)
+
+ callback.on_training_start(locals(), globals())
+
+ while self.num_timesteps < total_timesteps:
+ callback.on_rollout_start()
+ t0 = time.time()
+ self.policy = self.policy.train()
+ continue_training = self.collect_rollouts(self.env, callback, self.buffer, self.n_steps)
+ self.t_rollout = time.time() - t0
+ callback.on_rollout_end()
+
+ if continue_training is False:
+ break
+
+ t0 = time.time()
+ self.train()
+ self.t_train = time.time() - t0
+ callback.on_training_end()
+
+ return self
+
+ def _get_init_kwargs(self):
+ init_kwargs = dict(
+ learning_rate=self.learning_rate,
+ n_steps_total=self.n_steps_total,
+ batch_size=self.batch_size,
+ n_epochs=self.n_epochs,
+ gamma=self.gamma,
+ gae_lambda=self.gae_lambda,
+ clip_range=self.clip_range,
+ clip_range_vf=self.clip_range_vf,
+ ent_coef=self.ent_coef,
+ explore_coef=self.explore_coef,
+ vf_coef=self.vf_coef,
+ max_grad_norm=self.max_grad_norm,
+ target_kl=self.target_kl,
+ update_adv=self.update_adv,
+ lr_schedule_step=self.lr_schedule_step,
+ start_num_timesteps=self.num_timesteps,
+ )
+ return init_kwargs
+
+ def save(self, path: str) -> None:
+ th.save({'policy_state_dict': self.policy.state_dict(),
+ 'policy_init_kwargs': self.policy.get_init_kwargs(),
+ 'train_init_kwargs': self._get_init_kwargs()},
+ path)
+
+ def get_env(self):
+ return self.env
diff --git a/rl_birdview/models/ppo_buffer.py b/rl_birdview/models/ppo_buffer.py
new file mode 100644
index 0000000..448e14f
--- /dev/null
+++ b/rl_birdview/models/ppo_buffer.py
@@ -0,0 +1,263 @@
+"""Adapted from https://github.com/zhejz/carla-roach CC-BY-NC 4.0 license."""
+
+from gym import spaces
+import numpy as np
+from typing import Optional, Generator, NamedTuple, Dict, List
+import torch as th
+from stable_baselines3.common.vec_env.base_vec_env import tile_images
+import cv2
+import time
+from threading import Thread
+import queue
+
+COLORS = [
+ [46, 52, 54],
+ [136, 138, 133],
+ [255, 0, 255],
+ [0, 255, 255],
+ [0, 0, 255],
+ [255, 0, 0],
+ [255, 255, 0],
+ [255, 255, 255]
+]
+
+
+class PpoBufferSamples(NamedTuple):
+ observations: Dict[str, th.Tensor]
+ actions: th.Tensor
+ old_values: th.Tensor
+ old_log_prob: th.Tensor
+ old_mu: th.Tensor
+ old_sigma: th.Tensor
+ advantages: th.Tensor
+ returns: th.Tensor
+ exploration_suggests: List[tuple]
+
+
+class PpoBuffer():
+ def __init__(self, buffer_size: int, observation_space: spaces.Space, action_space: spaces.Space,
+ gae_lambda: float = 1, gamma: float = 0.99, n_envs: int = 1):
+
+ self.buffer_size = buffer_size
+ self.observation_space = observation_space
+ self.action_space = action_space
+ self.gae_lambda = gae_lambda
+ self.gamma = gamma
+ self.n_envs = n_envs
+ self.reset()
+
+ self.pos = 0
+ self.full = False
+ if th.cuda.is_available():
+ self.device = 'cuda'
+ else:
+ self.device = 'cpu'
+
+ self.sample_queue = queue.Queue()
+
+ def reset(self) -> None:
+ self.observations = {}
+ for k, s in self.observation_space.spaces.items():
+ self.observations[k] = np.zeros((self.buffer_size, self.n_envs,)+s.shape, dtype=s.dtype)
+ # int(np.prod(self.action_space.shape))
+ self.actions = np.zeros((self.buffer_size, self.n_envs)+self.action_space.shape, dtype=np.float32)
+ self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32)
+ self.mus = np.zeros((self.buffer_size, self.n_envs)+self.action_space.shape, dtype=np.float32)
+ self.sigmas = np.zeros((self.buffer_size, self.n_envs)+self.action_space.shape, dtype=np.float32)
+ self.exploration_suggests = np.zeros((self.buffer_size, self.n_envs), dtype=[('acc', 'U10'), ('steer', 'U10')])
+
+ self.reward_debugs = [[] for i in range(self.n_envs)]
+ self.terminal_debugs = [[] for i in range(self.n_envs)]
+
+ self.pos = 0
+ self.full = False
+
+ def compute_returns_and_advantage(self, last_value: th.Tensor, dones: np.ndarray) -> None:
+ last_gae_lam = 0
+ for step in reversed(range(self.buffer_size)):
+ if step == self.buffer_size - 1:
+ next_non_terminal = 1.0 - dones
+ next_value = last_value
+ # spinning up return calculation
+ # self.returns[step] = self.rewards[step] + self.gamma * last_value * next_non_terminal
+ else:
+ next_non_terminal = 1.0 - self.dones[step + 1]
+ next_value = self.values[step + 1]
+ # spinning up return calculation
+ # self.returns[step] = self.rewards[step] + self.gamma * self.returns[step+1] * next_non_terminal
+ delta = self.rewards[step] + self.gamma * next_value * next_non_terminal - self.values[step]
+ last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam
+ self.advantages[step] = last_gae_lam
+
+ # sb3 return
+ self.returns = self.advantages + self.values
+
+ def add(self,
+ obs_dict: Dict[str, np.ndarray],
+ action: np.ndarray,
+ reward: np.ndarray,
+ done: np.ndarray,
+ value: np.ndarray,
+ log_prob: np.ndarray,
+ mu: np.ndarray,
+ sigma: np.ndarray,
+ infos) -> None:
+
+ for k, v in obs_dict.items():
+ self.observations[k][self.pos] = v
+ self.actions[self.pos] = action
+ self.rewards[self.pos] = reward
+ self.dones[self.pos] = done
+ self.values[self.pos] = value
+ self.log_probs[self.pos] = log_prob
+ self.mus[self.pos] = mu
+ self.sigmas[self.pos] = sigma
+
+ for i in range(self.n_envs):
+ self.reward_debugs[i].append(infos[i]['reward_debug']['debug_texts'])
+ self.terminal_debugs[i].append(infos[i]['terminal_debug']['debug_texts'])
+
+ n_steps = infos[i]['terminal_debug']['exploration_suggest']['n_steps']
+ if n_steps > 0:
+ n_start = max(0, self.pos-n_steps)
+ self.exploration_suggests[n_start:self.pos, i] = \
+ infos[i]['terminal_debug']['exploration_suggest']['suggest']
+
+ self.pos += 1
+ if self.pos == self.buffer_size:
+ self.full = True
+
+ def update_values(self, policy):
+ for i in range(self.buffer_size):
+ obs_dict = {}
+ for k in self.observations.keys():
+ obs_dict[k] = self.observations[k][i]
+ values = policy.forward_value(obs_dict)
+ self.values[i] = values
+
+ def get(self, batch_size: Optional[int] = None) -> Generator[PpoBufferSamples, None, None]:
+ assert self.full, ''
+ indices = np.random.permutation(self.buffer_size * self.n_envs)
+ # Prepare the data
+ for tensor in ['actions', 'values', 'log_probs', 'advantages', 'returns',
+ 'mus', 'sigmas', 'exploration_suggests']:
+ self.__dict__['flat_'+tensor] = self.flatten(self.__dict__[tensor])
+ self.flat_observations = {}
+ for k in self.observations.keys():
+ self.flat_observations[k] = self.flatten(self.observations[k])
+
+ # spinning up: the next two lines implement the advantage normalization trick
+ adv_mean = np.mean(self.advantages)
+ adv_std = np.std(self.advantages) + np.finfo(np.float32).eps
+ self.advantages = (self.advantages - adv_mean) / adv_std
+
+ # Return everything, don't create minibatches
+ if batch_size is None:
+ batch_size = self.buffer_size * self.n_envs
+
+ start_idx = 0
+ while start_idx < self.buffer_size * self.n_envs:
+ yield self._get_samples(indices[start_idx:start_idx + batch_size])
+ start_idx += batch_size
+
+ def _get_samples(self, batch_inds: np.ndarray) -> PpoBufferSamples:
+ def to_torch(x):
+ return th.as_tensor(x).to(self.device)
+ # return th.from_numpy(x.astype(np.float32)).to(self.device)
+
+ obs_dict = {}
+ for k in self.observations.keys():
+ obs_dict[k] = to_torch(self.flat_observations[k][batch_inds])
+
+ data = (self.flat_actions[batch_inds],
+ self.flat_values[batch_inds],
+ self.flat_log_probs[batch_inds],
+ self.flat_mus[batch_inds],
+ self.flat_sigmas[batch_inds],
+ self.flat_advantages[batch_inds],
+ self.flat_returns[batch_inds]
+ )
+
+ data_torch = (obs_dict,) + tuple(map(to_torch, data)) + (self.flat_exploration_suggests[batch_inds],)
+ return PpoBufferSamples(*data_torch)
+
+ @staticmethod
+ def flatten(arr: np.ndarray) -> np.ndarray:
+ shape = arr.shape
+ # if len(shape) < 3:
+ # return arr.swapaxes(0, 1).reshape(shape[0] * shape[1])
+ # else:
+ return arr.reshape(shape[0] * shape[1], *shape[2:])
+
+ def render(self):
+ assert self.full, ''
+ list_render = []
+
+ _, _, c, h, w = self.observations['birdview'].shape
+ vis_idx = np.array([0, 1, 2, 6, 10, 14])
+
+ for i in range(self.buffer_size):
+ im_envs = []
+ for j in range(self.n_envs):
+
+ masks = self.observations['birdview'][i, j, vis_idx, :, :] > 100
+
+ im_birdview = np.zeros([h, w, 3], dtype=np.uint8)
+ for idx_c in range(len(vis_idx)):
+ im_birdview[masks[idx_c]] = COLORS[idx_c]
+
+ im = np.zeros([h, w*2, 3], dtype=np.uint8)
+ im[:h, :w] = im_birdview
+
+ action_str = np.array2string(self.actions[i, j], precision=1, separator=',', suppress_small=True)
+ state_str = np.array2string(self.observations['state'][i, j],
+ precision=1, separator=',', suppress_small=True)
+
+ reward = self.rewards[i, j]
+ ret = self.returns[i, j]
+ advantage = self.advantages[i, j]
+ done = int(self.dones[i, j])
+ value = self.values[i, j]
+ log_prob = self.log_probs[i, j]
+
+ txt_1 = f'v:{value:5.2f} p:{log_prob:5.2f} a{action_str}'
+ im = cv2.putText(im, txt_1, (2, 12), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1)
+ txt_2 = f'{done} {state_str}'
+ im = cv2.putText(im, txt_2, (2, 24), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1)
+ txt_3 = f'rw:{reward:5.2f} rt:{ret:5.2f} a:{advantage:5.2f}'
+ im = cv2.putText(im, txt_3, (2, 36), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1)
+
+ for i_txt, txt in enumerate(self.reward_debugs[j][i] + self.terminal_debugs[j][i]):
+ im = cv2.putText(im, txt, (w, (i_txt+1)*15), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1)
+
+ im_envs.append(im)
+
+ big_im = tile_images(im_envs)
+ list_render.append(big_im)
+
+ return list_render
+
+ def start_caching(self, batch_size):
+ thread1 = Thread(target=self.cache_to_cuda, args=(batch_size,))
+ thread1.start()
+
+ def cache_to_cuda(self, batch_size):
+ self.sample_queue.queue.clear()
+
+ for rollout_data in self.get(batch_size):
+ while self.sample_queue.qsize() >= 2:
+ time.sleep(0.01)
+ self.sample_queue.put(rollout_data)
+
+ def size(self) -> int:
+ """
+ :return: (int) The current size of the buffer
+ """
+ if self.full:
+ return self.buffer_size
+ return self.pos
diff --git a/rl_birdview/models/ppo_policy.py b/rl_birdview/models/ppo_policy.py
new file mode 100644
index 0000000..29ce6e6
--- /dev/null
+++ b/rl_birdview/models/ppo_policy.py
@@ -0,0 +1,244 @@
+"""Adapted from https://github.com/zhejz/carla-roach CC-BY-NC 4.0 license."""
+
+from typing import Union, Dict, Tuple, Any
+from functools import partial
+import gym
+import torch as th
+import torch.nn as nn
+import numpy as np
+
+from carla_gym.utils.config_utils import load_entry_point
+
+
+class PpoPolicy(nn.Module):
+
+ def __init__(self,
+ observation_space: gym.spaces.Space,
+ action_space: gym.spaces.Space,
+ policy_head_arch=[256, 256],
+ value_head_arch=[256, 256],
+ features_extractor_entry_point=None,
+ features_extractor_kwargs={},
+ distribution_entry_point=None,
+ distribution_kwargs={}):
+
+ super(PpoPolicy, self).__init__()
+ self.observation_space = observation_space
+ self.action_space = action_space
+ self.features_extractor_entry_point = features_extractor_entry_point
+ self.features_extractor_kwargs = features_extractor_kwargs
+ self.distribution_entry_point = distribution_entry_point
+ self.distribution_kwargs = distribution_kwargs
+
+ if th.cuda.is_available():
+ self.device = 'cuda'
+ else:
+ self.device = 'cpu'
+
+ self.optimizer_class = th.optim.Adam
+ self.optimizer_kwargs = {'eps': 1e-5}
+
+ features_extractor_class = load_entry_point(features_extractor_entry_point)
+ self.features_extractor = features_extractor_class(observation_space, **features_extractor_kwargs)
+
+ distribution_class = load_entry_point(distribution_entry_point)
+ self.action_dist = distribution_class(int(np.prod(action_space.shape)), **distribution_kwargs)
+
+ if 'StateDependentNoiseDistribution' in distribution_entry_point:
+ self.use_sde = True
+ self.sde_sample_freq = 4
+ else:
+ self.use_sde = False
+ self.sde_sample_freq = None
+
+ # best_so_far
+ # self.net_arch = [dict(pi=[256, 128, 64], vf=[128, 64])]
+ self.policy_head_arch = list(policy_head_arch)
+ self.value_head_arch = list(value_head_arch)
+ self.activation_fn = nn.ReLU
+ self.ortho_init = False
+
+ self._build()
+
+ def reset_noise(self, n_envs: int = 1) -> None:
+ assert self.use_sde, 'reset_noise() is only available when using gSDE'
+ self.action_dist.sample_weights(self.dist_sigma, batch_size=n_envs)
+
+ def _build(self) -> None:
+ last_layer_dim_pi = self.features_extractor.features_dim
+ policy_net = []
+ for layer_size in self.policy_head_arch:
+ policy_net.append(nn.Linear(last_layer_dim_pi, layer_size))
+ policy_net.append(self.activation_fn())
+ last_layer_dim_pi = layer_size
+
+ self.policy_head = nn.Sequential(*policy_net).to(self.device)
+ # mu->alpha/mean, sigma->beta/log_std (nn.Module, nn.Parameter)
+ self.dist_mu, self.dist_sigma = self.action_dist.proba_distribution_net(last_layer_dim_pi)
+
+ last_layer_dim_vf = self.features_extractor.features_dim
+ value_net = []
+ for layer_size in self.value_head_arch:
+ value_net.append(nn.Linear(last_layer_dim_vf, layer_size))
+ value_net.append(self.activation_fn())
+ last_layer_dim_vf = layer_size
+
+ value_net.append(nn.Linear(last_layer_dim_vf, 1))
+ self.value_head = nn.Sequential(*value_net).to(self.device)
+
+ # Init weights: use orthogonal initialization
+ # with small initial weight for the output
+ if self.ortho_init:
+ # TODO: check for features_extractor
+ # Values from stable-baselines.
+ # feature_extractor/mlp values are
+ # originally from openai/baselines (default gains/init_scales).
+ module_gains = {
+ # self.features_extractor: np.sqrt(2),
+ self.policy_head: np.sqrt(2),
+ self.value_head: np.sqrt(2)
+ # self.action_net: 0.01,
+ }
+ for module, gain in module_gains.items():
+ module.apply(partial(self.init_weights, gain=gain))
+
+ self.optimizer = self.optimizer_class(self.parameters(), **self.optimizer_kwargs)
+
+ def _get_features(self, birdview: th.Tensor, state: th.Tensor) -> th.Tensor:
+ """
+ :param birdview: th.Tensor (num_envs, frame_stack*channel, height, width)
+ :param state: th.Tensor (num_envs, state_dim)
+ """
+ birdview = birdview.float() / 255.0
+ features = self.features_extractor(birdview, state)
+ return features
+
+ def _get_action_dist_from_features(self, features: th.Tensor):
+ latent_pi = self.policy_head(features)
+ mu = self.dist_mu(latent_pi)
+ if isinstance(self.dist_sigma, nn.Parameter):
+ sigma = self.dist_sigma
+ else:
+ sigma = self.dist_sigma(latent_pi)
+ return self.action_dist.proba_distribution(mu, sigma), mu.detach().cpu().numpy(), sigma.detach().cpu().numpy()
+
+ def evaluate_actions(self, obs_dict: Dict[str, th.Tensor], actions: th.Tensor, exploration_suggests,
+ detach_values=False):
+ features = self._get_features(**obs_dict)
+
+ if detach_values:
+ detached_features = features.detach()
+ values = self.value_head(detached_features)
+ else:
+ values = self.value_head(features)
+
+ distribution, mu, sigma = self._get_action_dist_from_features(features)
+ actions = self.scale_action(actions)
+ log_prob = distribution.log_prob(actions)
+ return values.flatten(), log_prob, distribution.entropy_loss(), \
+ distribution.exploration_loss(exploration_suggests), distribution.distribution
+
+ def evaluate_values(self, obs_dict: Dict[str, th.Tensor]):
+ features = self._get_features(**obs_dict)
+ values = self.value_head(features)
+ distribution, mu, sigma = self._get_action_dist_from_features(features)
+ return values.flatten(), distribution.distribution
+
+ def forward(self, obs_dict: Dict[str, np.ndarray], deterministic: bool = False, clip_action: bool = False):
+ '''
+ used in collect_rollouts(), do not clamp actions
+ '''
+ with th.no_grad():
+ obs_tensor_dict = dict([(k, th.as_tensor(v).to(self.device)) for k, v in obs_dict.items()])
+ features = self._get_features(**obs_tensor_dict)
+ values = self.value_head(features)
+ distribution, mu, sigma = self._get_action_dist_from_features(features)
+ actions = distribution.get_actions(deterministic=deterministic)
+ log_prob = distribution.log_prob(actions)
+
+ actions = actions.cpu().numpy()
+ actions = self.unscale_action(actions)
+ if clip_action:
+ actions = np.clip(actions, self.action_space.low, self.action_space.high)
+ values = values.cpu().numpy().flatten()
+ log_prob = log_prob.cpu().numpy()
+ features = features.cpu().numpy()
+ return actions, values, log_prob, mu, sigma, features
+
+ def forward_value(self, obs_dict: Dict[str, np.ndarray]) -> np.ndarray:
+ with th.no_grad():
+ obs_tensor_dict = dict([(k, th.as_tensor(v).to(self.device)) for k, v in obs_dict.items()])
+ features = self._get_features(**obs_tensor_dict)
+ values = self.value_head(features)
+ values = values.cpu().numpy().flatten()
+ return values
+
+ def forward_policy(self, obs_dict: Dict[str, np.ndarray]) -> np.ndarray:
+ with th.no_grad():
+ obs_tensor_dict = dict([(k, th.as_tensor(v).to(self.device)) for k, v in obs_dict.items()])
+ features = self._get_features(**obs_tensor_dict)
+ distribution, mu, sigma = self._get_action_dist_from_features(features)
+ return mu, sigma
+
+ def scale_action(self, action: th.Tensor, eps=1e-7) -> th.Tensor:
+ # input action \in [a_low, a_high]
+ # output action \in [d_low+eps, d_high-eps]
+ d_low, d_high = self.action_dist.low, self.action_dist.high # scalar
+
+ if d_low is not None and d_high is not None:
+ a_low = th.as_tensor(self.action_space.low.astype(np.float32)).to(action.device)
+ a_high = th.as_tensor(self.action_space.high.astype(np.float32)).to(action.device)
+ action = (action-a_low)/(a_high-a_low) * (d_high-d_low) + d_low
+ action = th.clamp(action, d_low+eps, d_high-eps)
+ return action
+
+ def unscale_action(self, action: np.ndarray, eps=0.0) -> np.ndarray:
+ # input action \in [d_low, d_high]
+ # output action \in [a_low+eps, a_high-eps]
+ d_low, d_high = self.action_dist.low, self.action_dist.high # scalar
+
+ if d_low is not None and d_high is not None:
+ # batch_size = action.shape[0]
+ a_low, a_high = self.action_space.low, self.action_space.high
+ # same shape as action [batch_size, action_dim]
+ # a_high = np.tile(self.action_space.high, [batch_size, 1])
+ action = (action-d_low)/(d_high-d_low) * (a_high-a_low) + a_low
+ # action = np.clip(action, a_low+eps, a_high-eps)
+ return action
+
+ def get_init_kwargs(self) -> Dict[str, Any]:
+ init_kwargs = dict(
+ observation_space=self.observation_space,
+ action_space=self.action_space,
+ policy_head_arch=self.policy_head_arch,
+ value_head_arch=self.value_head_arch,
+ features_extractor_entry_point=self.features_extractor_entry_point,
+ features_extractor_kwargs=self.features_extractor_kwargs,
+ distribution_entry_point=self.distribution_entry_point,
+ distribution_kwargs=self.distribution_kwargs,
+ )
+ return init_kwargs
+
+ @classmethod
+ def load(cls, path):
+ if th.cuda.is_available():
+ device = 'cuda'
+ else:
+ device = 'cpu'
+ saved_variables = th.load(path, map_location=device)
+ # Create policy object
+ model = cls(**saved_variables['policy_init_kwargs'])
+ # Load weights
+ model.load_state_dict(saved_variables['policy_state_dict'])
+ model.to(device)
+ return model, saved_variables['train_init_kwargs']
+
+ @staticmethod
+ def init_weights(module: nn.Module, gain: float = 1) -> None:
+ """
+ Orthogonal initialization (used in PPO and A2C)
+ """
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
+ nn.init.orthogonal_(module.weight, gain=gain)
+ if module.bias is not None:
+ module.bias.data.fill_(0.0)
diff --git a/rl_birdview/models/torch_layers.py b/rl_birdview/models/torch_layers.py
new file mode 100644
index 0000000..4bb7575
--- /dev/null
+++ b/rl_birdview/models/torch_layers.py
@@ -0,0 +1,145 @@
+"""Adapted from https://github.com/zhejz/carla-roach CC-BY-NC 4.0 license."""
+
+"""Policies: abstract base class and concrete implementations."""
+
+import torch as th
+import torch.nn as nn
+import numpy as np
+
+from . import torch_util as tu
+
+
+class XtMaCNN(nn.Module):
+ '''
+ Inspired by https://github.com/xtma/pytorch_car_caring
+ '''
+
+ def __init__(self, observation_space, n_input_frames=1, features_dim=256, states_neurons=[256]):
+ super().__init__()
+ self.features_dim = features_dim
+
+ n_input_channels = n_input_frames * observation_space['birdview'].shape[0]
+
+ self.cnn = nn.Sequential(
+ nn.Conv2d(n_input_channels, 8, kernel_size=5, stride=2),
+ nn.ReLU(),
+ nn.Conv2d(8, 16, kernel_size=5, stride=2),
+ nn.ReLU(),
+ nn.Conv2d(16, 32, kernel_size=5, stride=2),
+ nn.ReLU(),
+ nn.Conv2d(32, 64, kernel_size=3, stride=2),
+ nn.ReLU(),
+ nn.Conv2d(64, 128, kernel_size=3, stride=2),
+ nn.ReLU(),
+ nn.Conv2d(128, 256, kernel_size=3, stride=1),
+ nn.ReLU(),
+ nn.Flatten(),
+ )
+ # Compute shape by doing one forward pass
+ with th.no_grad():
+ n_flatten = 1024 #self.cnn(th.as_tensor(observation_space['birdview'].sample()[None]).float()).shape[1]
+
+ self.linear = nn.Sequential(nn.Linear(n_flatten+states_neurons[-1], 512), nn.ReLU(),
+ nn.Linear(512, features_dim), nn.ReLU())
+
+ states_neurons = [observation_space['state'].shape[0]] + states_neurons
+ self.state_linear = []
+ for i in range(len(states_neurons)-1):
+ self.state_linear.append(nn.Linear(states_neurons[i], states_neurons[i+1]))
+ self.state_linear.append(nn.ReLU())
+ self.state_linear = nn.Sequential(*self.state_linear)
+
+ self.apply(self._weights_init)
+
+ @staticmethod
+ def _weights_init(m):
+ if isinstance(m, nn.Conv2d):
+ nn.init.xavier_uniform_(m.weight, gain=nn.init.calculate_gain('relu'))
+ nn.init.constant_(m.bias, 0.1)
+
+ def forward(self, birdview, state):
+ x = self.cnn(birdview)
+ latent_state = self.state_linear(state)
+
+ # latent_state = state.repeat(1, state.shape[1]*256)
+
+ x = th.cat((x, latent_state), dim=1)
+ x = self.linear(x)
+ return x
+
+
+class StateEncoder(nn.Module):
+ def __init__(self, observation_space, features_dim=256, states_neurons=[256]):
+ super().__init__()
+ self.features_dim = features_dim
+
+ n_flatten = 256 * 6 * 6 # 9216
+ self.linear = nn.Sequential(nn.Linear(n_flatten+states_neurons[-1], 512), nn.ReLU(),
+ nn.Linear(512, features_dim), nn.ReLU())
+
+ states_neurons = [observation_space['state'].shape[0]] + states_neurons
+ self.state_linear = []
+ for i in range(len(states_neurons)-1):
+ self.state_linear.append(nn.Linear(states_neurons[i], states_neurons[i+1]))
+ self.state_linear.append(nn.ReLU())
+ self.state_linear = nn.Sequential(*self.state_linear)
+
+ def forward(self, birdview_state, state):
+ batch_size = state.shape[0]
+ birdview_state = birdview_state.view(batch_size, -1)
+ latent_state = self.state_linear(state)
+
+ # latent_state = state.repeat(1, state.shape[1]*256)
+
+ x = th.cat((birdview_state, latent_state), dim=1)
+ x = self.linear(x)
+ return x
+
+
+
+class ImpalaCNN(nn.Module):
+ def __init__(self, observation_space, chans=(16, 32, 32, 64, 64), states_neurons=[256],
+ features_dim=256, nblock=2, batch_norm=False, final_relu=True):
+ # (16, 32, 32)
+ super().__init__()
+ self.features_dim = features_dim
+ self.final_relu = final_relu
+
+ # image encoder
+ curshape = observation_space['birdview'].shape
+ s = 1 / np.sqrt(len(chans)) # per stack scale
+ self.stacks = nn.ModuleList()
+ for outchan in chans:
+ stack = tu.CnnDownStack(curshape[0], nblock=nblock, outchan=outchan, scale=s, batch_norm=batch_norm)
+ self.stacks.append(stack)
+ curshape = stack.output_shape(curshape)
+
+ # dense after concatenate
+ n_image_latent = tu.intprod(curshape)
+ self.dense = tu.NormedLinear(n_image_latent+states_neurons[-1], features_dim, scale=1.4)
+
+ # state encoder
+ states_neurons = [observation_space['state'].shape[0]] + states_neurons
+ self.state_linear = []
+ for i in range(len(states_neurons)-1):
+ self.state_linear.append(tu.NormedLinear(states_neurons[i], states_neurons[i+1]))
+ self.state_linear.append(nn.ReLU())
+ self.state_linear = nn.Sequential(*self.state_linear)
+
+ def forward(self, birdview, state):
+ # birdview: [b, c, h, w]
+ # x = x.to(dtype=th.float32) / self.scale_ob
+
+ for layer in self.stacks:
+ birdview = layer(birdview)
+
+ x = th.flatten(birdview, 1)
+ x = th.relu(x)
+
+ latent_state = self.state_linear(state)
+
+ x = th.cat((x, latent_state), dim=1)
+ x = self.dense(x)
+ if self.final_relu:
+ x = th.relu(x)
+ return x
diff --git a/rl_birdview/models/torch_util.py b/rl_birdview/models/torch_util.py
new file mode 100644
index 0000000..d311caf
--- /dev/null
+++ b/rl_birdview/models/torch_util.py
@@ -0,0 +1,106 @@
+"""Adapted from https://github.com/zhejz/carla-roach CC-BY-NC 4.0 license."""
+
+import torch as th
+from torch import nn
+import math
+import torch.nn.functional as F
+
+
+def NormedLinear(*args, scale=1.0, dtype=th.float32, **kwargs):
+ """
+ nn.Linear but with normalized fan-in init
+ """
+ out = nn.Linear(*args, **kwargs)
+ out.weight.data *= scale / out.weight.norm(dim=1, p=2, keepdim=True)
+ if kwargs.get("bias", True):
+ out.bias.data *= 0
+ return out
+
+
+def NormedConv2d(*args, scale=1, **kwargs):
+ """
+ nn.Conv2d but with normalized fan-in init
+ """
+ out = nn.Conv2d(*args, **kwargs)
+ out.weight.data *= scale / out.weight.norm(dim=(1, 2, 3), p=2, keepdim=True)
+ if kwargs.get("bias", True):
+ out.bias.data *= 0
+ return out
+
+
+def intprod(xs):
+ """
+ Product of a sequence of integers
+ """
+ out = 1
+ for x in xs:
+ out *= x
+ return out
+
+
+class CnnBasicBlock(nn.Module):
+ """
+ Residual basic block (without batchnorm), as in ImpalaCNN
+ Preserves channel number and shape
+ """
+
+ def __init__(self, inchan, scale=1, batch_norm=False):
+ super().__init__()
+ self.inchan = inchan
+ self.batch_norm = batch_norm
+ s = math.sqrt(scale)
+ self.conv0 = NormedConv2d(self.inchan, self.inchan, 3, padding=1, scale=s)
+ self.conv1 = NormedConv2d(self.inchan, self.inchan, 3, padding=1, scale=s)
+ if self.batch_norm:
+ self.bn0 = nn.BatchNorm2d(self.inchan)
+ self.bn1 = nn.BatchNorm2d(self.inchan)
+
+ def residual(self, x):
+ # inplace should be False for the first relu, so that it does not change the input,
+ # which will be used for skip connection.
+ # getattr is for backwards compatibility with loaded models
+ if getattr(self, "batch_norm", False):
+ x = self.bn0(x)
+ x = F.relu(x, inplace=False)
+ x = self.conv0(x)
+ if getattr(self, "batch_norm", False):
+ x = self.bn1(x)
+ x = F.relu(x, inplace=True)
+ x = self.conv1(x)
+ return x
+
+ def forward(self, x):
+ return x + self.residual(x)
+
+
+class CnnDownStack(nn.Module):
+ """
+ Downsampling stack from Impala CNN
+ """
+
+ def __init__(self, inchan, nblock, outchan, scale=1, pool=True, **kwargs):
+ super().__init__()
+ self.inchan = inchan
+ self.outchan = outchan
+ self.pool = pool
+ self.firstconv = NormedConv2d(inchan, outchan, 3, padding=1)
+ s = scale / math.sqrt(nblock)
+ self.blocks = nn.ModuleList(
+ [CnnBasicBlock(outchan, scale=s, **kwargs) for _ in range(nblock)]
+ )
+
+ def forward(self, x):
+ x = self.firstconv(x)
+ if getattr(self, "pool", True):
+ x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
+ for block in self.blocks:
+ x = block(x)
+ return x
+
+ def output_shape(self, inshape):
+ c, h, w = inshape
+ assert c == self.inchan
+ if getattr(self, "pool", True):
+ return (self.outchan, (h + 1) // 2, (w + 1) // 2)
+ else:
+ return (self.outchan, h, w)
\ No newline at end of file
diff --git a/rl_birdview/rl_birdview_agent.py b/rl_birdview/rl_birdview_agent.py
new file mode 100644
index 0000000..bcdf2d2
--- /dev/null
+++ b/rl_birdview/rl_birdview_agent.py
@@ -0,0 +1,122 @@
+"""Adapted from https://github.com/zhejz/carla-roach CC-BY-NC 4.0 license."""
+
+import logging
+import numpy as np
+from omegaconf import OmegaConf
+import wandb
+import copy
+
+from carla_gym.utils.config_utils import load_entry_point
+
+
+class RlBirdviewAgent():
+ def __init__(self, path_to_conf_file='config_agent.yaml'):
+ self._logger = logging.getLogger(__name__)
+ self._render_dict = None
+ self.supervision_dict = None
+ self.setup(path_to_conf_file)
+
+ def setup(self, path_to_conf_file):
+ cfg = OmegaConf.load(path_to_conf_file)
+
+ # load checkpoint from wandb
+ if cfg.wb_run_path is not None:
+ api = wandb.Api()
+ run = api.run(cfg.wb_run_path)
+ all_ckpts = [f for f in run.files() if 'ckpt' in f.name]
+
+ if cfg.wb_ckpt_step is None:
+ f = max(all_ckpts, key=lambda x: int(x.name.split('_')[1].split('.')[0]))
+ self._logger.info(f'Resume checkpoint latest {f.name}')
+ else:
+ wb_ckpt_step = int(cfg.wb_ckpt_step)
+ f = min(all_ckpts, key=lambda x: abs(int(x.name.split('_')[1].split('.')[0]) - wb_ckpt_step))
+ self._logger.info(f'Resume checkpoint closest to step {wb_ckpt_step}: {f.name}')
+
+ f.download(replace=True)
+ run.file('config_agent.yaml').download(replace=True)
+ cfg = OmegaConf.load('config_agent.yaml')
+ self._ckpt = f.name
+ else:
+ self._ckpt = None
+
+ cfg = OmegaConf.to_container(cfg)
+
+ self._obs_configs = cfg['obs_configs']
+ self._train_cfg = cfg['training']
+
+ # prepare policy
+ self._policy_class = load_entry_point(cfg['policy']['entry_point'])
+ self._policy_kwargs = cfg['policy']['kwargs']
+ if self._ckpt is None:
+ self._policy = None
+ else:
+ self._logger.info(f'Loading wandb checkpoint: {self._ckpt}')
+ self._policy, self._train_cfg['kwargs'] = self._policy_class.load(self._ckpt)
+ self._policy = self._policy.eval()
+
+ self._wrapper_class = load_entry_point(cfg['env_wrapper']['entry_point'])
+ self._wrapper_kwargs = cfg['env_wrapper']['kwargs']
+
+ def run_step(self, input_data, timestamp):
+ input_data = copy.deepcopy(input_data)
+
+ policy_input = self._wrapper_class.process_obs(input_data, self._wrapper_kwargs['input_states'], train=False)
+
+ actions, values, log_probs, mu, sigma, features = self._policy.forward(
+ policy_input, deterministic=True, clip_action=True)
+ control = self._wrapper_class.process_act(actions, self._wrapper_kwargs['acc_as_action'], train=False)
+ self.supervision_dict = {
+ 'action': np.array([control.throttle, control.steer, control.brake], dtype=np.float32),
+ 'value': values[0],
+ 'action_mu': mu[0],
+ 'action_sigma': sigma[0],
+ 'features': features[0],
+ 'speed': input_data['speed']['forward_speed']
+ }
+ self.supervision_dict = copy.deepcopy(self.supervision_dict)
+
+ self._render_dict = {
+ 'timestamp': timestamp,
+ 'obs': policy_input,
+ 'im_render': input_data['birdview']['rendered'],
+ 'action': actions,
+ 'action_value': values[0],
+ 'action_log_probs': log_probs[0],
+ 'action_mu': mu[0],
+ 'action_sigma': sigma[0]
+ }
+ self._render_dict = copy.deepcopy(self._render_dict)
+
+ return control
+
+ def reset(self, log_file_path):
+ # logger
+ self._logger.handlers = []
+ self._logger.propagate = False
+ self._logger.setLevel(logging.DEBUG)
+ fh = logging.FileHandler(log_file_path, mode='w')
+ fh.setLevel(logging.DEBUG)
+ self._logger.addHandler(fh)
+
+ def learn(self, env, total_timesteps, callback, seed):
+ if self._policy is None:
+ self._policy = self._policy_class(env.observation_space, env.action_space, **self._policy_kwargs)
+
+ # init ppo model
+ model_class = load_entry_point(self._train_cfg['entry_point'])
+ model = model_class(self._policy, env, **self._train_cfg['kwargs'])
+ model.learn(total_timesteps, callback=callback, seed=seed)
+
+ def render(self, reward_debug, terminal_debug):
+ '''
+ test render, used in benchmark.py
+ '''
+ self._render_dict['reward_debug'] = reward_debug
+ self._render_dict['terminal_debug'] = terminal_debug
+
+ return self._wrapper_class.im_render(self._render_dict)
+
+ @property
+ def obs_configs(self):
+ return self._obs_configs
diff --git a/rl_birdview/utils/rl_birdview_wrapper.py b/rl_birdview/utils/rl_birdview_wrapper.py
new file mode 100644
index 0000000..81a20a0
--- /dev/null
+++ b/rl_birdview/utils/rl_birdview_wrapper.py
@@ -0,0 +1,215 @@
+"""Adapted from https://github.com/zhejz/carla-roach CC-BY-NC 4.0 license."""
+
+import gym
+import numpy as np
+import cv2
+import carla
+
+eval_num_zombie_vehicles = {
+ 'Town01': 120,
+ 'Town02': 70,
+ 'Town03': 70,
+ 'Town04': 150,
+ 'Town05': 120,
+ 'Town06': 120
+}
+eval_num_zombie_walkers = {
+ 'Town01': 120,
+ 'Town02': 70,
+ 'Town03': 70,
+ 'Town04': 80,
+ 'Town05': 120,
+ 'Town06': 80
+}
+
+class RlBirdviewWrapper(gym.Wrapper):
+ def __init__(self, env, input_states=[], acc_as_action=False):
+ assert len(env._obs_configs) == 1
+ self._ev_id = list(env._obs_configs.keys())[0]
+ self._input_states = input_states
+ self._acc_as_action = acc_as_action
+ self._render_dict = {}
+
+ state_spaces = []
+ if 'speed' in self._input_states:
+ state_spaces.append(env.observation_space[self._ev_id]['speed']['speed_xy'])
+ if 'speed_limit' in self._input_states:
+ state_spaces.append(env.observation_space[self._ev_id]['control']['speed_limit'])
+ if 'control' in self._input_states:
+ state_spaces.append(env.observation_space[self._ev_id]['control']['throttle'])
+ state_spaces.append(env.observation_space[self._ev_id]['control']['steer'])
+ state_spaces.append(env.observation_space[self._ev_id]['control']['brake'])
+ state_spaces.append(env.observation_space[self._ev_id]['control']['gear'])
+ if 'acc_xy' in self._input_states:
+ state_spaces.append(env.observation_space[self._ev_id]['velocity']['acc_xy'])
+ if 'vel_xy' in self._input_states:
+ state_spaces.append(env.observation_space[self._ev_id]['velocity']['vel_xy'])
+ if 'vel_ang_z' in self._input_states:
+ state_spaces.append(env.observation_space[self._ev_id]['velocity']['vel_ang_z'])
+
+ state_low = np.concatenate([s.low for s in state_spaces])
+ state_high = np.concatenate([s.high for s in state_spaces])
+
+ env.observation_space = gym.spaces.Dict(
+ {'state': gym.spaces.Box(low=state_low, high=state_high, dtype=np.float32),
+ 'birdview': env.observation_space[self._ev_id]['birdview']['masks']})
+
+ if self._acc_as_action:
+ # act: acc(throttle/brake), steer
+ env.action_space = gym.spaces.Box(low=np.array([-1, -1]), high=np.array([1, 1]), dtype=np.float32)
+ else:
+ # act: throttle, steer, brake
+ env.action_space = gym.spaces.Box(low=np.array([0, -1, 0]), high=np.array([1, 1, 1]), dtype=np.float32)
+
+ super(RlBirdviewWrapper, self).__init__(env)
+
+ self.eval_mode = False
+
+ def reset(self):
+ self.env.set_task_idx(np.random.choice(self.env.num_tasks))
+ if self.eval_mode:
+ self.env._task['num_zombie_vehicles'] = eval_num_zombie_vehicles[self.env._carla_map]
+ self.env._task['num_zombie_walkers'] = eval_num_zombie_walkers[self.env._carla_map]
+ for ev_id in self.env._ev_handler._terminal_configs:
+ self.env._ev_handler._terminal_configs[ev_id]['kwargs']['eval_mode'] = True
+ else:
+ for ev_id in self.env._ev_handler._terminal_configs:
+ self.env._ev_handler._terminal_configs[ev_id]['kwargs']['eval_mode'] = False
+
+ obs_ma = self.env.reset()
+ action_ma = {self._ev_id: carla.VehicleControl(manual_gear_shift=True, gear=1)}
+ obs_ma, _, _, _ = self.env.step(action_ma)
+ action_ma = {self._ev_id: carla.VehicleControl(manual_gear_shift=False)}
+ obs_ma, _, _, _ = self.env.step(action_ma)
+
+ snap_shot = self.env._world.get_snapshot()
+ self.env._timestamp = {
+ 'step': 0,
+ 'frame': 0,
+ 'relative_wall_time': 0.0,
+ 'wall_time': snap_shot.timestamp.platform_timestamp,
+ 'relative_simulation_time': 0.0,
+ 'simulation_time': snap_shot.timestamp.elapsed_seconds,
+ 'start_frame': snap_shot.timestamp.frame,
+ 'start_wall_time': snap_shot.timestamp.platform_timestamp,
+ 'start_simulation_time': snap_shot.timestamp.elapsed_seconds
+ }
+
+ obs = self.process_obs(obs_ma[self._ev_id], self._input_states)
+
+ self._render_dict['prev_obs'] = obs
+ self._render_dict['prev_im_render'] = obs_ma[self._ev_id]['birdview']['rendered']
+ return obs
+
+ def step(self, action):
+ action_ma = {self._ev_id: self.process_act(action, self._acc_as_action)}
+
+ obs_ma, reward_ma, done_ma, info_ma = self.env.step(action_ma)
+
+ obs = self.process_obs(obs_ma[self._ev_id], self._input_states)
+ reward = reward_ma[self._ev_id]
+ done = done_ma[self._ev_id]
+ info = info_ma[self._ev_id]
+
+ self._render_dict = {
+ 'timestamp': self.env.timestamp,
+ 'obs': self._render_dict['prev_obs'],
+ 'prev_obs': obs,
+ 'im_render': self._render_dict['prev_im_render'],
+ 'prev_im_render': obs_ma[self._ev_id]['birdview']['rendered'],
+ 'action': action,
+ 'reward_debug': info['reward_debug'],
+ 'terminal_debug': info['terminal_debug']
+ }
+ return obs, reward, done, info
+
+ def render(self, mode='human'):
+ '''
+ train render: used in train_rl.py
+ '''
+ self._render_dict['action_value'] = self.action_value
+ self._render_dict['action_log_probs'] = self.action_log_probs
+ self._render_dict['action_mu'] = self.action_mu
+ self._render_dict['action_sigma'] = self.action_sigma
+ return self.im_render(self._render_dict)
+
+ @staticmethod
+ def im_render(render_dict):
+ im_birdview = render_dict['im_render']
+ h, w, c = im_birdview.shape
+ im = np.zeros([h, w*2, c], dtype=np.uint8)
+ im[:h, :w] = im_birdview
+
+ action_str = np.array2string(render_dict['action'], precision=2, separator=',', suppress_small=True)
+ mu_str = np.array2string(render_dict['action_mu'], precision=2, separator=',', suppress_small=True)
+ sigma_str = np.array2string(render_dict['action_sigma'], precision=2, separator=',', suppress_small=True)
+ state_str = np.array2string(render_dict['obs']['state'], precision=2, separator=',', suppress_small=True)
+
+ txt_t = f'step:{render_dict["timestamp"]["step"]:5}, frame:{render_dict["timestamp"]["frame"]:5}'
+ im = cv2.putText(im, txt_t, (3, 12), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 255), 1)
+ txt_1 = f'a{action_str} v:{render_dict["action_value"]:5.2f} p:{render_dict["action_log_probs"]:5.2f}'
+ im = cv2.putText(im, txt_1, (3, 24), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 255), 1)
+ txt_2 = f's{state_str}'
+ im = cv2.putText(im, txt_2, (3, 36), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 255), 1)
+
+ txt_3 = f'a{mu_str} b{sigma_str}'
+ im = cv2.putText(im, txt_3, (w, 12), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 255), 1)
+ for i, txt in enumerate(render_dict['reward_debug']['debug_texts'] +
+ render_dict['terminal_debug']['debug_texts']):
+ im = cv2.putText(im, txt, (w, (i+2)*12), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 255), 1)
+ return im
+
+ @staticmethod
+ def process_obs(obs, input_states, train=True):
+
+ state_list = []
+ if 'speed' in input_states:
+ state_list.append(obs['speed']['speed_xy'])
+ if 'speed_limit' in input_states:
+ state_list.append(obs['control']['speed_limit'])
+ if 'control' in input_states:
+ state_list.append(obs['control']['throttle'])
+ state_list.append(obs['control']['steer'])
+ state_list.append(obs['control']['brake'])
+ state_list.append(obs['control']['gear']/5.0)
+ if 'acc_xy' in input_states:
+ state_list.append(obs['velocity']['acc_xy'])
+ if 'vel_xy' in input_states:
+ state_list.append(obs['velocity']['vel_xy'])
+ if 'vel_ang_z' in input_states:
+ state_list.append(obs['velocity']['vel_ang_z'])
+
+ state = np.concatenate(state_list)
+
+ birdview = obs['birdview']['masks']
+
+ if not train:
+ birdview = np.expand_dims(birdview, 0)
+ state = np.expand_dims(state, 0)
+
+ obs_dict = {
+ 'state': state.astype(np.float32),
+ 'birdview': birdview
+ }
+ return obs_dict
+
+ @staticmethod
+ def process_act(action, acc_as_action, train=True):
+ if not train:
+ action = action[0]
+ if acc_as_action:
+ acc, steer = action.astype(np.float64)
+ if acc >= 0.0:
+ throttle = acc
+ brake = 0.0
+ else:
+ throttle = 0.0
+ brake = np.abs(acc)
+ else:
+ throttle, steer, brake = action.astype(np.float64)
+
+ throttle = np.clip(throttle, 0, 1)
+ steer = np.clip(steer, -1, 1)
+ brake = np.clip(brake, 0, 1)
+ control = carla.VehicleControl(throttle=throttle, steer=steer, brake=brake)
+ return control
diff --git a/rl_birdview/utils/wandb_callback.py b/rl_birdview/utils/wandb_callback.py
new file mode 100644
index 0000000..6bc081a
--- /dev/null
+++ b/rl_birdview/utils/wandb_callback.py
@@ -0,0 +1,213 @@
+"""Adapted from https://github.com/zhejz/carla-roach CC-BY-NC 4.0 license."""
+
+import numpy as np
+import time
+from pathlib import Path
+import wandb
+from stable_baselines3.common.callbacks import BaseCallback
+from gym.wrappers.monitoring.video_recorder import ImageEncoder
+from omegaconf import OmegaConf
+
+
+class WandbCallback(BaseCallback):
+ def __init__(self, cfg, vec_env):
+ super(WandbCallback, self).__init__(verbose=1)
+
+ # save_dir = Path.cwd()
+ # self._save_dir = save_dir
+ self._video_path = Path('video')
+ self._video_path.mkdir(parents=True, exist_ok=True)
+ self._ckpt_dir = Path('ckpt')
+ self._ckpt_dir.mkdir(parents=True, exist_ok=True)
+
+ # wandb.init(project=cfg.wb_project, dir=save_dir, name=cfg.wb_runname)
+ wandb.init(project=cfg.wb_project, name=cfg.wb_name, notes=cfg.wb_notes, tags=cfg.wb_tags)
+ wandb.config.update(OmegaConf.to_container(cfg))
+
+ wandb.save('./config_agent.yaml')
+ wandb.save('.hydra/*')
+
+ self.vec_env = vec_env
+
+ self._eval_step = int(1e5)
+ self._buffer_step = int(1e5)
+
+ def _init_callback(self):
+ self.n_epoch = 0
+ self._last_time_buffer = self.model.num_timesteps
+ self._last_time_eval = self.model.num_timesteps
+
+ def _on_step(self) -> bool:
+ return True
+
+ def _on_training_start(self) -> None:
+ pass
+
+ def _on_rollout_start(self):
+ # self.model._last_obs = self.model.env.reset()
+ pass
+
+ def _on_training_end(self) -> None:
+ print(f'n_epoch: {self.n_epoch}, num_timesteps: {self.model.num_timesteps}')
+ # save time
+ time_elapsed = time.time() - self.model.start_time
+ wandb.log({
+ 'time/n_epoch': self.n_epoch,
+ 'time/sec_per_epoch': time_elapsed / (self.n_epoch+1),
+ 'time/fps': (self.model.num_timesteps-self.model.start_num_timesteps) / time_elapsed,
+ 'time/train': self.model.t_train,
+ 'time/train_values': self.model.t_train_values,
+ 'time/rollout': self.model.t_rollout
+ }, step=self.model.num_timesteps)
+ wandb.log(self.model.train_debug, step=self.model.num_timesteps)
+
+ # evaluate and save checkpoint
+ if (self.model.num_timesteps - self._last_time_eval) >= self._eval_step:
+ self._last_time_eval = self.model.num_timesteps
+ # evaluate
+ eval_video_path = (self._video_path / f'eval_{self.model.num_timesteps}.mp4').as_posix()
+ avg_ep_stat, ep_events = self.evaluate_policy(self.vec_env, self.model.policy, eval_video_path)
+ # log to wandb
+ wandb.log({f'video/{self.model.num_timesteps}': wandb.Video(eval_video_path)},
+ step=self.model.num_timesteps)
+ wandb.log(avg_ep_stat, step=self.model.num_timesteps)
+ # save events
+ # eval_json_path = (video_path / f'event_{self.model.num_timesteps}.json').as_posix()
+ # with open(eval_json_path, 'w') as fd:
+ # json.dump(ep_events, fd, indent=4, sort_keys=False)
+ # wandb.save(eval_json_path)
+
+ ckpt_path = (self._ckpt_dir / f'ckpt_{self.model.num_timesteps}.pth').as_posix()
+ self.model.save(ckpt_path)
+ wandb.save(f'./{ckpt_path}')
+ self.n_epoch += 1
+
+ # CONFIGHACK: curriculum
+ # num_zombies = {}
+ # for i in range(self.vec_env.num_envs):
+ # env_all_tasks = self.vec_env.get_attr('all_tasks',indices=i)[0]
+ # num_zombies[f'train/n_veh/{i}'] = env_all_tasks[0]['num_zombie_vehicles']
+ # num_zombies[f'train/n_ped/{i}'] = env_all_tasks[0]['num_zombie_walkers']
+ # if wandb.config['curriculum']:
+ # if avg_ep_stat['eval/route_completed_in_km'] > 1.0:
+ # # and avg_ep_stat['eval/red_light']>0:
+ # for env_task in env_all_tasks:
+ # env_task['num_zombie_vehicles'] += 10
+ # env_task['num_zombie_walkers'] += 10
+ # self.vec_env.set_attr('all_tasks', env_all_tasks, indices=i)
+
+ # wandb.log(num_zombies, step=self.model.num_timesteps)
+
+ def _on_rollout_end(self):
+ wandb.log({'time/rollout': self.model.t_rollout}, step=self.model.num_timesteps)
+
+ # save rollout statistics
+ avg_ep_stat = self.get_avg_ep_stat(self.model.ep_stat_buffer, prefix='rollout/')
+ wandb.log(avg_ep_stat, step=self.model.num_timesteps)
+
+ # action, mu, sigma histogram
+ action_statistics = np.array(self.model.action_statistics)
+ mu_statistics = np.array(self.model.mu_statistics)
+ sigma_statistics = np.array(self.model.sigma_statistics)
+ n_action = action_statistics.shape[-1]
+ action_statistics = action_statistics.reshape(-1, n_action)
+ mu_statistics = mu_statistics.reshape(-1, n_action)
+ sigma_statistics = sigma_statistics.reshape(-1, n_action)
+
+ for i in range(n_action):
+ # path_str = (self._save_dir/f'action{i}.csv').as_posix()
+ # np.savetxt(path_str, action_statistics[:, i], delimiter=',')
+ # wandb.save(path_str)
+ wandb.log({f'action[{i}]': wandb.Histogram(action_statistics[:, i])}, step=self.model.num_timesteps)
+ wandb.log({f'alpha[{i}]': wandb.Histogram(mu_statistics[:, i])}, step=self.model.num_timesteps)
+ wandb.log({f'beta[{i}]': wandb.Histogram(sigma_statistics[:, i])}, step=self.model.num_timesteps)
+
+ # render buffer
+ if (self.model.num_timesteps - self._last_time_buffer) >= self._buffer_step:
+ self._last_time_buffer = self.model.num_timesteps
+ buffer_video_path = (self._video_path / f'buffer_{self.model.num_timesteps}.mp4').as_posix()
+
+ list_buffer_im = self.model.buffer.render()
+ encoder = ImageEncoder(buffer_video_path, list_buffer_im[0].shape, 30, 30)
+ for im in list_buffer_im:
+ encoder.capture_frame(im)
+ encoder.close()
+ encoder = None
+
+ wandb.log({f'buffer/{self.model.num_timesteps}': wandb.Video(buffer_video_path)},
+ step=self.model.num_timesteps)
+
+ @staticmethod
+ def evaluate_policy(env, policy, video_path, min_eval_steps=3000):
+ policy = policy.eval()
+ t0 = time.time()
+ for i in range(env.num_envs):
+ env.set_attr('eval_mode', True, indices=i)
+ obs = env.reset()
+
+ list_render = []
+ ep_stat_buffer = []
+ ep_events = {}
+ for i in range(env.num_envs):
+ ep_events[f'venv_{i}'] = []
+
+ n_step = 0
+ n_timeout = 0
+ env_done = np.array([False]*env.num_envs)
+ # while n_step < min_eval_steps:
+ while n_step < min_eval_steps or not np.all(env_done):
+ actions, values, log_probs, mu, sigma, _ = policy.forward(obs, deterministic=True, clip_action=True)
+ obs, reward, done, info = env.step(actions)
+
+ for i in range(env.num_envs):
+ env.set_attr('action_value', values[i], indices=i)
+ env.set_attr('action_log_probs', log_probs[i], indices=i)
+ env.set_attr('action_mu', mu[i], indices=i)
+ env.set_attr('action_sigma', sigma[i], indices=i)
+
+ list_render.append(env.render(mode='rgb_array'))
+
+ n_step += 1
+ env_done |= done
+
+ for i in np.where(done)[0]:
+ ep_stat_buffer.append(info[i]['episode_stat'])
+ ep_events[f'venv_{i}'].append(info[i]['episode_event'])
+ n_timeout += int(info[i]['timeout'])
+
+ # conda install x264=='1!152.20180717' ffmpeg=4.0.2 -c conda-forge
+ encoder = ImageEncoder(video_path, list_render[0].shape, 30, 30)
+ for im in list_render:
+ encoder.capture_frame(im)
+ encoder.close()
+
+ avg_ep_stat = WandbCallback.get_avg_ep_stat(ep_stat_buffer, prefix='eval/')
+ avg_ep_stat['eval/eval_timeout'] = n_timeout
+
+ duration = time.time() - t0
+ avg_ep_stat['time/t_eval'] = duration
+ avg_ep_stat['time/fps_eval'] = n_step * env.num_envs / duration
+
+ for i in range(env.num_envs):
+ env.set_attr('eval_mode', False, indices=i)
+ obs = env.reset()
+ return avg_ep_stat, ep_events
+
+ @staticmethod
+ def get_avg_ep_stat(ep_stat_buffer, prefix=''):
+ avg_ep_stat = {}
+ if len(ep_stat_buffer) > 0:
+ for ep_info in ep_stat_buffer:
+ for k, v in ep_info.items():
+ k_avg = f'{prefix}{k}'
+ if k_avg in avg_ep_stat:
+ avg_ep_stat[k_avg] += v
+ else:
+ avg_ep_stat[k_avg] = v
+
+ n_episodes = float(len(ep_stat_buffer))
+ for k in avg_ep_stat.keys():
+ avg_ep_stat[k] /= n_episodes
+ avg_ep_stat[f'{prefix}n_episodes'] = n_episodes
+
+ return avg_ep_stat
diff --git a/run/data_collect.sh b/run/data_collect.sh
new file mode 100644
index 0000000..4ac8e23
--- /dev/null
+++ b/run/data_collect.sh
@@ -0,0 +1,38 @@
+#!/bin/bash
+# Adapted from https://github.com/zhejz/carla-roach/ CC-BY-NC 4.0 license.
+
+if [[ $# -ne 1 ]] ; then
+# echo 'Please specify the CARLA executable path, the folder to save the dataset, and the CARLA port.'
+ echo 'Please specify the CARLA port.'
+ exit 1
+fi
+
+#CARLA_PATH=$1
+#DATASET_ROOT=$2
+PORT=$1
+
+data_collect () {
+ python -u data_collect.py --config-name data_collect
+}
+
+source ~/miniconda3/etc/profile.d/conda.sh
+conda activate carla
+
+# Remove checkpoint files
+rm outputs/port_${PORT}_checkpoint.txt
+rm outputs/port_${PORT}_wb_run_id.txt
+rm outputs/port_${PORT}_ep_stat_buffer_*.json
+
+
+# Resume benchmark in case carla crashed.
+RED=$'\e[0;31m'
+NC=$'\e[0m'
+PYTHON_RETURN=1
+until [ $PYTHON_RETURN == 0 ]; do
+ data_collect
+ PYTHON_RETURN=$?
+ echo "${RED} PYTHON_RETURN=${PYTHON_RETURN}!!! Start Over!!!${NC}" >&2
+ sleep 2
+done
+
+echo "Bash script done."
diff --git a/sim_run.py b/sim_run.py
new file mode 100644
index 0000000..66ec115
--- /dev/null
+++ b/sim_run.py
@@ -0,0 +1,120 @@
+import os
+import socket
+import time
+from tqdm import tqdm
+
+import torch
+from torch.utils.tensorboard.writer import SummaryWriter
+# import lightning.pytorch as pl
+import numpy as np
+
+from muvo.config import get_parser, get_cfg
+from muvo.data.dataset import DataModule
+from muvo.trainer import WorldModelTrainer
+from lightning.pytorch.callbacks import ModelSummary
+
+from clearml import Task, Dataset, Model
+
+
+def main():
+ args = get_parser().parse_args()
+ cfg = get_cfg(args)
+
+ task = Task.init(project_name=cfg.CML_PROJECT, task_name=cfg.CML_TASK, task_type=cfg.CML_TYPE, tags=cfg.TAG)
+ task.connect(cfg)
+ cml_logger = task.get_logger()
+ #
+ # dataset_root = Dataset.get(dataset_project=cfg.CML_PROJECT,
+ # dataset_name=cfg.CML_DATASET,
+ # ).get_local_copy()
+
+ # data = DataModule(cfg, dataset_root=dataset_root)
+ data = DataModule(cfg)
+ data.setup()
+
+ input_model = Model(model_id='').get_local_copy() if cfg.PRETRAINED.CML_MODEL else None
+ model = WorldModelTrainer(cfg.convert_to_dict(), pretrained_path=input_model)
+ # model.get_cml_logger(cml_logger)
+
+ save_dir = os.path.join(
+ cfg.LOG_DIR, time.strftime('%d%B%Yat%H:%M:%S%Z') + '_' + socket.gethostname() + '_' + cfg.TAG
+ )
+ # writer = SummaryWriter(log_dir=save_dir)
+
+ dataloader = data.test_dataloader()[2]
+
+ pbar = tqdm(total=len(dataloader), desc='Prediction')
+ model.cuda()
+
+ model.train()
+ for module in model.modules():
+ if isinstance(module, torch.nn.Dropout):
+ module.eval()
+
+ # n_prediction_samples = model.cfg.PREDICTION.N_SAMPLES
+ upload_data = {
+ 'rgb_label': list(),
+ 'throttle_brake': list(),
+ 'steering': list(),
+ 'pcd_label': list(),
+ 'voxel_label': list(),
+ 'rgb_re': list(),
+ 'pcd_re': list(),
+ 'voxel_re': list(),
+ 'rgb_im': list(),
+ 'pcd_im': list(),
+ 'voxel_im': list(),
+ }
+
+ for i, batch in enumerate(dataloader):
+ batch = {key: value.cuda() for key, value in batch.items()}
+ with torch.no_grad():
+ batch = model.preprocess(batch)
+ output, output_imagine = model.model.sim_forward(batch, is_dreaming=False)
+
+ voxel_label = torch.where(batch['voxel_label_1'].squeeze()[0].cpu() != 0)
+ voxel_label = torch.stack(voxel_label).transpose(0, 1).numpy()
+
+ voxel_re = torch.where(torch.argmax(output['voxel_1'][0][0].detach(), dim=-4).cpu() != 0)
+ voxel_re = torch.stack(voxel_re).transpose(0, 1).numpy()
+
+ voxel_im = torch.where(torch.argmax(output_imagine['voxel_1'][0][(0, 3, 9), ...].detach(), dim=-4).cpu() != 0)
+ voxel_im = torch.stack(voxel_im).transpose(0, 1).numpy()
+
+ upload_data['rgb_label'].append((batch['rgb_label_1'][0][0].cpu().numpy() * 255).astype(np.uint8))
+ upload_data['throttle_brake'].append(batch['throttle_brake'][0][0].cpu().numpy())
+ upload_data['steering'].append(batch['steering'][0][0].cpu().numpy())
+ upload_data['pcd_label'].append(batch['range_view_label_1'][0][0].cpu().numpy())
+ upload_data['voxel_label'].append(voxel_label)
+ upload_data['rgb_re'].append((output['rgb_1'][0][0].detach().cpu().numpy() * 255).astype(np.uint8))
+ upload_data['pcd_re'].append(output['lidar_reconstruction_1'][0][0].detach().cpu().numpy())
+ upload_data['voxel_re'].append(voxel_re)
+ upload_data['rgb_im'].append((output_imagine['rgb_1'][0][(0, 3, 9), ...].detach().cpu().numpy() * 255).astype(np.uint8))
+ upload_data['pcd_im'].append(output_imagine['lidar_reconstruction_1'][0][(0, 3, 9), ...].detach().cpu().numpy())
+ upload_data['voxel_im'].append(voxel_im)
+
+ if i % 500 == 0 and i != 0:
+ print(f'Uploading data {i}')
+ task.upload_artifact(f'data_{i}', np.array(upload_data))
+ upload_data = {
+ 'rgb_label': list(),
+ 'throttle_brake': list(),
+ 'steering': list(),
+ 'pcd_label': list(),
+ 'voxel_label': list(),
+ 'rgb_re': list(),
+ 'pcd_re': list(),
+ 'voxel_re': list(),
+ 'rgb_im': list(),
+ 'pcd_im': list(),
+ 'voxel_im': list(),
+ }
+
+ pbar.update(1)
+
+ if i % 500 != 0:
+ task.upload_artifact(f'data_{i}', np.array(upload_data))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..f423e93
--- /dev/null
+++ b/train.py
@@ -0,0 +1,119 @@
+import git
+import os
+import socket
+import time
+from weakref import proxy
+
+import torch
+import lightning.pytorch as pl
+from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
+from lightning.pytorch.callbacks import ModelSummary, LearningRateMonitor
+
+from muvo.config import get_parser, get_cfg
+from muvo.data.dataset import DataModule
+from muvo.trainer import WorldModelTrainer
+
+from clearml import Task, Dataset, Model
+
+
+class SaveGitDiffHashCallback(pl.Callback):
+ def setup(self, trainer, pl_model, stage):
+ repo = git.Repo()
+ trainer.git_hash = repo.head.object.hexsha
+ trainer.git_diff = repo.git.diff(repo.head.commit.tree)
+
+ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
+ checkpoint['world_size'] = trainer.world_size
+ checkpoint['git_hash'] = trainer.git_hash
+ checkpoint['git_diff'] = trainer.git_diff
+
+
+class MyModelCheckpoint(ModelCheckpoint):
+ def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
+ filename = filepath.split('/')[-1]
+ _checkpoint = trainer._checkpoint_connector.dump_checkpoint(self.save_weights_only)
+ try:
+ torch.save(_checkpoint, filename)
+ except AttributeError as err:
+ key = "hyper_parameters"
+ _checkpoint.pop(key, None)
+ print(f"Warning, `{key}` dropped from checkpoint. An attribute is not picklable: {err}")
+ torch.save(_checkpoint, filename)
+
+ self._last_global_step_saved = trainer.global_step
+
+ # notify loggers
+ if trainer.is_global_zero:
+ for logger in trainer.loggers:
+ logger.after_save_checkpoint(proxy(self))
+
+
+def main():
+ args = get_parser().parse_args()
+ cfg = get_cfg(args)
+
+ # task = Task.init(project_name=cfg.CML_PROJECT, task_name=cfg.CML_TASK, task_type=cfg.CML_TYPE, tags=cfg.TAG)
+ # task.connect(cfg)
+ # cml_logger = task.get_logger()
+ #
+ # dataset_root = Dataset.get(dataset_project=cfg.CML_PROJECT,
+ # dataset_name=cfg.CML_DATASET,
+ # ).get_local_copy()
+
+ # data = DataModule(cfg, dataset_root=dataset_root)
+ data = DataModule(cfg)
+
+ input_model = Model(model_id='').get_local_copy() if cfg.PRETRAINED.CML_MODEL else None
+ # input_model = cfg.PRETRAINED.PATH
+ model = WorldModelTrainer(cfg.convert_to_dict(), pretrained_path=input_model)
+ # model = WorldModelTrainer.load_from_checkpoint(checkpoint_path=input_model)
+ # model.get_cml_logger(cml_logger)
+
+ save_dir = os.path.join(
+ cfg.LOG_DIR, time.strftime('%d%B%Yat%H:%M:%S%Z') + '_' + socket.gethostname() + '_' + cfg.TAG
+ )
+ logger = pl.loggers.TensorBoardLogger(save_dir=save_dir)
+
+ callbacks = [
+ ModelSummary(),
+ SaveGitDiffHashCallback(),
+ LearningRateMonitor(),
+ MyModelCheckpoint(
+ save_dir, every_n_train_steps=cfg.VAL_CHECK_INTERVAL,
+ ),
+ ]
+
+ if cfg.LIMIT_VAL_BATCHES in [0, 1]:
+ limit_val_batches = float(cfg.LIMIT_VAL_BATCHES)
+ else:
+ limit_val_batches = cfg.LIMIT_VAL_BATCHES
+
+ replace_sampler_ddp = not cfg.SAMPLER.ENABLED
+
+ trainer = pl.Trainer(
+ # devices=cfg.GPUS,
+ accelerator='auto',
+ # strategy='ddp',
+ precision=cfg.PRECISION,
+ # sync_batchnorm=True,
+ max_epochs=None,
+ max_steps=cfg.STEPS,
+ callbacks=callbacks,
+ logger=logger,
+ log_every_n_steps=cfg.LOGGING_INTERVAL,
+ val_check_interval=cfg.VAL_CHECK_INTERVAL * cfg.OPTIMIZER.ACCUMULATE_GRAD_BATCHES,
+ check_val_every_n_epoch=None,
+ # limit_val_batches=limit_val_batches,
+ limit_val_batches=3,
+ # use_distributed_sampler=replace_sampler_ddp,
+ accumulate_grad_batches=cfg.OPTIMIZER.ACCUMULATE_GRAD_BATCHES,
+ num_sanity_val_steps=2,
+ profiler='simple',
+ )
+
+ trainer.fit(model, datamodule=data)
+ trainer.test(model, dataloaders=data.test_dataloader())
+
+
+if __name__ == '__main__':
+ main()
diff --git a/utils/saving_utils.py b/utils/saving_utils.py
new file mode 100644
index 0000000..9b0053d
--- /dev/null
+++ b/utils/saving_utils.py
@@ -0,0 +1,343 @@
+"""Adapted from https://github.com/zhejz/carla-roach CC-BY-NC 4.0 license."""
+
+import os
+import numpy as np
+import pandas as pd
+import logging
+import cv2
+from PIL import Image
+from tqdm import tqdm
+import tempfile
+import shutil
+
+from data.dataset_utils import preprocess_birdview_and_routemap, binary_to_integer
+from constants import CARLA_FPS
+
+log = logging.getLogger(__name__)
+
+
+def report_dataset_size(dataset_dir):
+ list_runs = list(dataset_dir.glob('*'))
+
+ n_frames = 0
+ for run in list_runs:
+ n_frames += len(os.listdir(os.path.join(run, 'image')))
+
+ log.warning(f'{dataset_dir}: {len(list_runs)} episodes, '
+ f'{n_frames} saved frames={n_frames / (CARLA_FPS * 3600):.2f} hours')
+
+
+class DataWriter:
+ def __init__(self, dir_path, ev_id, im_stack_idx=[-1], run_info=None, save_birdview_label=False,
+ render_image=False):
+ self._dir_path = dir_path
+ self._ev_id = ev_id
+ self._im_stack_idx = np.array(im_stack_idx)
+ self.run_info = run_info
+ self.weather_keys = [
+ 'cloudiness', 'fog_density', 'fog_distance', 'fog_falloff', 'precipitation', 'precipitation_deposits',
+ 'sun_altitude_angle', 'sun_azimuth_angle', 'wetness', 'wind_intensity',
+ ]
+
+ assert self._im_stack_idx[0] == -1, 'Not handled'
+ self.save_birdview_label = save_birdview_label
+ self.render_image = render_image
+
+ os.makedirs(self._dir_path, exist_ok=True)
+ self._tmp_dir = tempfile.mkdtemp(dir=self._dir_path)
+ print(f'tempdir: {self._tmp_dir}')
+
+ self._data_list = []
+
+ def write(self, timestamp, obs, supervision, reward, control_diff=None, weather=None):
+ assert self._ev_id in obs and self._ev_id in supervision
+ obs = obs[self._ev_id]
+ render_rgb = None
+
+ data_dict = {
+ 'step': timestamp['step'],
+ 'obs': {
+ 'central_rgb': None,
+ 'left_rgb': None,
+ 'right_rgb': None,
+ 'depth_semantic': None,
+ # 'all_rgb': None,
+ 'gnss': None,
+ 'speed': None,
+ 'route_plan': None,
+ 'birdview': None,
+ 'point_cloud': None,
+ 'point_cloud_multi': None,
+ 'point_cloud_semantic': None,
+ },
+ 'supervision': None,
+ 'control_diff': None,
+ 'weather': None,
+ 'reward': None,
+ 'critical': True,
+ }
+
+ # central_rgb
+ data_dict['obs']['central_rgb'] = obs['central_rgb']
+ # gnss speed
+ data_dict['obs']['gnss'] = obs['gnss']
+ data_dict['obs']['speed'] = obs['speed']
+
+ # Route plan and birdview
+ data_dict['obs']['route_plan'] = obs['route_plan']
+
+ if self.save_birdview_label:
+ data_dict['obs']['birdview'] = obs['birdview_label']
+ else:
+ data_dict['obs']['birdview'] = obs['birdview']
+
+ # left_rgb & right_rgb
+ if 'left_rgb' in obs and 'right_rgb' in obs:
+ data_dict['obs']['left_rgb'] = obs['left_rgb']
+ data_dict['obs']['right_rgb'] = obs['right_rgb']
+
+ if self.render_image:
+ render_rgb = np.concatenate([obs['central_rgb']['data'],
+ obs['left_rgb']['data'],
+ obs['right_rgb']['data']], axis=0)
+ elif self.render_image:
+ render_rgb = obs['central_rgb']['data']
+
+ # depth_semantic
+ if 'depth_semantic' in obs:
+ data_dict['obs']['depth_semantic'] = obs['depth_semantic']
+
+ # point cloud
+ if 'lidar_points' in obs:
+ data_dict['obs']['point_cloud'] = obs['lidar_points']
+
+ if 'lidar_points_semantic' in obs:
+ data_dict['obs']['point_cloud_semantic'] = obs['lidar_points_semantic']
+
+ if 'lidar_points_multi' in obs:
+ data_dict['obs']['point_cloud_multi'] = obs['lidar_points_multi']
+
+ # supervision
+ data_dict['supervision'] = supervision[self._ev_id]
+ # Add reward in supervision
+ data_dict['supervision']['reward'] = reward[self._ev_id]
+
+ # reward
+ data_dict['reward'] = reward[self._ev_id]
+
+ # control_diff
+ if control_diff is not None:
+ data_dict['control_diff'] = control_diff[self._ev_id]
+
+ if weather is not None:
+ data_dict['weather'] = self.convert_weather_to_dict(weather)
+
+ tmp = tempfile.NamedTemporaryFile(dir=self._tmp_dir, delete=False)
+ np.save(tmp, data_dict)
+ tmp.close()
+ self._data_list.append(tmp.name)
+ # self._data_list.append(data_dict)
+
+ if self.render_image:
+ # put text
+ action_str = np.array2string(supervision[self._ev_id]['action'],
+ precision=2, separator=',', suppress_small=True)
+ speed = supervision[self._ev_id]['speed']
+ txt_1 = f'{action_str} spd:{speed[0]:5.2f}'
+ render_rgb = cv2.putText(render_rgb, txt_1, (0, 12), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (255, 255, 255), 1)
+ return render_rgb
+
+ def convert_weather_to_dict(self, weather):
+ weather_dict = {}
+
+ for key in self.weather_keys:
+ weather_dict[key] = getattr(weather, key)
+
+ return weather_dict
+
+ @staticmethod
+ def _write_dict_to_group(group, key, my_dict):
+ group_key = group.create_group(key)
+ for k, v in my_dict.items():
+ if type(v) == np.ndarray and v.size > 2000:
+ group_key.create_dataset(k, data=v, compression="gzip", compression_opts=4)
+ else:
+ group_key.create_dataset(k, data=v)
+
+ def close(self, terminal_debug, remove_final_steps, last_value=None):
+ # clean up data
+ log.info(f'Episode finished, len={len(self._data_list)}')
+
+ # behaviour cloning dataset
+ valid = True
+ if remove_final_steps:
+ if terminal_debug['traffic_rule_violated']:
+ step_to_delete = min(300, len(self._data_list))
+ del self._data_list[-step_to_delete:]
+ if len(self._data_list) < 300:
+ valid = False
+ log.warning(f'traffic_rule_violated, valid={valid}, len={len(self._data_list)}')
+
+ if terminal_debug['blocked']:
+ step_to_delete = min(600, len(self._data_list))
+ del self._data_list[-step_to_delete:]
+ if len(self._data_list) < 300:
+ valid = False
+ log.warning(f'blocked, valid={valid}, len={len(self._data_list)}')
+
+ if terminal_debug['route_deviation']:
+ valid = False
+ log.warning(f'route deviation, valid={valid}')
+
+ if valid:
+ self.save_files()
+
+ self._data_list.clear()
+ shutil.rmtree(self._tmp_dir)
+ return valid
+
+ def save_files(self):
+ os.makedirs(os.path.join(self._dir_path, 'image'), exist_ok=True)
+ os.makedirs(os.path.join(self._dir_path, 'image_left'), exist_ok=True)
+ os.makedirs(os.path.join(self._dir_path, 'image_right'), exist_ok=True)
+ os.makedirs(os.path.join(self._dir_path, 'image_all'), exist_ok=True)
+ os.makedirs(os.path.join(self._dir_path, 'depth_semantic'), exist_ok=True)
+ os.makedirs(os.path.join(self._dir_path, 'birdview'), exist_ok=True)
+ os.makedirs(os.path.join(self._dir_path, 'routemap'), exist_ok=True)
+ # os.makedirs(os.path.join(self._dir_path, 'points'), exist_ok=True)
+ os.makedirs(os.path.join(self._dir_path, 'points_semantic'), exist_ok=True)
+
+ dict_dataframe = {
+ 'action_mu': [],
+ 'action_sigma': [],
+ 'action': [],
+ 'speed': [],
+ 'reward': [],
+ 'value': [],
+ 'features': [],
+ 'gnss': [],
+ 'target_gps': [],
+ 'imu': [],
+ 'command': [],
+ 'target_gps_next': [],
+ 'command_next': [],
+ 'image_path': [],
+ 'depth_semantic_path': [],
+ # 'depth_semantic_trans': [],
+ 'birdview_path': [],
+ 'routemap_path': [],
+ # 'points_path': [],
+ # 'point_cloud_multi_path': [],
+ 'points_semantic_path': [],
+ 'n_classes': [], # Number of classes in the bev
+ }
+
+ for k in self.run_info.keys():
+ dict_dataframe[k] = []
+ for k in self.weather_keys:
+ dict_dataframe[k] = []
+
+ points_list = {}
+ # points_list_multi = {}
+ points_list_semantic = {}
+
+ log.info(f'Saving {self._dir_path}, data_len={len(self._data_list)}')
+
+ for i, data_name in enumerate(tqdm(self._data_list, desc='Saving data')):
+ data = np.load(data_name, allow_pickle=True).item()
+ os.remove(data_name)
+
+ obs = data['obs']
+ supervision = data['supervision']
+
+ for k, v in supervision.items():
+ dict_dataframe[k].append(v)
+
+ if 'action_mu' not in supervision.keys():
+ # Using autopilot, fill with dummy values
+ for k in ['action_mu', 'action_sigma', 'value', 'features']:
+ dict_dataframe[k].append(np.zeros(1))
+
+ for k, v in obs['gnss'].items():
+ dict_dataframe[k].append(v)
+
+ # Add weather information
+ for k, v in data['weather'].items():
+ dict_dataframe[k].append(v)
+
+ # Add run information
+ for k, v in self.run_info.items():
+ dict_dataframe[k].append(v)
+
+ image = obs['central_rgb']['data']
+
+ if obs['left_rgb'] is not None and obs['right_rgb'] is not None:
+ image_left = obs['left_rgb']['data']
+ image_right = obs['right_rgb']['data']
+ image_all = np.concatenate([obs['left_rgb']['data'],
+ obs['central_rgb']['data'],
+ obs['right_rgb']['data']], axis=1)
+ else:
+ image_all, image_left, image_right = None, None, None
+
+ if obs['depth_semantic'] is not None:
+ depth_semantic = obs['depth_semantic']['data']
+ else:
+ depth_semantic = None
+
+ if obs['point_cloud'] is not None:
+ points = obs['point_cloud']['data']
+ else:
+ points = None
+
+ if obs['point_cloud_semantic'] is not None:
+ points_semantic = obs['point_cloud_semantic']['data']
+ else:
+ points_semantic = None
+
+ # Process birdview and save as png
+ birdview, route_map = preprocess_birdview_and_routemap(obs['birdview']['masks'])
+ birdview, route_map = birdview.numpy(), route_map.numpy()
+ n_bits, h, w = birdview.shape
+ birdview = birdview.reshape(n_bits, -1)
+ birdview = birdview.transpose((1, 0))
+ # Convert bits to integer for storage
+ birdview = binary_to_integer(birdview, n_bits).reshape(h, w)
+
+ image_path = os.path.join(f'image', f'image_{i:09d}.png')
+ birdview_path = os.path.join(f'birdview', f'birdview_{i:09d}.png')
+ routemap_path = os.path.join(f'routemap', f'routemap_{i:09d}.png')
+ dict_dataframe['image_path'].append(image_path)
+ dict_dataframe['birdview_path'].append(birdview_path)
+ dict_dataframe['routemap_path'].append(routemap_path)
+ dict_dataframe['n_classes'].append(n_bits)
+ # Save RGB images
+ Image.fromarray(image).save(os.path.join(self._dir_path, image_path))
+ Image.fromarray(birdview, mode='I').save(os.path.join(self._dir_path, birdview_path))
+ Image.fromarray(route_map, mode='L').save(os.path.join(self._dir_path, routemap_path))
+ if image_all is not None:
+ image_left_path = os.path.join(f'image_left', f'image_left_{i:09d}.png')
+ image_right_path = os.path.join(f'image_right', f'image_right_{i:09d}.png')
+ image_all_path = os.path.join(f'image_all', f'image_all_{i:09d}.png')
+ Image.fromarray(image_left).save(os.path.join(self._dir_path, image_left_path))
+ Image.fromarray(image_right).save(os.path.join(self._dir_path, image_right_path))
+ Image.fromarray(image_all).save(os.path.join(self._dir_path, image_all_path))
+ if depth_semantic is not None:
+ depth_semantic_path = os.path.join(f'depth_semantic', f'depth_semantic_{i:09d}.png')
+ Image.fromarray(depth_semantic).save(os.path.join(self._dir_path, depth_semantic_path))
+ # dict_dataframe['depth_semantic_trans'].append(obs['depth_semantic']['trans'])
+ dict_dataframe['depth_semantic_path'].append(depth_semantic_path)
+
+ # store point cloud
+ if points is not None:
+ points_path = os.path.join(f'points', f'points_{i:09d}.npy')
+ np.save(os.path.join(self._dir_path, points_path), points)
+ dict_dataframe['points_path'].append(points_path)
+ if points_semantic is not None:
+ points_semantic_path = os.path.join(f'points_semantic', f'points_semantic_{i:09d}.npy')
+ np.save(os.path.join(self._dir_path, points_semantic_path), points_semantic)
+ dict_dataframe['points_semantic_path'].append(points_semantic_path)
+
+ pd_dataframe = pd.DataFrame(dict_dataframe)
+ pd_dataframe.to_pickle(os.path.join(self._dir_path, 'pd_dataframe.pkl'))
+
diff --git a/utils/server_utils.py b/utils/server_utils.py
new file mode 100644
index 0000000..b3b6b3e
--- /dev/null
+++ b/utils/server_utils.py
@@ -0,0 +1,65 @@
+"""Adapted from https://github.com/zhejz/carla-roach CC-BY-NC 4.0 license."""
+
+import subprocess
+import os
+import time
+from omegaconf import OmegaConf
+import logging
+log = logging.getLogger(__name__)
+
+from constants import CARLA_FPS
+
+
+def kill_carla(port=2005):
+ # The command below kills ALL carla processes
+ #kill_process = subprocess.Popen('killall -9 -r CarlaUE4-Linux', shell=True)
+
+ # This one only kills processes linked to a certain port
+ kill_process = subprocess.Popen(f'fuser -k {port}/tcp', shell=True)
+ log.info(f"Killed Carla Servers on port {port}!")
+ kill_process.wait()
+ time.sleep(1)
+
+
+class CarlaServerManager():
+ def __init__(self, carla_sh_str, port=2000, configs=None, render_off_screen=False, t_sleep=5):
+ self._carla_sh_str = carla_sh_str
+ self.port = port
+ self._render_off_screen = render_off_screen
+ # self._root_save_dir = root_save_dir
+ self._t_sleep = t_sleep
+ self.env_configs = []
+
+ if configs is None:
+ cfg = {
+ 'gpu': os.environ.get('CUDA_VISIBLE_DEVICES'),
+ 'port': port,
+ }
+ self.env_configs.append(cfg)
+ else:
+ for cfg in configs:
+ for gpu in cfg['gpu']:
+ single_env_cfg = OmegaConf.to_container(cfg)
+ single_env_cfg['gpu'] = gpu
+ single_env_cfg['port'] = port
+ self.env_configs.append(single_env_cfg)
+ port += 5
+
+ def start(self):
+ kill_carla(self.port)
+ for cfg in self.env_configs:
+ if self._render_off_screen:
+ cmd = f'{self._carla_sh_str} ' \
+ f'-fps={CARLA_FPS} -quality-level=Epic -carla-rpc-port={cfg["port"]} -RenderOffScreen'
+ else:
+ cmd = f'CUDA_VISIBLE_DEVICES={cfg["gpu"]} bash {self._carla_sh_str} ' \
+ f'-fps={CARLA_FPS} -quality-level=Epic -carla-rpc-port={cfg["port"]}'
+ # cmd = f'{self._carla_sh_str} ' \
+ # f'-fps={CARLA_FPS} -quality-level=Epic -carla-rpc-port={cfg["port"]}'
+ log.info(cmd)
+ server_process = subprocess.Popen(cmd, shell=True, preexec_fn=os.setsid)
+ time.sleep(self._t_sleep)
+
+ def stop(self):
+ kill_carla(self.port)
+ time.sleep(self._t_sleep)
diff --git a/vis/.gitkeep b/vis/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/vis/graphs.py b/vis/graphs.py
new file mode 100644
index 0000000..dd37f7f
--- /dev/null
+++ b/vis/graphs.py
@@ -0,0 +1,91 @@
+import json
+import matplotlib
+matplotlib.use("pgf")
+import matplotlib.pyplot as plt
+
+matplotlib.rcParams.update({
+ "pgf.texsystem": "pdflatex",
+ 'font.family': 'serif',
+ 'text.usetex': True,
+ 'pgf.rcfonts': False,
+ 'font.size': 6
+})
+
+def filter_data(task, threshold=50002):
+ filtered_x = [x for x in task["x"] if x <= threshold]
+ filtered_y = [task["y"][i] for i, x in enumerate(task["x"]) if x <= threshold]
+ return {"name": task["name"], "x": filtered_x, "y": filtered_y, "type": task["type"], "task": task["task"]}
+
+def load_data(file_path):
+ with open(file_path, 'r') as file:
+ return json.load(file)
+
+def plot_data(ax, task_data, linestyle='-o'):
+ for task_name, task_info in task_data.items():
+ ax.plot(task_info["x"], task_info["y"], linestyle, label=task_name, color=task_info["color"], linewidth = '0.5',markersize=2)
+
+# Load data from the files
+file_paths = [
+ 'val_imagine2_chamfer_distance _ val_imagine2_chamfer_distance.json',
+ 'val_imagine1_chamfer_distance_val_imagine1_chamfer_distance.json',
+ 'val_imagine2_psnr _ val_imagine2_psnr.json',
+ 'val_imagine1_psnr _ val_imagine1_psnr.json'
+]
+
+data_1 = [filter_data(task) for task in load_data(file_paths[0])]
+data_2 = [filter_data(task) for task in load_data(file_paths[1])]
+data_3 = [filter_data(task) for task in load_data(file_paths[2])]
+data_4 = [filter_data(task) for task in load_data(file_paths[3])]
+
+# Create a dictionary to store data for each task
+task_data_1 = {}
+task_data_2 = {}
+task_data_3 = {}
+task_data_4 = {}
+
+# Define a list of colors to use for each task
+colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f']
+task_names = ["PP-BEV-AVG", "PP-BEV-FC", "RV-BEV-FC", "PP-RN-TR", "PP-BEV-TR", "RV-BEV-AVG", "RV-RN-TR", "RV-BEV-TR"]
+
+# Process filtered data
+for i, (filtered_data, task_data) in enumerate(zip([data_1, data_2, data_3, data_4], [task_data_1, task_data_2, task_data_3, task_data_4])):
+ for j, item in enumerate(filtered_data):
+ #task_name = item["task"]
+ task_name = task_names[j]
+ color = colors[j % len(colors)]
+ if task_name not in task_data:
+ task_data[task_name] = {"x": [], "y": [], "name": item["name"], "color": color}
+
+ task_data[task_name]["x"].extend(item["x"])
+ task_data[task_name]["y"].extend(item["y"])
+
+# Plot the data for each task in subplots
+subplot_width = 1.71875 # 6.875 inches divided by 4 subplots
+
+fig, axs = plt.subplots(1, 4, figsize=(6.875, subplot_width), gridspec_kw={'width_ratios': [subplot_width] * 4})
+#fig.set_size_inches(w=6.875, h=1.5)
+
+plot_data(axs[0], task_data_1)
+plot_data(axs[1], task_data_2)
+plot_data(axs[2], task_data_3)
+plot_data(axs[3], task_data_4)
+
+# Set titles and labels
+axs[0].set_title(f"$\mathcal{{D}}_{{val}}^{{RL}}$")
+axs[1].set_title(f"$\mathcal{{D}}_{{val}}^{{DS}}$")
+axs[2].set_title(f"$\mathcal{{D}}_{{val}}^{{RL}}$")
+axs[3].set_title(f"$\mathcal{{D}}_{{val}}^{{DS}}$")
+axs[0].set_ylabel(f"Chamfer Distance (Lidar) $\\downarrow$")
+axs[2].set_ylabel(f"PSNR (Camera) $\\uparrow$")
+
+for ax in axs:
+ ax.get_xaxis().set_major_formatter(
+ matplotlib.ticker.FuncFormatter(lambda x, p: format(int(x), ',')))
+
+# Only show legend in the top-left subplot
+axs[0].legend(fontsize=4,fancybox=True)
+
+# Save the plot
+fig.tight_layout()
+plt.savefig('sensor_fusion.pgf')
+