Skip to content

Commit

Permalink
v1.2.4: use numba to speed up numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
Sweetnow committed Jan 10, 2025
1 parent f74112e commit 6f271c7
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 16 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "python-moss"
version = "1.2.3"
version = "1.2.4"
description = "MObility Simulation System"
authors = [
{ name = "Jun Zhang", email = "[email protected]" },
Expand All @@ -18,6 +18,7 @@ dependencies = [
"protobuf>=3.20,<5",
"numpy>=1.20,<2",
"pycityproto>=2,<3",
"numba>=0.60",
"tqdm",
"psycopg[binary,pool]>=3.0,<4",
"pymongo",
Expand Down
86 changes: 71 additions & 15 deletions python/src/moss/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from warnings import warn

import numpy as np
from numba import njit
from numpy.typing import NDArray
from pycityproto.city.map.v2.map_pb2 import Map
from pycityproto.city.person.v2.person_pb2 import Persons
Expand All @@ -31,6 +32,40 @@ def _import():
_thread_local = threading.local()


@njit
def _populate_waiting_at_lane_end(
enable: np.ndarray,
status: np.ndarray,
lane_id: np.ndarray,
lane_length_array: np.ndarray,
v: np.ndarray,
s: np.ndarray,
speed_threshold: float,
distance_to_end: float,
):
filter = (enable == 1) & (status == DRIVING) & (v < speed_threshold)
filtered_lane_id = lane_id[filter]
filtered_s = s[filter]
# find the distance to the end of the lane
lane_ids_for_count = []
for i, s in zip(filtered_lane_id, filtered_s):
if lane_length_array[i] - s < distance_to_end:
lane_ids_for_count.append(i)
return lane_ids_for_count


@njit
def _populate_waiting_at_lane(
enable: np.ndarray,
status: np.ndarray,
lane_id: np.ndarray,
v: np.ndarray,
speed_threshold: float,
):
filter = (enable == 1) & (status == DRIVING) & (v < speed_threshold)
return lane_id[filter]


class TlPolicy(Enum):
MANUAL = 0
FIXED_TIME = 1
Expand Down Expand Up @@ -115,13 +150,15 @@ def __init__(
The interval of speed statistics. Set to `0` to disable speed statistics.
"""
# check parameters

if step_interval <= 0:
raise ValueError("step_interval should be greater than 0")
if step_interval > 1:
warn("step_interval is greater than 1, the simulation may not be accurate")
if person_limit < -1:
raise ValueError("person_limit should be greater than -1, -1 means no limit")
raise ValueError(
"person_limit should be greater than -1, -1 means no limit"
)
if junction_yellow_time < 0:
raise ValueError("junction_yellow_time should be greater than 0")
if phase_pressure_coeff <= 0:
Expand Down Expand Up @@ -200,6 +237,12 @@ def __init__(
"""
Dictionary of road index indexed by road id
"""
self.lane_length_array = np.array(
[l.length for l in self._map.lanes], dtype=np.float32
)
"""
Numpy array of lane length indexed by lane index
"""

self._persons = Persons()
with open(person_file, "rb") as f:
Expand Down Expand Up @@ -356,7 +399,11 @@ def fetch_persons(self, fields: List[str] = []) -> Dict[str, NDArray]:
"traveling_time",
"total_distance",
]
has_fields = set() if self._fetched_persons is None else set(self._fetched_persons.keys())
has_fields = (
set()
if self._fetched_persons is None
else set(self._fetched_persons.keys())
)
delta_fields = set(fields) - has_fields
if len(delta_fields) > 0:
(
Expand Down Expand Up @@ -437,7 +484,7 @@ def get_running_person_count(self) -> int:
Get the total number of running persons (including driving and walking)
"""
persons = self.fetch_persons(["enable", "status"])
enable = persons["enable"] # type: NDArray[np.uint8]
enable = persons["enable"] # type: NDArray[np.uint8]
status: NDArray[np.uint8] = persons["status"]
return ((enable == 1) & ((status == DRIVING) | (status == WALKING))).sum()

Expand All @@ -464,8 +511,13 @@ def get_lane_waiting_vehicle_counts(
lane_id = persons["lane_id"]
status = persons["status"]
v = persons["v"]
filter = (enable == 1) & (status == DRIVING) & (v < speed_threshold)
filtered_lane_id = lane_id[filter]
filtered_lane_id = _populate_waiting_at_lane(
enable=enable,
status=status,
lane_id=lane_id,
v=v,
speed_threshold=speed_threshold,
)
# count for the lane id
unique, counts = np.unique(filtered_lane_id, return_counts=True)
return unique, counts
Expand All @@ -486,14 +538,16 @@ def get_lane_waiting_at_end_vehicle_counts(
status = persons["status"]
v = persons["v"]
s = persons["s"]
filter = (enable == 1) & (status == DRIVING) & (v < speed_threshold)
filtered_lane_id = lane_id[filter]
filtered_s = s[filter]
# find the distance to the end of the lane
lane_ids_for_count = []
for i, s in zip(filtered_lane_id, filtered_s):
if self.id2lanes[i].length - s < distance_to_end:
lane_ids_for_count.append(i)
lane_ids_for_count = _populate_waiting_at_lane_end(
enable=enable,
status=status,
lane_id=lane_id,
lane_length_array=self.lane_length_array,
v=v,
s=s,
speed_threshold=speed_threshold,
distance_to_end=distance_to_end,
)
# count for the lane id
unique, counts = np.unique(lane_ids_for_count, return_counts=True)
return unique, counts
Expand Down Expand Up @@ -639,7 +693,9 @@ def get_road_vehicle_counts(self) -> Tuple[NDArray[np.int32], NDArray[np.int32]]
enable = persons["enable"]
road_id = persons["lane_parent_id"]
status = persons["status"]
filter = (enable == 1) & (status == DRIVING)
filter = (
(enable == 1) & (status == DRIVING) & (road_id < 3_0000_0000)
) # the road id ranges [2_0000_0000, 3_0000_0000)
filtered_road_id = road_id[filter]
# count for the road id
unique, counts = np.unique(filtered_road_id, return_counts=True)
Expand Down

0 comments on commit 6f271c7

Please sign in to comment.