From 986fd187e5796ad81f89824ca58bc6cbfd787edf Mon Sep 17 00:00:00 2001 From: Dustin Carlino Date: Thu, 12 Sep 2024 11:00:09 +0100 Subject: [PATCH] Use the weights when picking points in a zone --- Cargo.lock | 48 +++++++++++++++++++++++++++++++++++--------- od2net/Cargo.toml | 2 +- od2net/src/od.rs | 51 ++++++++++++++++++++++------------------------- 3 files changed, 64 insertions(+), 37 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8dde409..e97421f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -588,9 +588,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.153" +version = "0.2.158" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" +checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" [[package]] name = "libm" @@ -653,12 +653,6 @@ dependencies = [ "adler", ] -[[package]] -name = "nanorand" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" - [[package]] name = "num-traits" version = "0.2.16" @@ -704,9 +698,9 @@ dependencies = [ "itertools 0.12.1", "log", "lts", - "nanorand", "num_cpus", "osm-reader", + "rand", "rayon", "rstar", "serde", @@ -750,6 +744,15 @@ version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f32154ba0af3a075eefa1eda8bb414ee928f62303a54ea85b8d6638ff1a6ee9e" +[[package]] +name = "ppv-lite86" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] + [[package]] name = "priority-queue" version = "2.0.2" @@ -830,6 +833,32 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" + [[package]] name = "rayon" version = "1.9.0" @@ -1348,6 +1377,7 @@ version = "0.7.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "74d4d3961e53fa4c9a25a8637fc2bfaf2595b3d3ae34875568a5cf64787716be" dependencies = [ + "byteorder", "zerocopy-derive", ] diff --git a/od2net/Cargo.toml b/od2net/Cargo.toml index acaf4c5..c15dd72 100644 --- a/od2net/Cargo.toml +++ b/od2net/Cargo.toml @@ -19,8 +19,8 @@ itertools = "0.12.1" log = "0.4.21" lts = { path = "../lts" } osm-reader = { git = "https://github.com/a-b-street/osm-reader", features = ["serde"] } -nanorand = { version = "0.7.0", default-features = false, features = ["wyrand"] } num_cpus = "1.16.0" +rand = { version = "0.8.5", default-features = false, features = ["alloc", "std_rng"] } rayon = "1.9.0" rstar = "0.12.0" serde = { version = "1.0.197", features = ["derive"] } diff --git a/od2net/src/od.rs b/od2net/src/od.rs index 4e054ed..0b9ad0f 100644 --- a/od2net/src/od.rs +++ b/od2net/src/od.rs @@ -6,7 +6,7 @@ use fs_err::File; use geo::{BoundingRect, Centroid, Contains, MultiPolygon, Point}; use geojson::{FeatureReader, Value}; use indicatif::HumanCount; -use nanorand::{Rng, WyRand}; +use rand::{prelude::SliceRandom, rngs::StdRng, SeedableRng}; use rstar::{PointDistance, RTree, RTreeObject, AABB}; use serde::Deserialize; @@ -104,23 +104,23 @@ pub fn generate_requests( timer.stop(); timer.start(format!("Generating requests from {csv_path}")); - let mut rng = WyRand::new_seed(rng_seed); + let mut rng = StdRng::seed_from_u64(rng_seed); for rec in csv::Reader::from_reader(File::open(csv_path)?).deserialize() { let row: BetweenZonesRow = rec?; + let Some(from_points) = origins_per_zone.get(&row.from) else { + bail!("Unknown zone {}", row.from); + }; + let Some(to_points) = destinations_per_zone.get(&row.to) else { + bail!("Unknown zone {}", row.to); + }; for _ in 0..row.count { - let from = match origins_per_zone.get(&row.from) { - Some(points) => points[rng.generate_range(0..points.len())], - None => { - bail!("Unknown zone {}", row.from); - } - }; - let to = match destinations_per_zone.get(&row.to) { - Some(points) => points[rng.generate_range(0..points.len())], - None => { - bail!("Unknown zone {}", row.to); - } - }; + // TODO choose_weighted is O(n); there are alternatives if this ever becomes a + // problem. + let from = from_points + .choose_weighted(&mut rng, |pt| pt.weight) + .unwrap(); + let to = to_points.choose_weighted(&mut rng, |pt| pt.weight).unwrap(); requests.push(Request { x1: from.lon, y1: from.lat, @@ -154,23 +154,20 @@ pub fn generate_requests( timer.stop(); timer.start(format!("Generating requests from {csv_path}")); - let mut rng = WyRand::new_seed(rng_seed); + let mut rng = StdRng::seed_from_u64(rng_seed); for rec in csv::Reader::from_reader(File::open(csv_path)?).deserialize() { let row: BetweenZonesRow = rec?; + let Some(from_points) = origins_per_zone.get(&row.from) else { + bail!("Unknown zone {}", row.from); + }; + let Some(to) = destinations.get(&row.to) else { + bail!("Unknown destination {}", row.to); + }; for _ in 0..row.count { - let from = match origins_per_zone.get(&row.from) { - Some(points) => points[rng.generate_range(0..points.len())], - None => { - bail!("Unknown zone {}", row.from); - } - }; - let to = match destinations.get(&row.to) { - Some(pt) => *pt, - None => { - bail!("Unknown destination {}", row.to); - } - }; + let from = from_points + .choose_weighted(&mut rng, |pt| pt.weight) + .unwrap(); requests.push(Request { x1: from.lon, y1: from.lat,