diff --git a/pyproject.toml b/pyproject.toml index 38bfea1..c6daae0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta" [project] name = "python-moss" -version = "0.3.4" +version = "0.3.5" description = "MObility Simulation System" authors = [ { name = "Wenxuan Ao", email = "aowx21@outlook.com" }, @@ -26,7 +26,7 @@ dependencies = [ "pycityproto==1.13.1", "igraph", "tqdm", - "psycopg", + "psycopg2-binary>=2.5,<3", "pyproj", ] diff --git a/python/src/moss/export.py b/python/src/moss/export.py index 938d51e..74f8726 100644 --- a/python/src/moss/export.py +++ b/python/src/moss/export.py @@ -1,5 +1,8 @@ +import io +from typing import Iterator, Optional + import numpy as np -import psycopg +import psycopg2 import pyproj from tqdm import tqdm @@ -35,6 +38,42 @@ def save(self, filepath): np.savez_compressed(filepath, data=self.data) +class StringIteratorIO(io.TextIOBase): + def __init__(self, iter: Iterator[str]): + self._iter = iter + self._buff = '' + + def readable(self) -> bool: + return True + + def _read1(self, n: Optional[int] = None) -> str: + while not self._buff: + try: + self._buff = next(self._iter) + except StopIteration: + break + ret = self._buff[:n] + self._buff = self._buff[len(ret):] + return ret + + def read(self, n: Optional[int] = None) -> str: + line = [] + if n is None or n < 0: + while True: + m = self._read1() + if not m: + break + line.append(m) + else: + while n > 0: + m = self._read1(n) + if not m: + break + n -= len(m) + line.append(m) + return ''.join(line) + + class DBRecorder: """ DBRecorder is for web visualization and writes to Postgres Database @@ -62,68 +101,61 @@ def save(self, db_url: str, mongo_map: str, output_name: str, batch_size=1000, u xs = [] ys = [] proj = pyproj.Proj(self.eng._e.get_map_projection()) - for step, vs, ts in self.data: + for step, (vs, vx, vy), (ts, tx, ty) in self.data: if vs: - p, l, x, y, d, v = zip(*vs) - x, y = proj(x, y, True) - xs += x - ys += y - for p, l, x, y, d, v in zip(p, l, x, y, d, v): - vehs.append(f"({step},{p},{l},{round(d,3)},{x},{y},'',0,0,{round(v,3)})") + x, y = proj(vx, vy, True) + xs.extend(x) + ys.extend(y) + vehs.append([step, vs, x, y]) if ts: - p, s, x, y = zip(*ts) - x, y = proj(x, y, True) - xs += x - ys += y - for p, s, x, y in zip(p, s, x, y): - tls.append(f"({step},{p},{s},{x},{y})") + x, y = proj(tx, ty, True) + xs.extend(x) + ys.extend(y) + tls.append([step, ts, x, y]) if xs: min_lon, max_lon, min_lat, max_lat = min(xs), max(xs), min(ys), max(ys) else: x1, y1, x2, y2 = self.eng._e.get_map_bbox() min_lon, min_lat = proj(x1, y1, True) max_lon, max_lat = proj(x2, y2, True) - with psycopg.connect(db_url) as conn: + with psycopg2.connect(db_url) as conn: with conn.cursor() as cur: - cur.execute(""" -CREATE TABLE IF NOT EXISTS public.meta_simple ( - "name" text NOT NULL, - "start" int4 NOT NULL, - steps int4 NOT NULL, - "time" float8 NOT NULL, - total_agents int4 NOT NULL, - "map" text NOT NULL, - min_lng float8 NOT NULL, - min_lat float8 NOT NULL, - max_lng float8 NOT NULL, - max_lat float8 NOT NULL, - road_status_v_min float8 NULL, - road_status_interval int4 NULL, - CONSTRAINT meta_simple_pkey PRIMARY KEY (name) -)""") - cur.execute(f"DELETE FROM public.meta_simple WHERE name='{output_name}'") - cur.execute( - f"INSERT INTO public.meta_simple VALUES ('{output_name}', {self.eng.start_step}, {len(self.data)}, 1, 1, '{mongo_map}', {min_lon}, {min_lat}, {max_lon}, {max_lat}, 0, 300)") - cur.execute(f"DROP TABLE IF EXISTS {output_name}_s_cars") - cur.execute(f"DROP TABLE IF EXISTS {output_name}_s_people") - cur.execute(f"DROP TABLE IF EXISTS {output_name}_s_traffic_light") - cur.execute(f"DROP TABLE IF EXISTS {output_name}_s_road") cur.execute(f""" -CREATE TABLE public.{output_name}_s_cars ( - step int4 NOT NULL, - id int4 NOT NULL, - parent_id int4 NOT NULL, - direction float8 NOT NULL, - lng float8 NOT NULL, - lat float8 NOT NULL, - model text NOT NULL, - z float8 NOT NULL, - pitch float8 NOT NULL, - v float8 NOT NULL -)""") - cur.execute(f"CREATE INDEX {output_name}_s_cars_step_lng_lat_idx ON public.{output_name}_s_cars USING btree (step, lng, lat)") - cur.execute(f""" -CREATE TABLE public.{output_name}_s_people ( +CREATE TABLE IF NOT EXISTS public.meta_simple ( + "name" text NOT NULL, + "start" int4 NOT NULL, + steps int4 NOT NULL, + "time" float8 NOT NULL, + total_agents int4 NOT NULL, + "map" text NOT NULL, + min_lng float8 NOT NULL, + min_lat float8 NOT NULL, + max_lng float8 NOT NULL, + max_lat float8 NOT NULL, + road_status_v_min float8 NULL, + road_status_interval int4 NULL, + CONSTRAINT meta_simple_pkey PRIMARY KEY (name) +); +DELETE FROM public.meta_simple WHERE name='{output_name}'; +INSERT INTO public.meta_simple VALUES ('{output_name}', {self.eng.start_step}, {len(self.data)}, 1, 1, '{mongo_map}', {min_lon}, {min_lat}, {max_lon}, {max_lat}, 0, 300); +DROP TABLE IF EXISTS {output_name}_s_cars; +DROP TABLE IF EXISTS {output_name}_s_people; +DROP TABLE IF EXISTS {output_name}_s_traffic_light; +DROP TABLE IF EXISTS {output_name}_s_road; +CREATE TABLE {output_name}_s_cars ( + step int4 NOT NULL, + id int4 NOT NULL, + parent_id int4 NOT NULL, + direction float8 NOT NULL, + lng float8 NOT NULL, + lat float8 NOT NULL, + model text NOT NULL, + z float8 NOT NULL, + pitch float8 NOT NULL, + v float8 NOT NULL +); +CREATE INDEX {output_name}_s_cars_step_lng_lat_idx ON {output_name}_s_cars USING btree (step, lng, lat); +CREATE TABLE {output_name}_s_people ( step int4 NOT NULL, id int4 NOT NULL, parent_id int4 NOT NULL, @@ -133,19 +165,17 @@ def save(self, db_url: str, mongo_map: str, output_name: str, batch_size=1000, u z float8 NOT NULL, v float8 NOT NULL, model text NOT NULL -)""") - cur.execute(f"CREATE INDEX {output_name}_s_people_step_lng_lat_idx ON public.{output_name}_s_people USING btree (step, lng, lat)") - cur.execute(f""" -CREATE TABLE public.{output_name}_s_traffic_light ( +); +CREATE INDEX {output_name}_s_people_step_lng_lat_idx ON {output_name}_s_people USING btree (step, lng, lat); +CREATE TABLE {output_name}_s_traffic_light ( step int4 NOT NULL, id int4 NOT NULL, state int4 NOT NULL, lng float8 NOT NULL, lat float8 NOT NULL -)""") - cur.execute(f"CREATE INDEX {output_name}_s_traffic_light_step_lng_lat_idx ON public.{output_name}_s_traffic_light USING btree (step, lng, lat)") - cur.execute(f""" -CREATE TABLE public.{output_name}_s_road ( +); +CREATE INDEX {output_name}_s_traffic_light_step_lng_lat_idx ON {output_name}_s_traffic_light USING btree (step, lng, lat); +CREATE TABLE {output_name}_s_road ( step int4 NOT NULL, id int4 NOT NULL, "level" int4 NOT NULL, @@ -153,15 +183,20 @@ def save(self, db_url: str, mongo_map: str, output_name: str, batch_size=1000, u in_vehicle_cnt int4 NOT NULL, out_vehicle_cnt int4 NOT NULL, cnt int4 NOT NULL -)""") - cur.execute(f"CREATE INDEX {output_name}_s_road_step_idx ON public.{output_name}_s_road USING btree (step)") - for i in tqdm(range(0, len(vehs), batch_size), ncols=90, disable=not use_tqdm): - cur.execute( - f"INSERT INTO public.{output_name}_s_cars VALUES " - + ",".join(vehs[i: i + batch_size]) - ) - for i in tqdm(range(0, len(tls), batch_size), ncols=90, disable=not use_tqdm): - cur.execute( - f"INSERT INTO public.{output_name}_s_traffic_light VALUES " - + ",".join(tls[i: i + batch_size]) - ) +); +CREATE INDEX {output_name}_s_road_step_idx ON {output_name}_s_road USING btree (step); +""") + cur.copy_from( + StringIteratorIO( + f"{step},{p},{l},{round(d,3)},{x},{y},'',0,0,{round(v,3)}\n" + for step, vs, x, y in tqdm(vehs, ncols=90, disable=not use_tqdm) for (p, l, d, v), x, y in zip(vs, x, y) + ), + f'{output_name}_s_cars', sep=',' + ) + cur.copy_from( + StringIteratorIO( + f"{step},{p},{s},{x},{y}\n" + for step, ts, x, y in tqdm(tls, ncols=90, disable=not use_tqdm) for (p, s), x, y in zip(ts, x, y) + ), + f'{output_name}_s_traffic_light', sep=',' + ) diff --git a/src/python_api.cu b/src/python_api.cu index eee54ea..ff8eba6 100644 --- a/src/python_api.cu +++ b/src/python_api.cu @@ -21,9 +21,9 @@ using no_gil = py::call_guard; template inline py::array_t asarray(T* ptr) { py::gil_scoped_acquire _; - auto capsule = - py::capsule(ptr, [](void* p) { delete reinterpret_cast(p); }); - return py::array(ptr->size(), ptr->data(), capsule); + return py::array(ptr->size(), ptr->data(), py::capsule(ptr, [](void* p) { + delete reinterpret_cast(p); + })); } template vec& remove_duplicate(vec& arr) { @@ -485,25 +485,37 @@ class Engine { } // 获取用于输出的车辆信息 (id, parent_id, x, y, dir) auto get_output_vehicles() { - vec> out; + vec> out; + auto xs = new vec; + auto ys = new vec; out.reserve(get_running_vehicle_count()); + xs->reserve(get_running_vehicle_count()); + ys->reserve(get_running_vehicle_count()); for (auto& p : S.person.persons) { if (p.runtime.status == PersonStatus::DRIVING) { - out.emplace_back( - p.id, p.runtime.lane ? p.runtime.lane->id : unsigned(-1), - p.runtime.x, p.runtime.y, p.runtime.dir, p.runtime.speed); + out.emplace_back(p.id, + p.runtime.lane ? p.runtime.lane->id : unsigned(-1), + p.runtime.dir, p.runtime.speed); + xs->push_back(p.runtime.x); + ys->push_back(p.runtime.y); } } - return out; + return std::make_tuple(out, asarray(xs), asarray(ys)); } // 获取用于输出的信号灯信息 (id, state, x, y) auto get_output_tls() { - vec> out; + vec> out; + auto xs = new vec; + auto ys = new vec; out.reserve(S.lane.output_lanes.size); + xs->reserve(S.lane.output_lanes.size); + ys->reserve(S.lane.output_lanes.size); for (auto* l : S.lane.output_lanes) { - out.emplace_back(l->id, l->light_state, l->center_x, l->center_y); + out.emplace_back(l->id, l->light_state); + xs->push_back(l->center_x); + ys->push_back(l->center_y); } - return out; + return std::make_tuple(out, asarray(xs), asarray(ys)); } // 设置人禁用 void set_vehicle_enable(uint vehicle_index, bool enable) {