Skip to content

Commit

Permalink
Speedup output [9.3+75]->[2+14]
Browse files Browse the repository at this point in the history
  • Loading branch information
aowenxuan committed Jun 20, 2024
1 parent 7f04921 commit 6c22ef8
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 86 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "python-moss"
version = "0.3.4"
version = "0.3.5"
description = "MObility Simulation System"
authors = [
{ name = "Wenxuan Ao", email = "[email protected]" },
Expand All @@ -26,7 +26,7 @@ dependencies = [
"pycityproto==1.13.1",
"igraph",
"tqdm",
"psycopg",
"psycopg2-binary>=2.5,<3",
"pyproj",
]

Expand Down
181 changes: 108 additions & 73 deletions python/src/moss/export.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import io
from typing import Iterator, Optional

import numpy as np
import psycopg
import psycopg2
import pyproj
from tqdm import tqdm

Expand Down Expand Up @@ -35,6 +38,42 @@ def save(self, filepath):
np.savez_compressed(filepath, data=self.data)


class StringIteratorIO(io.TextIOBase):
def __init__(self, iter: Iterator[str]):
self._iter = iter
self._buff = ''

def readable(self) -> bool:
return True

def _read1(self, n: Optional[int] = None) -> str:
while not self._buff:
try:
self._buff = next(self._iter)
except StopIteration:
break
ret = self._buff[:n]
self._buff = self._buff[len(ret):]
return ret

def read(self, n: Optional[int] = None) -> str:
line = []
if n is None or n < 0:
while True:
m = self._read1()
if not m:
break
line.append(m)
else:
while n > 0:
m = self._read1(n)
if not m:
break
n -= len(m)
line.append(m)
return ''.join(line)


class DBRecorder:
"""
DBRecorder is for web visualization and writes to Postgres Database
Expand Down Expand Up @@ -62,68 +101,61 @@ def save(self, db_url: str, mongo_map: str, output_name: str, batch_size=1000, u
xs = []
ys = []
proj = pyproj.Proj(self.eng._e.get_map_projection())
for step, vs, ts in self.data:
for step, (vs, vx, vy), (ts, tx, ty) in self.data:
if vs:
p, l, x, y, d, v = zip(*vs)
x, y = proj(x, y, True)
xs += x
ys += y
for p, l, x, y, d, v in zip(p, l, x, y, d, v):
vehs.append(f"({step},{p},{l},{round(d,3)},{x},{y},'',0,0,{round(v,3)})")
x, y = proj(vx, vy, True)
xs.extend(x)
ys.extend(y)
vehs.append([step, vs, x, y])
if ts:
p, s, x, y = zip(*ts)
x, y = proj(x, y, True)
xs += x
ys += y
for p, s, x, y in zip(p, s, x, y):
tls.append(f"({step},{p},{s},{x},{y})")
x, y = proj(tx, ty, True)
xs.extend(x)
ys.extend(y)
tls.append([step, ts, x, y])
if xs:
min_lon, max_lon, min_lat, max_lat = min(xs), max(xs), min(ys), max(ys)
else:
x1, y1, x2, y2 = self.eng._e.get_map_bbox()
min_lon, min_lat = proj(x1, y1, True)
max_lon, max_lat = proj(x2, y2, True)
with psycopg.connect(db_url) as conn:
with psycopg2.connect(db_url) as conn:
with conn.cursor() as cur:
cur.execute("""
CREATE TABLE IF NOT EXISTS public.meta_simple (
"name" text NOT NULL,
"start" int4 NOT NULL,
steps int4 NOT NULL,
"time" float8 NOT NULL,
total_agents int4 NOT NULL,
"map" text NOT NULL,
min_lng float8 NOT NULL,
min_lat float8 NOT NULL,
max_lng float8 NOT NULL,
max_lat float8 NOT NULL,
road_status_v_min float8 NULL,
road_status_interval int4 NULL,
CONSTRAINT meta_simple_pkey PRIMARY KEY (name)
)""")
cur.execute(f"DELETE FROM public.meta_simple WHERE name='{output_name}'")
cur.execute(
f"INSERT INTO public.meta_simple VALUES ('{output_name}', {self.eng.start_step}, {len(self.data)}, 1, 1, '{mongo_map}', {min_lon}, {min_lat}, {max_lon}, {max_lat}, 0, 300)")
cur.execute(f"DROP TABLE IF EXISTS {output_name}_s_cars")
cur.execute(f"DROP TABLE IF EXISTS {output_name}_s_people")
cur.execute(f"DROP TABLE IF EXISTS {output_name}_s_traffic_light")
cur.execute(f"DROP TABLE IF EXISTS {output_name}_s_road")
cur.execute(f"""
CREATE TABLE public.{output_name}_s_cars (
step int4 NOT NULL,
id int4 NOT NULL,
parent_id int4 NOT NULL,
direction float8 NOT NULL,
lng float8 NOT NULL,
lat float8 NOT NULL,
model text NOT NULL,
z float8 NOT NULL,
pitch float8 NOT NULL,
v float8 NOT NULL
)""")
cur.execute(f"CREATE INDEX {output_name}_s_cars_step_lng_lat_idx ON public.{output_name}_s_cars USING btree (step, lng, lat)")
cur.execute(f"""
CREATE TABLE public.{output_name}_s_people (
CREATE TABLE IF NOT EXISTS public.meta_simple (
"name" text NOT NULL,
"start" int4 NOT NULL,
steps int4 NOT NULL,
"time" float8 NOT NULL,
total_agents int4 NOT NULL,
"map" text NOT NULL,
min_lng float8 NOT NULL,
min_lat float8 NOT NULL,
max_lng float8 NOT NULL,
max_lat float8 NOT NULL,
road_status_v_min float8 NULL,
road_status_interval int4 NULL,
CONSTRAINT meta_simple_pkey PRIMARY KEY (name)
);
DELETE FROM public.meta_simple WHERE name='{output_name}';
INSERT INTO public.meta_simple VALUES ('{output_name}', {self.eng.start_step}, {len(self.data)}, 1, 1, '{mongo_map}', {min_lon}, {min_lat}, {max_lon}, {max_lat}, 0, 300);
DROP TABLE IF EXISTS {output_name}_s_cars;
DROP TABLE IF EXISTS {output_name}_s_people;
DROP TABLE IF EXISTS {output_name}_s_traffic_light;
DROP TABLE IF EXISTS {output_name}_s_road;
CREATE TABLE {output_name}_s_cars (
step int4 NOT NULL,
id int4 NOT NULL,
parent_id int4 NOT NULL,
direction float8 NOT NULL,
lng float8 NOT NULL,
lat float8 NOT NULL,
model text NOT NULL,
z float8 NOT NULL,
pitch float8 NOT NULL,
v float8 NOT NULL
);
CREATE INDEX {output_name}_s_cars_step_lng_lat_idx ON {output_name}_s_cars USING btree (step, lng, lat);
CREATE TABLE {output_name}_s_people (
step int4 NOT NULL,
id int4 NOT NULL,
parent_id int4 NOT NULL,
Expand All @@ -133,35 +165,38 @@ def save(self, db_url: str, mongo_map: str, output_name: str, batch_size=1000, u
z float8 NOT NULL,
v float8 NOT NULL,
model text NOT NULL
)""")
cur.execute(f"CREATE INDEX {output_name}_s_people_step_lng_lat_idx ON public.{output_name}_s_people USING btree (step, lng, lat)")
cur.execute(f"""
CREATE TABLE public.{output_name}_s_traffic_light (
);
CREATE INDEX {output_name}_s_people_step_lng_lat_idx ON {output_name}_s_people USING btree (step, lng, lat);
CREATE TABLE {output_name}_s_traffic_light (
step int4 NOT NULL,
id int4 NOT NULL,
state int4 NOT NULL,
lng float8 NOT NULL,
lat float8 NOT NULL
)""")
cur.execute(f"CREATE INDEX {output_name}_s_traffic_light_step_lng_lat_idx ON public.{output_name}_s_traffic_light USING btree (step, lng, lat)")
cur.execute(f"""
CREATE TABLE public.{output_name}_s_road (
);
CREATE INDEX {output_name}_s_traffic_light_step_lng_lat_idx ON {output_name}_s_traffic_light USING btree (step, lng, lat);
CREATE TABLE {output_name}_s_road (
step int4 NOT NULL,
id int4 NOT NULL,
"level" int4 NOT NULL,
v float8 NOT NULL,
in_vehicle_cnt int4 NOT NULL,
out_vehicle_cnt int4 NOT NULL,
cnt int4 NOT NULL
)""")
cur.execute(f"CREATE INDEX {output_name}_s_road_step_idx ON public.{output_name}_s_road USING btree (step)")
for i in tqdm(range(0, len(vehs), batch_size), ncols=90, disable=not use_tqdm):
cur.execute(
f"INSERT INTO public.{output_name}_s_cars VALUES "
+ ",".join(vehs[i: i + batch_size])
)
for i in tqdm(range(0, len(tls), batch_size), ncols=90, disable=not use_tqdm):
cur.execute(
f"INSERT INTO public.{output_name}_s_traffic_light VALUES "
+ ",".join(tls[i: i + batch_size])
)
);
CREATE INDEX {output_name}_s_road_step_idx ON {output_name}_s_road USING btree (step);
""")
cur.copy_from(
StringIteratorIO(
f"{step},{p},{l},{round(d,3)},{x},{y},'',0,0,{round(v,3)}\n"
for step, vs, x, y in tqdm(vehs, ncols=90, disable=not use_tqdm) for (p, l, d, v), x, y in zip(vs, x, y)
),
f'{output_name}_s_cars', sep=','
)
cur.copy_from(
StringIteratorIO(
f"{step},{p},{s},{x},{y}\n"
for step, ts, x, y in tqdm(tls, ncols=90, disable=not use_tqdm) for (p, s), x, y in zip(ts, x, y)
),
f'{output_name}_s_traffic_light', sep=','
)
34 changes: 23 additions & 11 deletions src/python_api.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ using no_gil = py::call_guard<py::gil_scoped_release>;
template <typename T>
inline py::array_t<typename T::value_type> asarray(T* ptr) {
py::gil_scoped_acquire _;
auto capsule =
py::capsule(ptr, [](void* p) { delete reinterpret_cast<T*>(p); });
return py::array(ptr->size(), ptr->data(), capsule);
return py::array(ptr->size(), ptr->data(), py::capsule(ptr, [](void* p) {
delete reinterpret_cast<T*>(p);
}));
}
template <class T>
vec<T>& remove_duplicate(vec<T>& arr) {
Expand Down Expand Up @@ -485,25 +485,37 @@ class Engine {
}
// 获取用于输出的车辆信息 (id, parent_id, x, y, dir)
auto get_output_vehicles() {
vec<std::tuple<int, int, float, float, float, float>> out;
vec<std::tuple<int, int, float, float>> out;
auto xs = new vec<float>;
auto ys = new vec<float>;
out.reserve(get_running_vehicle_count());
xs->reserve(get_running_vehicle_count());
ys->reserve(get_running_vehicle_count());
for (auto& p : S.person.persons) {
if (p.runtime.status == PersonStatus::DRIVING) {
out.emplace_back(
p.id, p.runtime.lane ? p.runtime.lane->id : unsigned(-1),
p.runtime.x, p.runtime.y, p.runtime.dir, p.runtime.speed);
out.emplace_back(p.id,
p.runtime.lane ? p.runtime.lane->id : unsigned(-1),
p.runtime.dir, p.runtime.speed);
xs->push_back(p.runtime.x);
ys->push_back(p.runtime.y);
}
}
return out;
return std::make_tuple(out, asarray(xs), asarray(ys));
}
// 获取用于输出的信号灯信息 (id, state, x, y)
auto get_output_tls() {
vec<std::tuple<int, int, float, float>> out;
vec<std::tuple<int, int>> out;
auto xs = new vec<float>;
auto ys = new vec<float>;
out.reserve(S.lane.output_lanes.size);
xs->reserve(S.lane.output_lanes.size);
ys->reserve(S.lane.output_lanes.size);
for (auto* l : S.lane.output_lanes) {
out.emplace_back(l->id, l->light_state, l->center_x, l->center_y);
out.emplace_back(l->id, l->light_state);
xs->push_back(l->center_x);
ys->push_back(l->center_y);
}
return out;
return std::make_tuple(out, asarray(xs), asarray(ys));
}
// 设置人禁用
void set_vehicle_enable(uint vehicle_index, bool enable) {
Expand Down

0 comments on commit 6c22ef8

Please sign in to comment.