Skip to content

Commit

Permalink
fix bug in python
Browse files Browse the repository at this point in the history
  • Loading branch information
Sweetnow committed Jan 6, 2025
1 parent d5cf0a1 commit d681c50
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions python/src/moss/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,10 +375,7 @@ def get_running_person_count(self) -> int:
"""
persons = self.fetch_persons()
status: NDArray[np.uint8] = persons["status"]
return (
(status == PersonStatus.DRIVING.value)
| (status == PersonStatus.WALKING.value)
).sum()
return ((status == DRIVING) | (status == WALKING)).sum()

def get_lane_statuses(self) -> NDArray[np.int8]:
"""
Expand All @@ -402,7 +399,7 @@ def get_lane_waiting_vehicle_counts(
lane_id = persons["lane_id"]
status = persons["status"]
v = persons["v"]
filter = (status == PersonStatus.DRIVING.value) & (v < speed_threshold)
filter = (status == DRIVING) & (v < speed_threshold)
filtered_lane_id = lane_id[filter]
# count for the lane id
unique, counts = np.unique(filtered_lane_id, return_counts=True)
Expand All @@ -424,7 +421,7 @@ def get_lane_waiting_at_end_vehicle_counts(
status = persons["status"]
v = persons["v"]
s = persons["s"]
filter = (status == PersonStatus.DRIVING.value) & (v < speed_threshold)
filter = (status == DRIVING) & (v < speed_threshold)
filtered_lane_id = lane_id[filter]
filtered_s = s[filter]
# find the distance to the end of the lane
Expand Down Expand Up @@ -542,7 +539,7 @@ def get_running_person_average_traveling_time(self) -> float:
persons = self.fetch_persons()
status: NDArray[np.uint8] = persons["status"]
traveling_time = persons["traveling_time"]
return traveling_time[status == PersonStatus.DRIVING.value].mean()
return traveling_time[status == DRIVING].mean()

def get_departed_person_average_traveling_time(self) -> float:
"""
Expand All @@ -569,7 +566,7 @@ def get_road_vehicle_counts(self) -> Dict[int, int]:
persons = self.fetch_persons()
road_id = persons["lane_parent_id"]
status = persons["status"]
filter = status == PersonStatus.DRIVING.value
filter = status == DRIVING
filtered_road_id = road_id[filter]
# count for the road id
unique, counts = np.unique(filtered_road_id, return_counts=True)
Expand All @@ -590,7 +587,7 @@ def get_road_waiting_vehicle_counts(
status = persons["status"]
v = persons["v"]
filter = (
(status == PersonStatus.DRIVING.value)
(status == DRIVING)
& (v < speed_threshold)
& (road_id < 3_0000_0000) # the road id ranges [2_0000_0000, 3_0000_0000)
)
Expand Down

0 comments on commit d681c50

Please sign in to comment.