From 686ee217117dd14b07a5d4eccb1581dd4060fc93 Mon Sep 17 00:00:00 2001 From: Anant Date: Fri, 13 Sep 2024 15:56:38 -0600 Subject: [PATCH] Implemented Ball Tree using Haversine distance as mentioned in #237 Thie implmentation does not utilize the SpatialQueries Trait. But can be modified after somce discussion on how to handle the differeing parameters This does implement all the function utilized in the query_knn.py There is demo code at the botton of basics.ipynb There are also some Rust tests to test the functioning of the ball tree. This ball tree was influenced by [this repo](https://github.com/grantslatton/ball-tree). --- Cargo.lock | 1 + Cargo.toml | 9 +- examples/basics.ipynb | 996 +++++++++++++++----- python/polars_ds/__init__.py | 1 + python/polars_ds/query_balltree.py | 421 +++++++++ src/arkadia/leaf.rs | 2 +- src/num_ext/ball_tree.rs | 1368 ++++++++++++++++++++++++++++ src/num_ext/mod.rs | 1 + 8 files changed, 2551 insertions(+), 248 deletions(-) create mode 100644 python/polars_ds/query_balltree.py create mode 100644 src/num_ext/ball_tree.rs diff --git a/Cargo.lock b/Cargo.lock index f9f896a4..a45e8376 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1999,6 +1999,7 @@ dependencies = [ "pyo3", "pyo3-polars", "rand", + "rand_chacha", "rand_distr", "rapidfuzz", "rayon", diff --git a/Cargo.toml b/Cargo.toml index 46fe5e38..8304383f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,9 @@ crate-type = ["cdylib"] [dependencies] pyo3 = {version = "*", features = ["abi3-py38", "extension-module"]} pyo3-polars = {version = "0.15", features = ["derive"]} -polars = {version = "0.41.3", features = ["performant", "cse", "lazy", "parquet", "dtype-array", "diff", "array_count", "abs", "cross_join", "rank", "ndarray", "log", "cum_agg", "round_series", "nightly"]} +polars = {version = "0.41.3", features = ["performant", "cse", "lazy", +"parquet", "dtype-array", "diff", "array_count", "abs", "cross_join", "rank", "ndarray", "log", +"cum_agg", "round_series", "nightly", "dtype-struct"]} num = "0.4.1" faer = {version = "0.19", features = ["nightly"]} faer-ext = {version = "0.2.0", features = ["ndarray"]} @@ -43,3 +45,8 @@ jemallocator = { version = "0.5", features = ["disable_initial_exec_tls"] } codegen-units = 1 strip = "symbols" # lto = "fat" + + +[dev-dependencies] +rand = "0.8" +rand_chacha = "0.3" \ No newline at end of file diff --git a/examples/basics.ipynb b/examples/basics.ipynb index fc3375d3..32df18ab 100644 --- a/examples/basics.ipynb +++ b/examples/basics.ipynb @@ -46,7 +46,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 12)
ftime_idxdummyactualpredicteddummy_groupsx1x2x3aby
f64i64stri32f64strf64f64f64f64f64f64
0.00"a"10.359186"a"0.0562390.1554660.0443280.6257090.858837-0.011383
0.8414711"a"00.92849"a"0.7765320.4119950.1853990.0856830.057331-0.037931
0.9092972"a"10.439717"a"0.3302520.3512940.52270.080720.557959-0.629124
0.141123"a"00.996919"a"0.3779560.1230350.6869290.4482040.415162-0.93677
-0.7568024"a"00.025227"a"0.3038690.8224850.4125840.7657670.272874-0.326544
" + "shape: (5, 12)
ftime_idxdummyactualpredicteddummy_groupsx1x2x3aby
f64i64stri32f64strf64f64f64f64f64f64
0.00"a"10.892165"a"0.9117930.6951280.0942180.8182950.8159660.203999
0.8414711"a"10.948313"a"0.1257250.9563510.7763090.3256620.589716-0.858681
0.9092972"a"10.98999"a"0.6212170.933510.0021520.6611890.4837860.370057
0.141123"a"10.063476"a"0.553960.8686330.5330630.7814150.392511-0.455826
-0.7568024"a"00.337291"a"0.1072610.2887370.4650190.0364170.10985-0.594726
" ], "text/plain": [ "shape: (5, 12)\n", @@ -55,11 +55,11 @@ "│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │\n", "│ f64 ┆ i64 ┆ str ┆ i32 ┆ ┆ f64 ┆ f64 ┆ f64 ┆ f64 │\n", "╞═══════════╪══════════╪═══════╪════════╪═══╪══════════╪══════════╪══════════╪═══════════╡\n", - "│ 0.0 ┆ 0 ┆ a ┆ 1 ┆ … ┆ 0.044328 ┆ 0.625709 ┆ 0.858837 ┆ -0.011383 │\n", - "│ 0.841471 ┆ 1 ┆ a ┆ 0 ┆ … ┆ 0.185399 ┆ 0.085683 ┆ 0.057331 ┆ -0.037931 │\n", - "│ 0.909297 ┆ 2 ┆ a ┆ 1 ┆ … ┆ 0.5227 ┆ 0.08072 ┆ 0.557959 ┆ -0.629124 │\n", - "│ 0.14112 ┆ 3 ┆ a ┆ 0 ┆ … ┆ 0.686929 ┆ 0.448204 ┆ 0.415162 ┆ -0.93677 │\n", - "│ -0.756802 ┆ 4 ┆ a ┆ 0 ┆ … ┆ 0.412584 ┆ 0.765767 ┆ 0.272874 ┆ -0.326544 │\n", + "│ 0.0 ┆ 0 ┆ a ┆ 1 ┆ … ┆ 0.094218 ┆ 0.818295 ┆ 0.815966 ┆ 0.203999 │\n", + "│ 0.841471 ┆ 1 ┆ a ┆ 1 ┆ … ┆ 0.776309 ┆ 0.325662 ┆ 0.589716 ┆ -0.858681 │\n", + "│ 0.909297 ┆ 2 ┆ a ┆ 1 ┆ … ┆ 0.002152 ┆ 0.661189 ┆ 0.483786 ┆ 0.370057 │\n", + "│ 0.14112 ┆ 3 ┆ a ┆ 1 ┆ … ┆ 0.533063 ┆ 0.781415 ┆ 0.392511 ┆ -0.455826 │\n", + "│ -0.756802 ┆ 4 ┆ a ┆ 0 ┆ … ┆ 0.465019 ┆ 0.036417 ┆ 0.10985 ┆ -0.594726 │\n", "└───────────┴──────────┴───────┴────────┴───┴──────────┴──────────┴──────────┴───────────┘" ] }, @@ -217,21 +217,21 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 3)
fab
f64f64f64
1.3944e-15-0.625709-0.858837
-0.841471-0.085683-0.057331
-0.909297-0.08072-0.557959
-0.14112-0.448204-0.415162
0.756802-0.1400580.585963
" + "shape: (5, 3)
fab
f64f64f64
-4.2882e-16-0.818295-0.815966
-0.841471-0.325662-0.589716
-0.909297-0.661189-0.483786
-0.14112-0.781415-0.392511
0.7568020.7818780.706116
" ], "text/plain": [ "shape: (5, 3)\n", - "┌────────────┬───────────┬───────────┐\n", - "│ f ┆ a ┆ b │\n", - "│ --- ┆ --- ┆ --- │\n", - "│ f64 ┆ f64 ┆ f64 │\n", - "╞════════════╪═══════════╪═══════════╡\n", - "│ 1.3944e-15 ┆ -0.625709 ┆ -0.858837 │\n", - "│ -0.841471 ┆ -0.085683 ┆ -0.057331 │\n", - "│ -0.909297 ┆ -0.08072 ┆ -0.557959 │\n", - "│ -0.14112 ┆ -0.448204 ┆ -0.415162 │\n", - "│ 0.756802 ┆ -0.140058 ┆ 0.585963 │\n", - "└────────────┴───────────┴───────────┘" + "┌─────────────┬───────────┬───────────┐\n", + "│ f ┆ a ┆ b │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ f64 ┆ f64 ┆ f64 │\n", + "╞═════════════╪═══════════╪═══════════╡\n", + "│ -4.2882e-16 ┆ -0.818295 ┆ -0.815966 │\n", + "│ -0.841471 ┆ -0.325662 ┆ -0.589716 │\n", + "│ -0.909297 ┆ -0.661189 ┆ -0.483786 │\n", + "│ -0.14112 ┆ -0.781415 ┆ -0.392511 │\n", + "│ 0.756802 ┆ 0.781878 ┆ 0.706116 │\n", + "└─────────────┴───────────┴───────────┘" ] }, "execution_count": 6, @@ -267,7 +267,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (1, 1)
y
list[f64]
[-0.490676, -0.361612]
" + "shape: (1, 1)
y
list[f64]
[-0.490017, -0.355404]
" ], "text/plain": [ "shape: (1, 1)\n", @@ -276,7 +276,7 @@ "│ --- │\n", "│ list[f64] │\n", "╞════════════════════════╡\n", - "│ [-0.490676, -0.361612] │\n", + "│ [-0.490017, -0.355404] │\n", "└────────────────────────┘" ] }, @@ -312,20 +312,20 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (4, 7)
featuresbetastd_errtp>|t|0.0250.975
strf64f64f64f64f64f64
"ln(x1+1)"0.2183220.001686129.511110.00.2150170.221626
"exp(x2)"0.174690.000679257.4137490.00.173360.17602
"sin(x3)"-1.7435070.00134-1300.6490750.0-1.746135-1.74088
"__bias__"-0.1080790.001497-72.1853580.0-0.111013-0.105144
" + "shape: (4, 7)
featuresbetastd_errtp>|t|0.0250.975
strf64f64f64f64f64f64
"ln(x1+1)"0.2198440.001645133.6804250.00.2166210.223068
"exp(x2)"0.1756150.000671261.5494710.00.1742990.176931
"sin(x3)"-1.7439930.001334-1306.993750.0-1.746608-1.741377
"__bias__"-0.1092990.001498-72.9742180.0-0.112235-0.106363
" ], "text/plain": [ "shape: (4, 7)\n", - "┌──────────┬───────────┬──────────┬──────────────┬───────┬───────────┬───────────┐\n", - "│ features ┆ beta ┆ std_err ┆ t ┆ p>|t| ┆ 0.025 ┆ 0.975 │\n", - "│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", - "│ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 │\n", - "╞══════════╪═══════════╪══════════╪══════════════╪═══════╪═══════════╪═══════════╡\n", - "│ ln(x1+1) ┆ 0.218322 ┆ 0.001686 ┆ 129.51111 ┆ 0.0 ┆ 0.215017 ┆ 0.221626 │\n", - "│ exp(x2) ┆ 0.17469 ┆ 0.000679 ┆ 257.413749 ┆ 0.0 ┆ 0.17336 ┆ 0.17602 │\n", - "│ sin(x3) ┆ -1.743507 ┆ 0.00134 ┆ -1300.649075 ┆ 0.0 ┆ -1.746135 ┆ -1.74088 │\n", - "│ __bias__ ┆ -0.108079 ┆ 0.001497 ┆ -72.185358 ┆ 0.0 ┆ -0.111013 ┆ -0.105144 │\n", - "└──────────┴───────────┴──────────┴──────────────┴───────┴───────────┴───────────┘" + "┌──────────┬───────────┬──────────┬─────────────┬───────┬───────────┬───────────┐\n", + "│ features ┆ beta ┆ std_err ┆ t ┆ p>|t| ┆ 0.025 ┆ 0.975 │\n", + "│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", + "│ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 │\n", + "╞══════════╪═══════════╪══════════╪═════════════╪═══════╪═══════════╪═══════════╡\n", + "│ ln(x1+1) ┆ 0.219844 ┆ 0.001645 ┆ 133.680425 ┆ 0.0 ┆ 0.216621 ┆ 0.223068 │\n", + "│ exp(x2) ┆ 0.175615 ┆ 0.000671 ┆ 261.549471 ┆ 0.0 ┆ 0.174299 ┆ 0.176931 │\n", + "│ sin(x3) ┆ -1.743993 ┆ 0.001334 ┆ -1306.99375 ┆ 0.0 ┆ -1.746608 ┆ -1.741377 │\n", + "│ __bias__ ┆ -0.109299 ┆ 0.001498 ┆ -72.974218 ┆ 0.0 ┆ -0.112235 ┆ -0.106363 │\n", + "└──────────┴───────────┴──────────┴─────────────┴───────┴───────────┴───────────┘" ] }, "execution_count": 8, @@ -361,7 +361,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (1, 1)
y
list[f64]
[-0.490676, -0.361612]
" + "shape: (1, 1)
y
list[f64]
[-0.490017, -0.355404]
" ], "text/plain": [ "shape: (1, 1)\n", @@ -370,7 +370,7 @@ "│ --- │\n", "│ list[f64] │\n", "╞════════════════════════╡\n", - "│ [-0.490676, -0.361612] │\n", + "│ [-0.490017, -0.355404] │\n", "└────────────────────────┘" ] }, @@ -405,7 +405,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (10_000, 2)
dummycoeffs
strlist[f64]
"a"[-0.485351, -0.364542]
"a"[-0.485351, -0.364542]
"a"[-0.485351, -0.364542]
"a"[-0.485351, -0.364542]
"a"[-0.485351, -0.364542]
"b"[-0.496067, -0.358636]
"b"[-0.496067, -0.358636]
"b"[-0.496067, -0.358636]
"b"[-0.496067, -0.358636]
"b"[-0.496067, -0.358636]
" + "shape: (10_000, 2)
dummycoeffs
strlist[f64]
"a"[-0.491637, -0.363114]
"a"[-0.491637, -0.363114]
"a"[-0.491637, -0.363114]
"a"[-0.491637, -0.363114]
"a"[-0.491637, -0.363114]
"b"[-0.488274, -0.347916]
"b"[-0.488274, -0.347916]
"b"[-0.488274, -0.347916]
"b"[-0.488274, -0.347916]
"b"[-0.488274, -0.347916]
" ], "text/plain": [ "shape: (10_000, 2)\n", @@ -414,17 +414,17 @@ "│ --- ┆ --- │\n", "│ str ┆ list[f64] │\n", "╞═══════╪════════════════════════╡\n", - "│ a ┆ [-0.485351, -0.364542] │\n", - "│ a ┆ [-0.485351, -0.364542] │\n", - "│ a ┆ [-0.485351, -0.364542] │\n", - "│ a ┆ [-0.485351, -0.364542] │\n", - "│ a ┆ [-0.485351, -0.364542] │\n", + "│ a ┆ [-0.491637, -0.363114] │\n", + "│ a ┆ [-0.491637, -0.363114] │\n", + "│ a ┆ [-0.491637, -0.363114] │\n", + "│ a ┆ [-0.491637, -0.363114] │\n", + "│ a ┆ [-0.491637, -0.363114] │\n", "│ … ┆ … │\n", - "│ b ┆ [-0.496067, -0.358636] │\n", - "│ b ┆ [-0.496067, -0.358636] │\n", - "│ b ┆ [-0.496067, -0.358636] │\n", - "│ b ┆ [-0.496067, -0.358636] │\n", - "│ b ┆ [-0.496067, -0.358636] │\n", + "│ b ┆ [-0.488274, -0.347916] │\n", + "│ b ┆ [-0.488274, -0.347916] │\n", + "│ b ┆ [-0.488274, -0.347916] │\n", + "│ b ┆ [-0.488274, -0.347916] │\n", + "│ b ┆ [-0.488274, -0.347916] │\n", "└───────┴────────────────────────┘" ] }, @@ -460,7 +460,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 5)
x1x2ypredresid
f64f64f64f64f64
0.0562390.155466-0.011383-0.0838130.07243
0.7765320.411995-0.037931-0.5300080.492077
0.3302520.351294-0.629124-0.289079-0.340045
0.3779560.123035-0.93677-0.229945-0.706825
0.3038690.822485-0.326544-0.4465220.119977
" + "shape: (5, 5)
x1x2ypredresid
f64f64f64f64f64
0.9117930.6951280.203999-0.6938460.897845
0.1257250.956351-0.858681-0.401499-0.457183
0.6212170.933510.370057-0.6361811.006238
0.553960.868633-0.455826-0.5801660.124339
0.1072610.288737-0.594726-0.155178-0.439548
" ], "text/plain": [ "shape: (5, 5)\n", @@ -469,11 +469,11 @@ "│ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", "│ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 │\n", "╞══════════╪══════════╪═══════════╪═══════════╪═══════════╡\n", - "│ 0.056239 ┆ 0.155466 ┆ -0.011383 ┆ -0.083813 ┆ 0.07243 │\n", - "│ 0.776532 ┆ 0.411995 ┆ -0.037931 ┆ -0.530008 ┆ 0.492077 │\n", - "│ 0.330252 ┆ 0.351294 ┆ -0.629124 ┆ -0.289079 ┆ -0.340045 │\n", - "│ 0.377956 ┆ 0.123035 ┆ -0.93677 ┆ -0.229945 ┆ -0.706825 │\n", - "│ 0.303869 ┆ 0.822485 ┆ -0.326544 ┆ -0.446522 ┆ 0.119977 │\n", + "│ 0.911793 ┆ 0.695128 ┆ 0.203999 ┆ -0.693846 ┆ 0.897845 │\n", + "│ 0.125725 ┆ 0.956351 ┆ -0.858681 ┆ -0.401499 ┆ -0.457183 │\n", + "│ 0.621217 ┆ 0.93351 ┆ 0.370057 ┆ -0.636181 ┆ 1.006238 │\n", + "│ 0.55396 ┆ 0.868633 ┆ -0.455826 ┆ -0.580166 ┆ 0.124339 │\n", + "│ 0.107261 ┆ 0.288737 ┆ -0.594726 ┆ -0.155178 ┆ -0.439548 │\n", "└──────────┴──────────┴───────────┴───────────┴───────────┘" ] }, @@ -513,7 +513,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (2, 2)
dummycoeffs
strlist[f64]
"b"[-0.496067, -0.358636]
"a"[-0.485351, -0.364542]
" + "shape: (2, 2)
dummycoeffs
strlist[f64]
"a"[-0.491637, -0.363114]
"b"[-0.488274, -0.347916]
" ], "text/plain": [ "shape: (2, 2)\n", @@ -522,8 +522,8 @@ "│ --- ┆ --- │\n", "│ str ┆ list[f64] │\n", "╞═══════╪════════════════════════╡\n", - "│ b ┆ [-0.496067, -0.358636] │\n", - "│ a ┆ [-0.485351, -0.364542] │\n", + "│ a ┆ [-0.491637, -0.363114] │\n", + "│ b ┆ [-0.488274, -0.347916] │\n", "└───────┴────────────────────────┘" ] }, @@ -558,7 +558,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (2, 2)
dummycoeffs
strlist[f64]
"b"[-0.323816, -0.185978]
"a"[-0.31474, -0.192814]
" + "shape: (2, 2)
dummycoeffs
strlist[f64]
"a"[-0.327664, -0.184011]
"b"[-0.318486, -0.177007]
" ], "text/plain": [ "shape: (2, 2)\n", @@ -567,8 +567,8 @@ "│ --- ┆ --- │\n", "│ str ┆ list[f64] │\n", "╞═══════╪════════════════════════╡\n", - "│ b ┆ [-0.323816, -0.185978] │\n", - "│ a ┆ [-0.31474, -0.192814] │\n", + "│ a ┆ [-0.327664, -0.184011] │\n", + "│ b ┆ [-0.318486, -0.177007] │\n", "└───────┴────────────────────────┘" ] }, @@ -605,7 +605,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (2, 2)
dummylasso_r2
strf64
"b"-0.541371
"a"-0.515024
" + "shape: (2, 2)
dummylasso_r2
strf64
"a"-0.533136
"b"-0.548727
" ], "text/plain": [ "shape: (2, 2)\n", @@ -614,8 +614,8 @@ "│ --- ┆ --- │\n", "│ str ┆ f64 │\n", "╞═══════╪═══════════╡\n", - "│ b ┆ -0.541371 │\n", - "│ a ┆ -0.515024 │\n", + "│ a ┆ -0.533136 │\n", + "│ b ┆ -0.548727 │\n", "└───────┴───────────┘" ] }, @@ -656,7 +656,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (10_000, 5)
yx1x2coeffspred
f64f64f64list[f64]f64
-0.0113830.0562390.155466nullnull
-0.0379310.7765320.411995nullnull
-0.6291240.3302520.351294nullnull
-0.936770.3779560.123035nullnull
-0.3265440.3038690.822485[-0.576516, -0.193246]-0.334127
-1.2598860.0345610.752261[0.282414, -1.23973]-0.922841
-0.208190.0408980.021563[0.614343, -1.462836]-0.006418
-0.6674310.2508750.402446[0.648749, -1.760264]-0.545656
-0.8841420.0291430.854246[0.12887, -1.343459]-1.143889
-0.4891730.448210.973485[1.102579, -1.289347]-0.760973
" + "shape: (10_000, 5)
yx1x2coeffspred
f64f64f64list[f64]f64
0.2039990.9117930.695128nullnull
-0.8586810.1257250.956351nullnull
0.3700570.6212170.93351nullnull
-0.4558260.553960.868633nullnull
-0.5947260.1072610.288737[1.160027, -0.979407]-0.158365
0.2204080.2309110.658026[-1.13021, 0.56658]0.111846
-0.9943650.8894470.924552[-0.924553, -0.002686]-0.824824
-1.4559820.2076590.020715[-1.495076, 0.657265]-0.29685
-0.980380.9395370.244478[-1.675845, 0.691202]-1.405534
-1.3069040.1192830.533054[-1.079334, -0.22549]-0.248944
" ], "text/plain": [ "shape: (10_000, 5)\n", @@ -665,17 +665,17 @@ "│ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", "│ f64 ┆ f64 ┆ f64 ┆ list[f64] ┆ f64 │\n", "╞═══════════╪══════════╪══════════╪════════════════════════╪═══════════╡\n", - "│ -0.011383 ┆ 0.056239 ┆ 0.155466 ┆ null ┆ null │\n", - "│ -0.037931 ┆ 0.776532 ┆ 0.411995 ┆ null ┆ null │\n", - "│ -0.629124 ┆ 0.330252 ┆ 0.351294 ┆ null ┆ null │\n", - "│ -0.93677 ┆ 0.377956 ┆ 0.123035 ┆ null ┆ null │\n", - "│ -0.326544 ┆ 0.303869 ┆ 0.822485 ┆ [-0.576516, -0.193246] ┆ -0.334127 │\n", + "│ 0.203999 ┆ 0.911793 ┆ 0.695128 ┆ null ┆ null │\n", + "│ -0.858681 ┆ 0.125725 ┆ 0.956351 ┆ null ┆ null │\n", + "│ 0.370057 ┆ 0.621217 ┆ 0.93351 ┆ null ┆ null │\n", + "│ -0.455826 ┆ 0.55396 ┆ 0.868633 ┆ null ┆ null │\n", + "│ -0.594726 ┆ 0.107261 ┆ 0.288737 ┆ [1.160027, -0.979407] ┆ -0.158365 │\n", "│ … ┆ … ┆ … ┆ … ┆ … │\n", - "│ -1.259886 ┆ 0.034561 ┆ 0.752261 ┆ [0.282414, -1.23973] ┆ -0.922841 │\n", - "│ -0.20819 ┆ 0.040898 ┆ 0.021563 ┆ [0.614343, -1.462836] ┆ -0.006418 │\n", - "│ -0.667431 ┆ 0.250875 ┆ 0.402446 ┆ [0.648749, -1.760264] ┆ -0.545656 │\n", - "│ -0.884142 ┆ 0.029143 ┆ 0.854246 ┆ [0.12887, -1.343459] ┆ -1.143889 │\n", - "│ -0.489173 ┆ 0.44821 ┆ 0.973485 ┆ [1.102579, -1.289347] ┆ -0.760973 │\n", + "│ 0.220408 ┆ 0.230911 ┆ 0.658026 ┆ [-1.13021, 0.56658] ┆ 0.111846 │\n", + "│ -0.994365 ┆ 0.889447 ┆ 0.924552 ┆ [-0.924553, -0.002686] ┆ -0.824824 │\n", + "│ -1.455982 ┆ 0.207659 ┆ 0.020715 ┆ [-1.495076, 0.657265] ┆ -0.29685 │\n", + "│ -0.98038 ┆ 0.939537 ┆ 0.244478 ┆ [-1.675845, 0.691202] ┆ -1.405534 │\n", + "│ -1.306904 ┆ 0.119283 ┆ 0.533054 ┆ [-1.079334, -0.22549] ┆ -0.248944 │\n", "└───────────┴──────────┴──────────┴────────────────────────┴───────────┘" ] }, @@ -756,7 +756,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (1, 1)
a
list[f64]
[29.227034, 28.841191, 28.673175]
" + "shape: (1, 1)
a
list[f64]
[29.313388, 28.98794, 28.657594]
" ], "text/plain": [ "shape: (1, 1)\n", @@ -765,7 +765,7 @@ "│ --- │\n", "│ list[f64] │\n", "╞═════════════════════════════════╡\n", - "│ [29.227034, 28.841191, 28.6731… │\n", + "│ [29.313388, 28.98794, 28.65759… │\n", "└─────────────────────────────────┘" ] }, @@ -797,7 +797,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (2, 2)
singular_valueweight_vector
f64list[f64]
29.213465[0.780529, -0.625119]
28.729148[0.625119, 0.780529]
" + "shape: (2, 2)
singular_valueweight_vector
f64list[f64]
29.122366[-0.462112, 0.886822]
28.657989[0.886822, 0.462112]
" ], "text/plain": [ "shape: (2, 2)\n", @@ -806,8 +806,8 @@ "│ --- ┆ --- │\n", "│ f64 ┆ list[f64] │\n", "╞════════════════╪═══════════════════════╡\n", - "│ 29.213465 ┆ [0.780529, -0.625119] │\n", - "│ 28.729148 ┆ [0.625119, 0.780529] │\n", + "│ 29.122366 ┆ [-0.462112, 0.886822] │\n", + "│ 28.657989 ┆ [0.886822, 0.462112] │\n", "└────────────────┴───────────────────────┘" ] }, @@ -839,7 +839,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 1)
pc1
f64
-0.123833
-0.044303
-0.361128
0.014969
0.351783
" + "shape: (5, 1)
pc1
f64
0.133943
0.160952
-0.08804
-0.224543
-0.130941
" ], "text/plain": [ "shape: (5, 1)\n", @@ -848,11 +848,11 @@ "│ --- │\n", "│ f64 │\n", "╞═══════════╡\n", - "│ -0.123833 │\n", - "│ -0.044303 │\n", - "│ -0.361128 │\n", - "│ 0.014969 │\n", - "│ 0.351783 │\n", + "│ 0.133943 │\n", + "│ 0.160952 │\n", + "│ -0.08804 │\n", + "│ -0.224543 │\n", + "│ -0.130941 │\n", "└───────────┘" ] }, @@ -892,7 +892,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (2, 8)
dummy_groupsl2log lossprecisionrecallfaverage_precisionroc_auc
strf64f64f64f64f64f64f64
"b"0.3339981.009880.5016140.493450.4974980.5058390.499459
"a"0.3390751.0127610.5008220.4746690.4873950.5016930.488649
" + "shape: (2, 8)
dummy_groupsl2log lossprecisionrecallfaverage_precisionroc_auc
strf64f64f64f64f64f64f64
"b"0.3321890.998670.4967450.4917440.4942320.5030490.50233
"a"0.3436861.0367970.4918230.4883170.4900640.4911710.480539
" ], "text/plain": [ "shape: (2, 8)\n", @@ -902,8 +902,8 @@ "│ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ --- ┆ f64 │\n", "│ ┆ ┆ ┆ ┆ ┆ ┆ f64 ┆ │\n", "╞══════════════╪══════════╪══════════╪═══════════╪══════════╪══════════╪════════════════╪══════════╡\n", - "│ b ┆ 0.333998 ┆ 1.00988 ┆ 0.501614 ┆ 0.49345 ┆ 0.497498 ┆ 0.505839 ┆ 0.499459 │\n", - "│ a ┆ 0.339075 ┆ 1.012761 ┆ 0.500822 ┆ 0.474669 ┆ 0.487395 ┆ 0.501693 ┆ 0.488649 │\n", + "│ b ┆ 0.332189 ┆ 0.99867 ┆ 0.496745 ┆ 0.491744 ┆ 0.494232 ┆ 0.503049 ┆ 0.50233 │\n", + "│ a ┆ 0.343686 ┆ 1.036797 ┆ 0.491823 ┆ 0.488317 ┆ 0.490064 ┆ 0.491171 ┆ 0.480539 │\n", "└──────────────┴──────────┴──────────┴───────────┴──────────┴──────────┴────────────────┴──────────┘" ] }, @@ -991,7 +991,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 1)
sen
str
"hello"
"world"
"church"
"to"
"going"
" + "shape: (5, 1)
sen
str
"world"
"hello"
"going"
"to"
"church"
" ], "text/plain": [ "shape: (5, 1)\n", @@ -1000,11 +1000,11 @@ "│ --- │\n", "│ str │\n", "╞════════╡\n", - "│ hello │\n", "│ world │\n", - "│ church │\n", - "│ to │\n", + "│ hello │\n", "│ going │\n", + "│ to │\n", + "│ church │\n", "└────────┘" ] }, @@ -1036,7 +1036,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 1)
sen
str
"church"
"go"
""
"world"
"hello"
" + "shape: (5, 1)
sen
str
"go"
""
"hello"
"church"
"world"
" ], "text/plain": [ "shape: (5, 1)\n", @@ -1045,11 +1045,11 @@ "│ --- │\n", "│ str │\n", "╞════════╡\n", - "│ church │\n", "│ go │\n", "│ │\n", - "│ world │\n", "│ hello │\n", + "│ church │\n", + "│ world │\n", "└────────┘" ] }, @@ -1377,7 +1377,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 1)
a
f64
null
null
-0.868297
0.552773
-1.152573
" + "shape: (5, 1)
a
f64
null
null
-0.063853
0.589552
-0.946684
" ], "text/plain": [ "shape: (5, 1)\n", @@ -1388,9 +1388,9 @@ "╞═══════════╡\n", "│ null │\n", "│ null │\n", - "│ -0.868297 │\n", - "│ 0.552773 │\n", - "│ -1.152573 │\n", + "│ -0.063853 │\n", + "│ 0.589552 │\n", + "│ -0.946684 │\n", "└───────────┘" ] }, @@ -1424,7 +1424,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 3)
arandom_normalrandom_normal_that_respects_null_of_a
f64f64f64
null0.744406null
null0.936033null
-0.8682970.5564360.723798
0.5527731.4458860.453344
-1.152573-1.195777-0.006817
" + "shape: (5, 3)
arandom_normalrandom_normal_that_respects_null_of_a
f64f64f64
null2.515477null
null-1.507373null
-0.0638530.1325070.784535
0.589552-1.1250481.717568
-0.9466841.3467410.454957
" ], "text/plain": [ "shape: (5, 3)\n", @@ -1433,11 +1433,11 @@ "│ --- ┆ --- ┆ --- │\n", "│ f64 ┆ f64 ┆ f64 │\n", "╞═══════════╪═══════════════╪═════════════════════════════════╡\n", - "│ null ┆ 0.744406 ┆ null │\n", - "│ null ┆ 0.936033 ┆ null │\n", - "│ -0.868297 ┆ 0.556436 ┆ 0.723798 │\n", - "│ 0.552773 ┆ 1.445886 ┆ 0.453344 │\n", - "│ -1.152573 ┆ -1.195777 ┆ -0.006817 │\n", + "│ null ┆ 2.515477 ┆ null │\n", + "│ null ┆ -1.507373 ┆ null │\n", + "│ -0.063853 ┆ 0.132507 ┆ 0.784535 │\n", + "│ 0.589552 ┆ -1.125048 ┆ 1.717568 │\n", + "│ -0.946684 ┆ 1.346741 ┆ 0.454957 │\n", "└───────────┴───────────────┴─────────────────────────────────┘" ] }, @@ -1472,7 +1472,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 3)
arandom_strrandom_str_that_respects_null_of_a
f64strstr
null"Pi3"null
null"ZCnaW"null
-0.868297"2""2"
0.552773"e""e"
-1.152573"7mg3Z""7mg3Z"
" + "shape: (5, 3)
arandom_strrandom_str_that_respects_null_of_a
f64strstr
null"p9q"null
null"R1UZ"null
-0.063853"q5Vog""o"
0.589552"vQMM""AH"
-0.946684"Ydh""0"
" ], "text/plain": [ "shape: (5, 3)\n", @@ -1481,11 +1481,11 @@ "│ --- ┆ --- ┆ --- │\n", "│ f64 ┆ str ┆ str │\n", "╞═══════════╪════════════╪═════════════════════════════════╡\n", - "│ null ┆ Pi3 ┆ null │\n", - "│ null ┆ ZCnaW ┆ null │\n", - "│ -0.868297 ┆ 2 ┆ 2 │\n", - "│ 0.552773 ┆ e ┆ e │\n", - "│ -1.152573 ┆ 7mg3Z ┆ 7mg3Z │\n", + "│ null ┆ p9q ┆ null │\n", + "│ null ┆ R1UZ ┆ null │\n", + "│ -0.063853 ┆ q5Vog ┆ o │\n", + "│ 0.589552 ┆ vQMM ┆ AH │\n", + "│ -0.946684 ┆ Ydh ┆ 0 │\n", "└───────────┴────────────┴─────────────────────────────────┘" ] }, @@ -1520,7 +1520,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 2)
arandom_str
f64str
nullnull
nullnull
-0.868297"IeA07"
0.552773"3ZtJz"
-1.152573"NoAKh"
" + "shape: (5, 2)
arandom_str
f64str
nullnull
nullnull
-0.063853"qhb5I"
0.589552"Dfmvg"
-0.946684"sL1xx"
" ], "text/plain": [ "shape: (5, 2)\n", @@ -1531,9 +1531,9 @@ "╞═══════════╪════════════╡\n", "│ null ┆ null │\n", "│ null ┆ null │\n", - "│ -0.868297 ┆ IeA07 │\n", - "│ 0.552773 ┆ 3ZtJz │\n", - "│ -1.152573 ┆ NoAKh │\n", + "│ -0.063853 ┆ qhb5I │\n", + "│ 0.589552 ┆ Dfmvg │\n", + "│ -0.946684 ┆ sL1xx │\n", "└───────────┴────────────┘" ] }, @@ -1567,7 +1567,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 4)
atest1literaltest1_perturbed
f64f64f64f64
null-0.314265null-0.314012
null-1.438597null-1.438247
-0.868297-0.3874460.37435-0.387077
0.552773-2.3357122.323368-2.336171
-1.152573-1.9289850.281368-1.92879
" + "shape: (5, 4)
atest1literaltest1_perturbed
f64f64f64f64
null0.278036null0.277542
null1.817391null1.817711
-0.063853-0.4984640.833706-0.498104
0.589552-0.9762791.275155-0.976284
-0.946684-0.7503242.289387-0.750434
" ], "text/plain": [ "shape: (5, 4)\n", @@ -1576,11 +1576,11 @@ "│ --- ┆ --- ┆ --- ┆ --- │\n", "│ f64 ┆ f64 ┆ f64 ┆ f64 │\n", "╞═══════════╪═══════════╪══════════╪═════════════════╡\n", - "│ null ┆ -0.314265 ┆ null ┆ -0.314012 │\n", - "│ null ┆ -1.438597 ┆ null ┆ -1.438247 │\n", - "│ -0.868297 ┆ -0.387446 ┆ 0.37435 ┆ -0.387077 │\n", - "│ 0.552773 ┆ -2.335712 ┆ 2.323368 ┆ -2.336171 │\n", - "│ -1.152573 ┆ -1.928985 ┆ 0.281368 ┆ -1.92879 │\n", + "│ null ┆ 0.278036 ┆ null ┆ 0.277542 │\n", + "│ null ┆ 1.817391 ┆ null ┆ 1.817711 │\n", + "│ -0.063853 ┆ -0.498464 ┆ 0.833706 ┆ -0.498104 │\n", + "│ 0.589552 ┆ -0.976279 ┆ 1.275155 ┆ -0.976284 │\n", + "│ -0.946684 ┆ -0.750324 ┆ 2.289387 ┆ -0.750434 │\n", "└───────────┴───────────┴──────────┴─────────────────┘" ] }, @@ -1619,7 +1619,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 4)
a[0, 1)NormalInt from [0, 10)
f64f64f64i32
null0.7508272.0369695
null0.3452180.4869222
-0.8682970.5370810.2575833
0.5527730.07438-1.3138623
-1.1525730.7021530.0965945
" + "shape: (5, 4)
a[0, 1)NormalInt from [0, 10)
f64f64f64i32
null0.484578-0.3736550
null0.096912-0.2716845
-0.0638530.380539-1.0883717
0.5895520.112595-1.8948234
-0.9466840.770406-1.2559294
" ], "text/plain": [ "shape: (5, 4)\n", @@ -1628,11 +1628,11 @@ "│ --- ┆ --- ┆ --- ┆ --- │\n", "│ f64 ┆ f64 ┆ f64 ┆ i32 │\n", "╞═══════════╪══════════╪═══════════╪══════════════════╡\n", - "│ null ┆ 0.750827 ┆ 2.036969 ┆ 5 │\n", - "│ null ┆ 0.345218 ┆ 0.486922 ┆ 2 │\n", - "│ -0.868297 ┆ 0.537081 ┆ 0.257583 ┆ 3 │\n", - "│ 0.552773 ┆ 0.07438 ┆ -1.313862 ┆ 3 │\n", - "│ -1.152573 ┆ 0.702153 ┆ 0.096594 ┆ 5 │\n", + "│ null ┆ 0.484578 ┆ -0.373655 ┆ 0 │\n", + "│ null ┆ 0.096912 ┆ -0.271684 ┆ 5 │\n", + "│ -0.063853 ┆ 0.380539 ┆ -1.088371 ┆ 7 │\n", + "│ 0.589552 ┆ 0.112595 ┆ -1.894823 ┆ 4 │\n", + "│ -0.946684 ┆ 0.770406 ┆ -1.255929 ┆ 4 │\n", "└───────────┴──────────┴───────────┴──────────────────┘" ] }, @@ -1667,7 +1667,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (1, 4)
t-tests: statisticst-tests: pvaluenormality_test: statisticsnormality_test: pvalue
f64f64f64f64
-0.7669880.4432150.754490.685748
" + "shape: (1, 4)
t-tests: statisticst-tests: pvaluenormality_test: statisticsnormality_test: pvalue
f64f64f64f64
-0.1092820.9129942.3119370.314753
" ], "text/plain": [ "shape: (1, 4)\n", @@ -1676,7 +1676,7 @@ "│ --- ┆ --- ┆ --- ┆ --- │\n", "│ f64 ┆ f64 ┆ f64 ┆ f64 │\n", "╞═════════════════════╪═════════════════╪════════════════════════════╪════════════════════════╡\n", - "│ -0.766988 ┆ 0.443215 ┆ 0.75449 ┆ 0.685748 │\n", + "│ -0.109282 ┆ 0.912994 ┆ 2.311937 ┆ 0.314753 │\n", "└─────────────────────┴─────────────────┴────────────────────────────┴────────────────────────┘" ] }, @@ -1720,7 +1720,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 5)
market_idvar1var2category_1category_2
i64f64f64i32i32
00.9689690.48266921
10.4547120.50095849
20.2756680.3722346
00.2444160.46794803
10.142920.5405911
" + "shape: (5, 5)
market_idvar1var2category_1category_2
i64f64f64i32i32
00.84980.10476227
10.8354230.30452132
20.4514550.91563521
00.1989610.84035307
10.0660470.93680523
" ], "text/plain": [ "shape: (5, 5)\n", @@ -1729,11 +1729,11 @@ "│ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", "│ i64 ┆ f64 ┆ f64 ┆ i32 ┆ i32 │\n", "╞═══════════╪══════════╪══════════╪════════════╪════════════╡\n", - "│ 0 ┆ 0.968969 ┆ 0.482669 ┆ 2 ┆ 1 │\n", - "│ 1 ┆ 0.454712 ┆ 0.500958 ┆ 4 ┆ 9 │\n", - "│ 2 ┆ 0.275668 ┆ 0.37223 ┆ 4 ┆ 6 │\n", - "│ 0 ┆ 0.244416 ┆ 0.467948 ┆ 0 ┆ 3 │\n", - "│ 1 ┆ 0.14292 ┆ 0.54059 ┆ 1 ┆ 1 │\n", + "│ 0 ┆ 0.8498 ┆ 0.104762 ┆ 2 ┆ 7 │\n", + "│ 1 ┆ 0.835423 ┆ 0.304521 ┆ 3 ┆ 2 │\n", + "│ 2 ┆ 0.451455 ┆ 0.915635 ┆ 2 ┆ 1 │\n", + "│ 0 ┆ 0.198961 ┆ 0.840353 ┆ 0 ┆ 7 │\n", + "│ 1 ┆ 0.066047 ┆ 0.936805 ┆ 2 ┆ 3 │\n", "└───────────┴──────────┴──────────┴────────────┴────────────┘" ] }, @@ -1773,17 +1773,17 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (1, 3)
t-testchi2-testf-test
struct[2]struct[2]struct[2]
{0.242506,0.808393}{30.627761,0.721731}{1.144874,0.333362}
" + "shape: (1, 3)
t-testchi2-testf-test
struct[2]struct[2]struct[2]
{-0.64793,0.517045}{22.170505,0.965739}{0.50483,0.732206}
" ], "text/plain": [ "shape: (1, 3)\n", - "┌─────────────────────┬──────────────────────┬─────────────────────┐\n", - "│ t-test ┆ chi2-test ┆ f-test │\n", - "│ --- ┆ --- ┆ --- │\n", - "│ struct[2] ┆ struct[2] ┆ struct[2] │\n", - "╞═════════════════════╪══════════════════════╪═════════════════════╡\n", - "│ {0.242506,0.808393} ┆ {30.627761,0.721731} ┆ {1.144874,0.333362} │\n", - "└─────────────────────┴──────────────────────┴─────────────────────┘" + "┌─────────────────────┬──────────────────────┬────────────────────┐\n", + "│ t-test ┆ chi2-test ┆ f-test │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ struct[2] ┆ struct[2] ┆ struct[2] │\n", + "╞═════════════════════╪══════════════════════╪════════════════════╡\n", + "│ {-0.64793,0.517045} ┆ {22.170505,0.965739} ┆ {0.50483,0.732206} │\n", + "└─────────────────────┴──────────────────────┴────────────────────┘" ] }, "execution_count": 39, @@ -1816,9 +1816,9 @@ "│ --- ┆ --- ┆ --- ┆ --- │\n", "│ i64 ┆ struct[2] ┆ struct[2] ┆ struct[2] │\n", "╞═══════════╪══════════════════════╪══════════════════════╪═════════════════════╡\n", - "│ 0 ┆ {-0.508801,0.610925} ┆ {46.709664,0.108988} ┆ {1.720471,0.14282} │\n", - "│ 1 ┆ {0.493627,0.621602} ┆ {39.48074,0.317128} ┆ {0.318692,0.865595} │\n", - "│ 2 ┆ {0.437398,0.661851} ┆ {24.599585,0.924818} ┆ {0.553822,0.696258} │\n", + "│ 0 ┆ {-1.04167,0.29764} ┆ {27.019093,0.860273} ┆ {0.542672,0.70442} │\n", + "│ 1 ┆ {-0.146552,0.883494} ┆ {37.653112,0.393468} ┆ {0.377929,0.824525} │\n", + "│ 2 ┆ {0.064001,0.948973} ┆ {26.863301,0.865169} ┆ {0.05765,0.993836} │\n", "└───────────┴──────────────────────┴──────────────────────┴─────────────────────┘\n" ] } @@ -1850,7 +1850,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (9, 2)
first_digit_cntfirst_digit_distribution
u32f64
5570.1114
5830.1166
5730.1146
5400.108
5260.1052
5720.1144
5370.1074
5530.1106
5590.1118
" + "shape: (9, 2)
first_digit_cntfirst_digit_distribution
u32f64
5660.1132
5850.117
5910.1182
5380.1076
5530.1106
5560.1112
5560.1112
4960.0992
5590.1118
" ], "text/plain": [ "shape: (9, 2)\n", @@ -1859,14 +1859,14 @@ "│ --- ┆ --- │\n", "│ u32 ┆ f64 │\n", "╞═════════════════╪══════════════════════════╡\n", - "│ 557 ┆ 0.1114 │\n", - "│ 583 ┆ 0.1166 │\n", - "│ 573 ┆ 0.1146 │\n", - "│ 540 ┆ 0.108 │\n", - "│ 526 ┆ 0.1052 │\n", - "│ 572 ┆ 0.1144 │\n", - "│ 537 ┆ 0.1074 │\n", + "│ 566 ┆ 0.1132 │\n", + "│ 585 ┆ 0.117 │\n", + "│ 591 ┆ 0.1182 │\n", + "│ 538 ┆ 0.1076 │\n", "│ 553 ┆ 0.1106 │\n", + "│ 556 ┆ 0.1112 │\n", + "│ 556 ┆ 0.1112 │\n", + "│ 496 ┆ 0.0992 │\n", "│ 559 ┆ 0.1118 │\n", "└─────────────────┴──────────────────────────┘" ] @@ -1933,7 +1933,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 7)
idvar1var2var3rrhnb_l_inf_cnt
u32f64f64f64f64f64u32
00.9732530.7845980.0322570.4135645.8829228
10.9649420.9349590.1254310.5662961.29994410
20.7676430.8110020.1100480.1125193.81431217
30.0973680.9299460.1760870.4935873.61023311
40.5747190.0927240.5301830.5930041.33270917
" + "shape: (5, 7)
idvar1var2var3rrhnb_l_inf_cnt
u32f64f64f64f64f64u32
00.6908010.1701210.4412140.9414386.09243312
10.8122470.0952820.3050640.6983634.98288613
20.8872740.7654510.3958540.4973460.9946819
30.312950.121560.3132210.8504280.94233617
40.7923790.0792880.8098530.4039861.95290510
" ], "text/plain": [ "shape: (5, 7)\n", @@ -1942,11 +1942,11 @@ "│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", "│ u32 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ u32 │\n", "╞═════╪══════════╪══════════╪══════════╪══════════╪══════════╪══════════════╡\n", - "│ 0 ┆ 0.973253 ┆ 0.784598 ┆ 0.032257 ┆ 0.413564 ┆ 5.882922 ┆ 8 │\n", - "│ 1 ┆ 0.964942 ┆ 0.934959 ┆ 0.125431 ┆ 0.566296 ┆ 1.299944 ┆ 10 │\n", - "│ 2 ┆ 0.767643 ┆ 0.811002 ┆ 0.110048 ┆ 0.112519 ┆ 3.814312 ┆ 17 │\n", - "│ 3 ┆ 0.097368 ┆ 0.929946 ┆ 0.176087 ┆ 0.493587 ┆ 3.610233 ┆ 11 │\n", - "│ 4 ┆ 0.574719 ┆ 0.092724 ┆ 0.530183 ┆ 0.593004 ┆ 1.332709 ┆ 17 │\n", + "│ 0 ┆ 0.690801 ┆ 0.170121 ┆ 0.441214 ┆ 0.941438 ┆ 6.092433 ┆ 12 │\n", + "│ 1 ┆ 0.812247 ┆ 0.095282 ┆ 0.305064 ┆ 0.698363 ┆ 4.982886 ┆ 13 │\n", + "│ 2 ┆ 0.887274 ┆ 0.765451 ┆ 0.395854 ┆ 0.497346 ┆ 0.99468 ┆ 19 │\n", + "│ 3 ┆ 0.31295 ┆ 0.12156 ┆ 0.313221 ┆ 0.850428 ┆ 0.942336 ┆ 17 │\n", + "│ 4 ┆ 0.792379 ┆ 0.079288 ┆ 0.809853 ┆ 0.403986 ┆ 1.952905 ┆ 10 │\n", "└─────┴──────────┴──────────┴──────────┴──────────┴──────────┴──────────────┘" ] }, @@ -1983,7 +1983,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 7)
idvar1var2var3rrhnb_l1_r_cnt
u32f64f64f64f64f64u32
00.9732530.7845980.0322570.4135645.88292256
10.9649420.9349590.1254310.5662961.299944133
20.7676430.8110020.1100480.1125193.8143125
30.0973680.9299460.1760870.4935873.610233146
40.5747190.0927240.5301830.5930041.332709367
" + "shape: (5, 7)
idvar1var2var3rrhnb_l1_r_cnt
u32f64f64f64f64f64u32
00.6908010.1701210.4412140.9414386.0924331142
10.8122470.0952820.3050640.6983634.982886437
20.8872740.7654510.3958540.4973460.99468226
30.312950.121560.3132210.8504280.942336810
40.7923790.0792880.8098530.4039861.952905112
" ], "text/plain": [ "shape: (5, 7)\n", @@ -1992,11 +1992,11 @@ "│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", "│ u32 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ u32 │\n", "╞═════╪══════════╪══════════╪══════════╪══════════╪══════════╪═════════════╡\n", - "│ 0 ┆ 0.973253 ┆ 0.784598 ┆ 0.032257 ┆ 0.413564 ┆ 5.882922 ┆ 56 │\n", - "│ 1 ┆ 0.964942 ┆ 0.934959 ┆ 0.125431 ┆ 0.566296 ┆ 1.299944 ┆ 133 │\n", - "│ 2 ┆ 0.767643 ┆ 0.811002 ┆ 0.110048 ┆ 0.112519 ┆ 3.814312 ┆ 5 │\n", - "│ 3 ┆ 0.097368 ┆ 0.929946 ┆ 0.176087 ┆ 0.493587 ┆ 3.610233 ┆ 146 │\n", - "│ 4 ┆ 0.574719 ┆ 0.092724 ┆ 0.530183 ┆ 0.593004 ┆ 1.332709 ┆ 367 │\n", + "│ 0 ┆ 0.690801 ┆ 0.170121 ┆ 0.441214 ┆ 0.941438 ┆ 6.092433 ┆ 1142 │\n", + "│ 1 ┆ 0.812247 ┆ 0.095282 ┆ 0.305064 ┆ 0.698363 ┆ 4.982886 ┆ 437 │\n", + "│ 2 ┆ 0.887274 ┆ 0.765451 ┆ 0.395854 ┆ 0.497346 ┆ 0.99468 ┆ 226 │\n", + "│ 3 ┆ 0.31295 ┆ 0.12156 ┆ 0.313221 ┆ 0.850428 ┆ 0.942336 ┆ 810 │\n", + "│ 4 ┆ 0.792379 ┆ 0.079288 ┆ 0.809853 ┆ 0.403986 ┆ 1.952905 ┆ 112 │\n", "└─────┴──────────┴──────────┴──────────┴──────────┴──────────┴─────────────┘" ] }, @@ -2032,7 +2032,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 7)
idvar1var2var3rrhbest friends
u32f64f64f64f64f64list[u32]
00.9732530.7845980.0322570.4135645.882922[0, 1569, … 709]
10.9649420.9349590.1254310.5662961.299944[1, 339, … 1014]
20.7676430.8110020.1100480.1125193.814312[2, 1106, … 1819]
30.0973680.9299460.1760870.4935873.610233[3, 53, … 213]
40.5747190.0927240.5301830.5930041.332709[4, 1174, … 642]
" + "shape: (5, 7)
idvar1var2var3rrhbest friends
u32f64f64f64f64f64list[u32]
00.6908010.1701210.4412140.9414386.092433[0, 1302, … 1845]
10.8122470.0952820.3050640.6983634.982886[1, 341, … 1938]
20.8872740.7654510.3958540.4973460.99468[2, 1331, … 259]
30.312950.121560.3132210.8504280.942336[3, 1198, … 1277]
40.7923790.0792880.8098530.4039861.952905[4, 957, … 553]
" ], "text/plain": [ "shape: (5, 7)\n", @@ -2041,11 +2041,11 @@ "│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", "│ u32 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ list[u32] │\n", "╞═════╪══════════╪══════════╪══════════╪══════════╪══════════╪═══════════════════╡\n", - "│ 0 ┆ 0.973253 ┆ 0.784598 ┆ 0.032257 ┆ 0.413564 ┆ 5.882922 ┆ [0, 1569, … 709] │\n", - "│ 1 ┆ 0.964942 ┆ 0.934959 ┆ 0.125431 ┆ 0.566296 ┆ 1.299944 ┆ [1, 339, … 1014] │\n", - "│ 2 ┆ 0.767643 ┆ 0.811002 ┆ 0.110048 ┆ 0.112519 ┆ 3.814312 ┆ [2, 1106, … 1819] │\n", - "│ 3 ┆ 0.097368 ┆ 0.929946 ┆ 0.176087 ┆ 0.493587 ┆ 3.610233 ┆ [3, 53, … 213] │\n", - "│ 4 ┆ 0.574719 ┆ 0.092724 ┆ 0.530183 ┆ 0.593004 ┆ 1.332709 ┆ [4, 1174, … 642] │\n", + "│ 0 ┆ 0.690801 ┆ 0.170121 ┆ 0.441214 ┆ 0.941438 ┆ 6.092433 ┆ [0, 1302, … 1845] │\n", + "│ 1 ┆ 0.812247 ┆ 0.095282 ┆ 0.305064 ┆ 0.698363 ┆ 4.982886 ┆ [1, 341, … 1938] │\n", + "│ 2 ┆ 0.887274 ┆ 0.765451 ┆ 0.395854 ┆ 0.497346 ┆ 0.99468 ┆ [2, 1331, … 259] │\n", + "│ 3 ┆ 0.31295 ┆ 0.12156 ┆ 0.313221 ┆ 0.850428 ┆ 0.942336 ┆ [3, 1198, … 1277] │\n", + "│ 4 ┆ 0.792379 ┆ 0.079288 ┆ 0.809853 ┆ 0.403986 ┆ 1.952905 ┆ [4, 957, … 553] │\n", "└─────┴──────────┴──────────┴──────────┴──────────┴──────────┴───────────────────┘" ] }, @@ -2084,11 +2084,11 @@ "│ --- ┆ --- ┆ --- │\n", "│ u32 ┆ list[u32] ┆ u32 │\n", "╞═════╪═══════════════════╪════════════════════╡\n", - "│ 0 ┆ [0, 1569, … 709] ┆ 3 │\n", - "│ 1 ┆ [1, 339, … 709] ┆ 7 │\n", - "│ 2 ┆ [2, 1106, … 902] ┆ 8 │\n", - "│ 3 ┆ [3, 53, … 1316] ┆ 6 │\n", - "│ 4 ┆ [4, 1174, … 1498] ┆ 9 │\n", + "│ 0 ┆ [0, 1302, … 1655] ┆ 5 │\n", + "│ 1 ┆ [1, 341, … 700] ┆ 8 │\n", + "│ 2 ┆ [2, 1331, … 946] ┆ 5 │\n", + "│ 3 ┆ [3, 1198, … 1047] ┆ 7 │\n", + "│ 4 ┆ [4, 957, … 246] ┆ 4 │\n", "└─────┴───────────────────┴────────────────────┘\n" ] } @@ -2129,7 +2129,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 8)
idvar1var2var3rrhidxdist
u32f64f64f64f64f64list[u32]list[f64]
00.9732530.7845980.0322570.4135645.882922[0, 1569, … 709][0.0, 0.064215, … 0.096617]
10.9649420.9349590.1254310.5662961.299944[1, 339, … 1014][0.0, 0.036057, … 0.063321]
20.7676430.8110020.1100480.1125193.814312[2, 1106, … 1819][0.0, 0.029892, … 0.062706]
30.0973680.9299460.1760870.4935873.610233[3, 53, … 213][0.0, 0.021567, … 0.070992]
40.5747190.0927240.5301830.5930041.332709[4, 1174, … 642][0.0, 0.05006, … 0.08659]
" + "shape: (5, 8)
idvar1var2var3rrhidxdist
u32f64f64f64f64f64list[u32]list[f64]
00.6908010.1701210.4412140.9414386.092433[0, 1302, … 1845][0.0, 0.0724, … 0.097051]
10.8122470.0952820.3050640.6983634.982886[1, 341, … 1938][0.0, 0.040774, … 0.063491]
20.8872740.7654510.3958540.4973460.99468[2, 1331, … 259][0.0, 0.053612, … 0.076934]
30.312950.121560.3132210.8504280.942336[3, 1198, … 1277][0.0, 0.04743, … 0.065846]
40.7923790.0792880.8098530.4039861.952905[4, 957, … 553][0.0, 0.062671, … 0.074585]
" ], "text/plain": [ "shape: (5, 8)\n", @@ -2138,16 +2138,16 @@ "│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", "│ u32 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ list[u32] ┆ list[f64] │\n", "╞═════╪══════════╪══════════╪══════════╪══════════╪══════════╪══════════════════╪══════════════════╡\n", - "│ 0 ┆ 0.973253 ┆ 0.784598 ┆ 0.032257 ┆ 0.413564 ┆ 5.882922 ┆ [0, 1569, … 709] ┆ [0.0, 0.064215, │\n", - "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ … 0.096617] │\n", - "│ 1 ┆ 0.964942 ┆ 0.934959 ┆ 0.125431 ┆ 0.566296 ┆ 1.299944 ┆ [1, 339, … 1014] ┆ [0.0, 0.036057, │\n", - "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ … 0.063321] │\n", - "│ 2 ┆ 0.767643 ┆ 0.811002 ┆ 0.110048 ┆ 0.112519 ┆ 3.814312 ┆ [2, 1106, … ┆ [0.0, 0.029892, │\n", - "│ ┆ ┆ ┆ ┆ ┆ ┆ 1819] ┆ … 0.062706] │\n", - "│ 3 ┆ 0.097368 ┆ 0.929946 ┆ 0.176087 ┆ 0.493587 ┆ 3.610233 ┆ [3, 53, … 213] ┆ [0.0, 0.021567, │\n", - "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ … 0.070992] │\n", - "│ 4 ┆ 0.574719 ┆ 0.092724 ┆ 0.530183 ┆ 0.593004 ┆ 1.332709 ┆ [4, 1174, … 642] ┆ [0.0, 0.05006, … │\n", - "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 0.08659] │\n", + "│ 0 ┆ 0.690801 ┆ 0.170121 ┆ 0.441214 ┆ 0.941438 ┆ 6.092433 ┆ [0, 1302, … ┆ [0.0, 0.0724, … │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ 1845] ┆ 0.097051] │\n", + "│ 1 ┆ 0.812247 ┆ 0.095282 ┆ 0.305064 ┆ 0.698363 ┆ 4.982886 ┆ [1, 341, … 1938] ┆ [0.0, 0.040774, │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ … 0.063491] │\n", + "│ 2 ┆ 0.887274 ┆ 0.765451 ┆ 0.395854 ┆ 0.497346 ┆ 0.99468 ┆ [2, 1331, … 259] ┆ [0.0, 0.053612, │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ … 0.076934] │\n", + "│ 3 ┆ 0.31295 ┆ 0.12156 ┆ 0.313221 ┆ 0.850428 ┆ 0.942336 ┆ [3, 1198, … ┆ [0.0, 0.04743, … │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ 1277] ┆ 0.065846] │\n", + "│ 4 ┆ 0.792379 ┆ 0.079288 ┆ 0.809853 ┆ 0.403986 ┆ 1.952905 ┆ [4, 957, … 553] ┆ [0.0, 0.062671, │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ … 0.074585] │\n", "└─────┴──────────┴──────────┴──────────┴──────────┴──────────┴──────────────────┴──────────────────┘" ] }, @@ -2187,7 +2187,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 6)
idvar1var2var3rrh
u32f64f64f64f64f64
40.5747190.0927240.5301830.5930041.332709
50.2844240.2710620.5535310.4179664.595913
60.7650660.5592270.7937380.9708977.312772
70.2153920.6012530.422710.8575962.016605
100.4873380.2341030.4574890.4828226.715263
" + "shape: (5, 6)
idvar1var2var3rrh
u32f64f64f64f64f64
00.6908010.1701210.4412140.9414386.092433
50.4889170.5695470.7191970.2225728.093664
60.7849010.5998070.6625870.0835769.275924
70.6557480.455770.5524080.7117638.581199
90.1898030.3366740.5211460.31859.84388
" ], "text/plain": [ "shape: (5, 6)\n", @@ -2196,11 +2196,11 @@ "│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", "│ u32 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 │\n", "╞═════╪══════════╪══════════╪══════════╪══════════╪══════════╡\n", - "│ 4 ┆ 0.574719 ┆ 0.092724 ┆ 0.530183 ┆ 0.593004 ┆ 1.332709 │\n", - "│ 5 ┆ 0.284424 ┆ 0.271062 ┆ 0.553531 ┆ 0.417966 ┆ 4.595913 │\n", - "│ 6 ┆ 0.765066 ┆ 0.559227 ┆ 0.793738 ┆ 0.970897 ┆ 7.312772 │\n", - "│ 7 ┆ 0.215392 ┆ 0.601253 ┆ 0.42271 ┆ 0.857596 ┆ 2.016605 │\n", - "│ 10 ┆ 0.487338 ┆ 0.234103 ┆ 0.457489 ┆ 0.482822 ┆ 6.715263 │\n", + "│ 0 ┆ 0.690801 ┆ 0.170121 ┆ 0.441214 ┆ 0.941438 ┆ 6.092433 │\n", + "│ 5 ┆ 0.488917 ┆ 0.569547 ┆ 0.719197 ┆ 0.222572 ┆ 8.093664 │\n", + "│ 6 ┆ 0.784901 ┆ 0.599807 ┆ 0.662587 ┆ 0.083576 ┆ 9.275924 │\n", + "│ 7 ┆ 0.655748 ┆ 0.45577 ┆ 0.552408 ┆ 0.711763 ┆ 8.581199 │\n", + "│ 9 ┆ 0.189803 ┆ 0.336674 ┆ 0.521146 ┆ 0.3185 ┆ 9.84388 │\n", "└─────┴──────────┴──────────┴──────────┴──────────┴──────────┘" ] }, @@ -2237,7 +2237,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 6)
idvar1var2var3rrh
u32f64f64f64f64f64
130.5543960.4882170.4709830.6787257.485183
520.5594810.563120.3215180.9442068.376489
740.5508880.4963130.6743240.6961425.191851
920.4667050.5282520.2550160.8977568.242887
930.5556640.4891690.3488290.4248420.734708
" + "shape: (5, 6)
idvar1var2var3rrh
u32f64f64f64f64f64
50.4889170.5695470.7191970.2225728.093664
250.5523140.5371430.2283360.3224787.359873
480.4582310.5389320.2512320.4248173.99862
540.5076410.5163750.4733930.7584250.310381
700.5167250.5470320.1764680.000182.904882
" ], "text/plain": [ "shape: (5, 6)\n", @@ -2246,11 +2246,11 @@ "│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", "│ u32 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 │\n", "╞═════╪══════════╪══════════╪══════════╪══════════╪══════════╡\n", - "│ 13 ┆ 0.554396 ┆ 0.488217 ┆ 0.470983 ┆ 0.678725 ┆ 7.485183 │\n", - "│ 52 ┆ 0.559481 ┆ 0.56312 ┆ 0.321518 ┆ 0.944206 ┆ 8.376489 │\n", - "│ 74 ┆ 0.550888 ┆ 0.496313 ┆ 0.674324 ┆ 0.696142 ┆ 5.191851 │\n", - "│ 92 ┆ 0.466705 ┆ 0.528252 ┆ 0.255016 ┆ 0.897756 ┆ 8.242887 │\n", - "│ 93 ┆ 0.555664 ┆ 0.489169 ┆ 0.348829 ┆ 0.424842 ┆ 0.734708 │\n", + "│ 5 ┆ 0.488917 ┆ 0.569547 ┆ 0.719197 ┆ 0.222572 ┆ 8.093664 │\n", + "│ 25 ┆ 0.552314 ┆ 0.537143 ┆ 0.228336 ┆ 0.322478 ┆ 7.359873 │\n", + "│ 48 ┆ 0.458231 ┆ 0.538932 ┆ 0.251232 ┆ 0.424817 ┆ 3.99862 │\n", + "│ 54 ┆ 0.507641 ┆ 0.516375 ┆ 0.473393 ┆ 0.758425 ┆ 0.310381 │\n", + "│ 70 ┆ 0.516725 ┆ 0.547032 ┆ 0.176468 ┆ 0.00018 ┆ 2.904882 │\n", "└─────┴──────────┴──────────┴──────────┴──────────┴──────────┘" ] }, @@ -2287,7 +2287,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 6)
idvar1var2var3rrh
u32f64f64f64f64f64
130.5543960.4882170.4709830.6787257.485183
920.4667050.5282520.2550160.8977568.242887
1600.4633810.4275640.1766330.6043859.981861
2980.4402540.5355040.7179830.0593259.33173
4960.5255960.4491780.2195020.5683558.999516
" + "shape: (5, 6)
idvar1var2var3rrh
u32f64f64f64f64f64
50.4889170.5695470.7191970.2225728.093664
250.5523140.5371430.2283360.3224787.359873
2240.4963320.5128510.0165620.9105879.142325
2730.4727580.4200410.2038520.266539.412762
5400.4844010.5066270.866040.8902586.919762
" ], "text/plain": [ "shape: (5, 6)\n", @@ -2296,11 +2296,11 @@ "│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", "│ u32 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 │\n", "╞═════╪══════════╪══════════╪══════════╪══════════╪══════════╡\n", - "│ 13 ┆ 0.554396 ┆ 0.488217 ┆ 0.470983 ┆ 0.678725 ┆ 7.485183 │\n", - "│ 92 ┆ 0.466705 ┆ 0.528252 ┆ 0.255016 ┆ 0.897756 ┆ 8.242887 │\n", - "│ 160 ┆ 0.463381 ┆ 0.427564 ┆ 0.176633 ┆ 0.604385 ┆ 9.981861 │\n", - "│ 298 ┆ 0.440254 ┆ 0.535504 ┆ 0.717983 ┆ 0.059325 ┆ 9.33173 │\n", - "│ 496 ┆ 0.525596 ┆ 0.449178 ┆ 0.219502 ┆ 0.568355 ┆ 8.999516 │\n", + "│ 5 ┆ 0.488917 ┆ 0.569547 ┆ 0.719197 ┆ 0.222572 ┆ 8.093664 │\n", + "│ 25 ┆ 0.552314 ┆ 0.537143 ┆ 0.228336 ┆ 0.322478 ┆ 7.359873 │\n", + "│ 224 ┆ 0.496332 ┆ 0.512851 ┆ 0.016562 ┆ 0.910587 ┆ 9.142325 │\n", + "│ 273 ┆ 0.472758 ┆ 0.420041 ┆ 0.203852 ┆ 0.26653 ┆ 9.412762 │\n", + "│ 540 ┆ 0.484401 ┆ 0.506627 ┆ 0.86604 ┆ 0.890258 ┆ 6.919762 │\n", "└─────┴──────────┴──────────┴──────────┴──────────┴──────────┘" ] }, @@ -2337,21 +2337,21 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 3)
idfriendscount
u64list[u32]u32
0[0]1
1[1, 1533, 339]3
2[2, 788]2
3[3, 53, … 1592]4
4[4, 533]2
" + "shape: (5, 3)
idfriendscount
u64list[u32]u32
0[0, 1728]2
1[1, 827]2
2[2, 679]2
3[3, 1618, … 137]4
4[4, 341, … 871]4
" ], "text/plain": [ "shape: (5, 3)\n", - "┌─────┬─────────────────┬───────┐\n", - "│ id ┆ friends ┆ count │\n", - "│ --- ┆ --- ┆ --- │\n", - "│ u64 ┆ list[u32] ┆ u32 │\n", - "╞═════╪═════════════════╪═══════╡\n", - "│ 0 ┆ [0] ┆ 1 │\n", - "│ 1 ┆ [1, 1533, 339] ┆ 3 │\n", - "│ 2 ┆ [2, 788] ┆ 2 │\n", - "│ 3 ┆ [3, 53, … 1592] ┆ 4 │\n", - "│ 4 ┆ [4, 533] ┆ 2 │\n", - "└─────┴─────────────────┴───────┘" + "┌─────┬──────────────────┬───────┐\n", + "│ id ┆ friends ┆ count │\n", + "│ --- ┆ --- ┆ --- │\n", + "│ u64 ┆ list[u32] ┆ u32 │\n", + "╞═════╪══════════════════╪═══════╡\n", + "│ 0 ┆ [0, 1728] ┆ 2 │\n", + "│ 1 ┆ [1, 827] ┆ 2 │\n", + "│ 2 ┆ [2, 679] ┆ 2 │\n", + "│ 3 ┆ [3, 1618, … 137] ┆ 4 │\n", + "│ 4 ┆ [4, 341, … 871] ┆ 4 │\n", + "└─────┴──────────────────┴───────┘" ] }, "execution_count": 51, @@ -2625,7 +2625,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 4)
actualpredicted0-20-9
f64f64i32i32
0.00.07202705
0.00.59254517
1.00.23218317
1.00.71553301
1.00.28955527
" + "shape: (5, 4)
actualpredicted0-20-9
f64f64i32i32
1.00.18475400
0.00.33277623
1.00.5461922
1.00.4184908
1.00.50400910
" ], "text/plain": [ "shape: (5, 4)\n", @@ -2634,11 +2634,11 @@ "│ --- ┆ --- ┆ --- ┆ --- │\n", "│ f64 ┆ f64 ┆ i32 ┆ i32 │\n", "╞════════╪═══════════╪═════╪═════╡\n", - "│ 0.0 ┆ 0.072027 ┆ 0 ┆ 5 │\n", - "│ 0.0 ┆ 0.592545 ┆ 1 ┆ 7 │\n", - "│ 1.0 ┆ 0.232183 ┆ 1 ┆ 7 │\n", - "│ 1.0 ┆ 0.715533 ┆ 0 ┆ 1 │\n", - "│ 1.0 ┆ 0.289555 ┆ 2 ┆ 7 │\n", + "│ 1.0 ┆ 0.184754 ┆ 0 ┆ 0 │\n", + "│ 0.0 ┆ 0.332776 ┆ 2 ┆ 3 │\n", + "│ 1.0 ┆ 0.54619 ┆ 2 ┆ 2 │\n", + "│ 1.0 ┆ 0.41849 ┆ 0 ┆ 8 │\n", + "│ 1.0 ┆ 0.504009 ┆ 1 ┆ 0 │\n", "└────────┴───────────┴─────┴─────┘" ] }, @@ -2714,17 +2714,17 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (1, 5)
precisionrecallfaverage_precisionroc_auc
f64f64f64f64f64
0.499790.500120.4999550.500360.502358
" + "shape: (1, 5)
precisionrecallfaverage_precisionroc_auc
f64f64f64f64f64
0.4984740.4982930.4983840.4981710.500488
" ], "text/plain": [ "shape: (1, 5)\n", - "┌───────────┬─────────┬──────────┬───────────────────┬──────────┐\n", - "│ precision ┆ recall ┆ f ┆ average_precision ┆ roc_auc │\n", - "│ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", - "│ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 │\n", - "╞═══════════╪═════════╪══════════╪═══════════════════╪══════════╡\n", - "│ 0.49979 ┆ 0.50012 ┆ 0.499955 ┆ 0.50036 ┆ 0.502358 │\n", - "└───────────┴─────────┴──────────┴───────────────────┴──────────┘" + "┌───────────┬──────────┬──────────┬───────────────────┬──────────┐\n", + "│ precision ┆ recall ┆ f ┆ average_precision ┆ roc_auc │\n", + "│ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n", + "│ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 │\n", + "╞═══════════╪══════════╪══════════╪═══════════════════╪══════════╡\n", + "│ 0.498474 ┆ 0.498293 ┆ 0.498384 ┆ 0.498171 ┆ 0.500488 │\n", + "└───────────┴──────────┴──────────┴───────────────────┴──────────┘" ] }, "execution_count": 59, @@ -2755,7 +2755,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (5, 4)
cnt<=baseline_pctactual_pctpsi_bin
f64f64f64f64
0.169530.20.1620.008007
0.3590350.20.1970.000045
0.5728230.20.2240.00272
0.7995980.20.230.004193
inf0.20.1870.000874
" + "shape: (5, 4)
cnt<=baseline_pctactual_pctpsi_bin
f64f64f64f64
0.1839270.20.1820.001698
0.3838860.20.230.004193
0.5832030.20.1850.001169
0.7882150.20.1950.000127
inf0.20.2080.000314
" ], "text/plain": [ "shape: (5, 4)\n", @@ -2764,11 +2764,11 @@ "│ --- ┆ --- ┆ --- ┆ --- │\n", "│ f64 ┆ f64 ┆ f64 ┆ f64 │\n", "╞══════════╪══════════════╪════════════╪══════════╡\n", - "│ 0.16953 ┆ 0.2 ┆ 0.162 ┆ 0.008007 │\n", - "│ 0.359035 ┆ 0.2 ┆ 0.197 ┆ 0.000045 │\n", - "│ 0.572823 ┆ 0.2 ┆ 0.224 ┆ 0.00272 │\n", - "│ 0.799598 ┆ 0.2 ┆ 0.23 ┆ 0.004193 │\n", - "│ inf ┆ 0.2 ┆ 0.187 ┆ 0.000874 │\n", + "│ 0.183927 ┆ 0.2 ┆ 0.182 ┆ 0.001698 │\n", + "│ 0.383886 ┆ 0.2 ┆ 0.23 ┆ 0.004193 │\n", + "│ 0.583203 ┆ 0.2 ┆ 0.185 ┆ 0.001169 │\n", + "│ 0.788215 ┆ 0.2 ┆ 0.195 ┆ 0.000127 │\n", + "│ inf ┆ 0.2 ┆ 0.208 ┆ 0.000314 │\n", "└──────────┴──────────────┴────────────┴──────────┘" ] }, @@ -2804,17 +2804,17 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (1, 1)
cid_ce
f64
13.02354
" + "shape: (1, 1)
cid_ce
f64
12.759065
" ], "text/plain": [ "shape: (1, 1)\n", - "┌──────────┐\n", - "│ cid_ce │\n", - "│ --- │\n", - "│ f64 │\n", - "╞══════════╡\n", - "│ 13.02354 │\n", - "└──────────┘" + "┌───────────┐\n", + "│ cid_ce │\n", + "│ --- │\n", + "│ f64 │\n", + "╞═══════════╡\n", + "│ 12.759065 │\n", + "└───────────┘" ] }, "execution_count": 61, @@ -2845,7 +2845,7 @@ " white-space: pre-wrap;\n", "}\n", "\n", - "shape: (1, 1)
c3_stats
f64
0.137019
" + "shape: (1, 1)
c3_stats
f64
0.123318
" ], "text/plain": [ "shape: (1, 1)\n", @@ -2854,7 +2854,7 @@ "│ --- │\n", "│ f64 │\n", "╞══════════╡\n", - "│ 0.137019 │\n", + "│ 0.123318 │\n", "└──────────┘" ] }, @@ -2871,11 +2871,515 @@ ")" ] }, + { + "cell_type": "markdown", + "id": "bc7ee650", + "metadata": {}, + "source": [ + "# Examples using Ball Tree Features" + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 65, "id": "cc89e553", "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (5, 13)
ftime_idxdummyactualpredicteddummy_groupsx1x2x3abyindex
f64i64stri32f64strf64f64f64f64f64f64u32
0.00"a"00.491635"a"0.7116890.6244970.036810.516180.7149250.2388940
0.8414711"a"10.01081"a"0.9903790.2739850.6918630.4489780.258321-0.8069791
0.9092972"a"10.016883"a"0.0798270.3496220.0125410.3488810.607950.0981342
0.141123"a"10.738627"a"0.9838780.3424830.9150680.7223410.762508-1.1222523
-0.7568024"a"00.596465"a"0.7607990.3092110.6870140.6561740.942946-0.823564
" + ], + "text/plain": [ + "shape: (5, 13)\n", + "┌───────────┬──────────┬───────┬────────┬───┬──────────┬──────────┬───────────┬───────┐\n", + "│ f ┆ time_idx ┆ dummy ┆ actual ┆ … ┆ a ┆ b ┆ y ┆ index │\n", + "│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │\n", + "│ f64 ┆ i64 ┆ str ┆ i32 ┆ ┆ f64 ┆ f64 ┆ f64 ┆ u32 │\n", + "╞═══════════╪══════════╪═══════╪════════╪═══╪══════════╪══════════╪═══════════╪═══════╡\n", + "│ 0.0 ┆ 0 ┆ a ┆ 0 ┆ … ┆ 0.51618 ┆ 0.714925 ┆ 0.238894 ┆ 0 │\n", + "│ 0.841471 ┆ 1 ┆ a ┆ 1 ┆ … ┆ 0.448978 ┆ 0.258321 ┆ -0.806979 ┆ 1 │\n", + "│ 0.909297 ┆ 2 ┆ a ┆ 1 ┆ … ┆ 0.348881 ┆ 0.60795 ┆ 0.098134 ┆ 2 │\n", + "│ 0.14112 ┆ 3 ┆ a ┆ 1 ┆ … ┆ 0.722341 ┆ 0.762508 ┆ -1.122252 ┆ 3 │\n", + "│ -0.756802 ┆ 4 ┆ a ┆ 0 ┆ … ┆ 0.656174 ┆ 0.942946 ┆ -0.82356 ┆ 4 │\n", + "└───────────┴──────────┴───────┴────────┴───┴──────────┴──────────┴───────────┴───────┘" + ] + }, + "execution_count": 65, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "size = 1_000\n", + "df = pl.DataFrame({\n", + " \"f\": np.sin(list(range(size)))\n", + " , \"time_idx\": range(size)\n", + " , \"dummy\": [\"a\"] * (size // 2) + [\"b\"] * (size // 2)\n", + " , \"actual\": np.round(np.random.random(size=size)).astype(np.int32)\n", + " , \"predicted\": np.random.random(size=size)\n", + " , \"dummy_groups\":[\"a\"] * (size//2) + [\"b\"] * (size//2) \n", + "}).with_columns(\n", + " pds.random(0., 1.).alias(\"x1\")\n", + " , pds.random(0., 1.).alias(\"x2\")\n", + " , pds.random(0., 1.).alias(\"x3\")\n", + " , pds.random(0., 1.).alias(\"a\")\n", + " , pds.random(0., 1.).alias(\"b\")\n", + ").with_columns(\n", + " y = pl.col(\"x1\") * 0.15 + pl.col(\"x2\") * 0.3 - pl.col(\"x3\") * 1.5 + pds.random() * 0.0001\n", + ")\n", + "df = df.with_columns(pl.int_range(0, size).cast(dtype=pl.UInt32).alias(\"index\"))\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "id": "728a9bc1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (1_000, 14)
ftime_idxdummyactualpredicteddummy_groupsx1x2x3abyindexball_tree_knn_ptwise
f64i64stri32f64strf64f64f64f64f64f64u32list[u32]
0.00"a"00.491635"a"0.7116890.6244970.036810.516180.7149250.2388940[0, 309, … 383]
0.8414711"a"10.01081"a"0.9903790.2739850.6918630.4489780.258321-0.8069791[1, 215, … 4]
0.9092972"a"10.016883"a"0.0798270.3496220.0125410.3488810.607950.0981342[2, 837, … 213]
0.141123"a"10.738627"a"0.9838780.3424830.9150680.7223410.762508-1.1222523[3, 366, … 78]
-0.7568024"a"00.596465"a"0.7607990.3092110.6870140.6561740.942946-0.823564[4, 196, … 942]
0.773833995"b"00.81024"b"0.8085940.7986730.7761580.7992830.96809-0.80334995[995, 234, … 915]
-0.114875996"b"00.435576"b"0.1828520.9522570.6639210.685660.079458-0.682739996[996, 490, … 827]
-0.897967997"b"10.799476"b"0.1310820.3890230.654030.6698340.847043-0.844666997[997, 456, … 116]
-0.855473998"b"10.12903"b"0.7948420.7948820.2971780.1137990.990062-0.088007998[998, 769, … 321]
-0.026461999"b"10.209619"b"0.5471870.216120.0063550.1310240.0380850.137398999[999, 107, … 586]
" + ], + "text/plain": [ + "shape: (1_000, 14)\n", + "┌───────────┬──────────┬───────┬────────┬───┬──────────┬───────────┬───────┬──────────────────────┐\n", + "│ f ┆ time_idx ┆ dummy ┆ actual ┆ … ┆ b ┆ y ┆ index ┆ ball_tree_knn_ptwise │\n", + "│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │\n", + "│ f64 ┆ i64 ┆ str ┆ i32 ┆ ┆ f64 ┆ f64 ┆ u32 ┆ list[u32] │\n", + "╞═══════════╪══════════╪═══════╪════════╪═══╪══════════╪═══════════╪═══════╪══════════════════════╡\n", + "│ 0.0 ┆ 0 ┆ a ┆ 0 ┆ … ┆ 0.714925 ┆ 0.238894 ┆ 0 ┆ [0, 309, … 383] │\n", + "│ 0.841471 ┆ 1 ┆ a ┆ 1 ┆ … ┆ 0.258321 ┆ -0.806979 ┆ 1 ┆ [1, 215, … 4] │\n", + "│ 0.909297 ┆ 2 ┆ a ┆ 1 ┆ … ┆ 0.60795 ┆ 0.098134 ┆ 2 ┆ [2, 837, … 213] │\n", + "│ 0.14112 ┆ 3 ┆ a ┆ 1 ┆ … ┆ 0.762508 ┆ -1.122252 ┆ 3 ┆ [3, 366, … 78] │\n", + "│ -0.756802 ┆ 4 ┆ a ┆ 0 ┆ … ┆ 0.942946 ┆ -0.82356 ┆ 4 ┆ [4, 196, … 942] │\n", + "│ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … │\n", + "│ 0.773833 ┆ 995 ┆ b ┆ 0 ┆ … ┆ 0.96809 ┆ -0.80334 ┆ 995 ┆ [995, 234, … 915] │\n", + "│ -0.114875 ┆ 996 ┆ b ┆ 0 ┆ … ┆ 0.079458 ┆ -0.682739 ┆ 996 ┆ [996, 490, … 827] │\n", + "│ -0.897967 ┆ 997 ┆ b ┆ 1 ┆ … ┆ 0.847043 ┆ -0.844666 ┆ 997 ┆ [997, 456, … 116] │\n", + "│ -0.855473 ┆ 998 ┆ b ┆ 1 ┆ … ┆ 0.990062 ┆ -0.088007 ┆ 998 ┆ [998, 769, … 321] │\n", + "│ -0.026461 ┆ 999 ┆ b ┆ 1 ┆ … ┆ 0.038085 ┆ 0.137398 ┆ 999 ┆ [999, 107, … 586] │\n", + "└───────────┴──────────┴───────┴────────┴───┴──────────┴───────────┴───────┴──────────────────────┘" + ] + }, + "execution_count": 66, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Pointwise Nearest Neighbors\n", + "df.with_columns(\n", + " pds.query_bt_knn_ptwise(\n", + " pl.col(\"x2\"), \n", + " pl.col(\"x3\"), \n", + " index=pl.col(\"index\"),\n", + " r=999.0,\n", + " k =5,\n", + " distance_metric=\"haversine\",\n", + " parallel=True).alias(\"ball_tree_knn_ptwise\")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "id": "74f1e763", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (1_000, 15)
ftime_idxdummyactualpredicteddummy_groupsx1x2x3abyindexidsdistances
f64i64stri32f64strf64f64f64f64f64f64u32list[u32]list[f64]
0.00"a"00.491635"a"0.7116890.6244970.036810.516180.7149250.2388940[0, 309, … 383][0.0, 0.227001, … 4.571094]
0.8414711"a"10.01081"a"0.9903790.2739850.6918630.4489780.258321-0.8069791[1, 215, … 4][0.0, 0.472495, … 3.953833]
0.9092972"a"10.016883"a"0.0798270.3496220.0125410.3488810.607950.0981342[2, 837, … 213][0.0, 0.507418, … 2.867486]
0.141123"a"10.738627"a"0.9838780.3424830.9150680.7223410.762508-1.1222523[3, 366, … 78][0.0, 0.094053, … 3.338146]
-0.7568024"a"00.596465"a"0.7607990.3092110.6870140.6561740.942946-0.823564[4, 196, … 942][0.0, 2.545978, … 3.247336]
0.773833995"b"00.81024"b"0.8085940.7986730.7761580.7992830.96809-0.80334995[995, 234, … 915][0.0, 0.762144, … 3.949634]
-0.114875996"b"00.435576"b"0.1828520.9522570.6639210.685660.079458-0.682739996[996, 490, … 827][0.0, 0.410204, … 3.064209]
-0.897967997"b"10.799476"b"0.1310820.3890230.654030.6698340.847043-0.844666997[997, 456, … 116][0.0, 0.796457, … 4.244226]
-0.855473998"b"10.12903"b"0.7948420.7948820.2971780.1137990.990062-0.088007998[998, 769, … 321][0.0, 0.507682, … 2.901211]
-0.026461999"b"10.209619"b"0.5471870.216120.0063550.1310240.0380850.137398999[999, 107, … 586][0.0, 1.816998, … 2.902915]
" + ], + "text/plain": [ + "shape: (1_000, 15)\n", + "┌───────────┬──────────┬───────┬────────┬───┬───────────┬───────┬────────────────────┬─────────────┐\n", + "│ f ┆ time_idx ┆ dummy ┆ actual ┆ … ┆ y ┆ index ┆ ids ┆ distances │\n", + "│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │\n", + "│ f64 ┆ i64 ┆ str ┆ i32 ┆ ┆ f64 ┆ u32 ┆ list[u32] ┆ list[f64] │\n", + "╞═══════════╪══════════╪═══════╪════════╪═══╪═══════════╪═══════╪════════════════════╪═════════════╡\n", + "│ 0.0 ┆ 0 ┆ a ┆ 0 ┆ … ┆ 0.238894 ┆ 0 ┆ [0, 309, … 383] ┆ [0.0, │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 0.227001, … │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 4.571094] │\n", + "│ 0.841471 ┆ 1 ┆ a ┆ 1 ┆ … ┆ -0.806979 ┆ 1 ┆ [1, 215, … 4] ┆ [0.0, │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 0.472495, … │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 3.953833] │\n", + "│ 0.909297 ┆ 2 ┆ a ┆ 1 ┆ … ┆ 0.098134 ┆ 2 ┆ [2, 837, … 213] ┆ [0.0, │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 0.507418, … │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 2.867486] │\n", + "│ 0.14112 ┆ 3 ┆ a ┆ 1 ┆ … ┆ -1.122252 ┆ 3 ┆ [3, 366, … 78] ┆ [0.0, │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 0.094053, … │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 3.338146] │\n", + "│ -0.756802 ┆ 4 ┆ a ┆ 0 ┆ … ┆ -0.82356 ┆ 4 ┆ [4, 196, … 942] ┆ [0.0, │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 2.545978, … │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 3.247336] │\n", + "│ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … │\n", + "│ 0.773833 ┆ 995 ┆ b ┆ 0 ┆ … ┆ -0.80334 ┆ 995 ┆ [995, 234, … 915] ┆ [0.0, │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 0.762144, … │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 3.949634] │\n", + "│ -0.114875 ┆ 996 ┆ b ┆ 0 ┆ … ┆ -0.682739 ┆ 996 ┆ [996, 490, … 827] ┆ [0.0, │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 0.410204, … │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 3.064209] │\n", + "│ -0.897967 ┆ 997 ┆ b ┆ 1 ┆ … ┆ -0.844666 ┆ 997 ┆ [997, 456, … 116] ┆ [0.0, │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 0.796457, … │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 4.244226] │\n", + "│ -0.855473 ┆ 998 ┆ b ┆ 1 ┆ … ┆ -0.088007 ┆ 998 ┆ [998, 769, … 321] ┆ [0.0, │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 0.507682, … │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 2.901211] │\n", + "│ -0.026461 ┆ 999 ┆ b ┆ 1 ┆ … ┆ 0.137398 ┆ 999 ┆ [999, 107, … 586] ┆ [0.0, │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 1.816998, … │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ 2.902915] │\n", + "└───────────┴──────────┴───────┴────────┴───┴───────────┴───────┴────────────────────┴─────────────┘" + ] + }, + "execution_count": 68, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Pointwise Nearest Neighbors with distances\n", + "# we get a struct with an ids column and a dist column\n", + "\n", + "df.with_columns(\n", + " pds.query_bt_knn_ptwise(\n", + " pl.col(\"x2\"), \n", + " pl.col(\"x3\"), \n", + " index=pl.col(\"index\"),\n", + " r=999.0,\n", + " k =5,\n", + " distance_metric=\"haversine\",\n", + " return_dist=True,\n", + " parallel=True).alias(\"ball_tree_knn_ptwise\")\n", + ").unnest(\"ball_tree_knn_ptwise\")" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "id": "d9ce8e89", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (1_000, 14)
ftime_idxdummyactualpredicteddummy_groupsx1x2x3abyindexball_tree_knn_radius_freq_cnt
f64i64stri32f64strf64f64f64f64f64f64u32struct[2]
0.00"a"00.491635"a"0.7116890.6244970.036810.516180.7149250.2388940{871,10}
0.8414711"a"10.01081"a"0.9903790.2739850.6918630.4489780.258321-0.8069791{701,10}
0.9092972"a"10.016883"a"0.0798270.3496220.0125410.3488810.607950.0981342{167,9}
0.141123"a"10.738627"a"0.9838780.3424830.9150680.7223410.762508-1.1222523{362,9}
-0.7568024"a"00.596465"a"0.7607990.3092110.6870140.6561740.942946-0.823564{147,9}
0.773833995"b"00.81024"b"0.8085940.7986730.7761580.7992830.96809-0.80334995{492,1}
-0.114875996"b"00.435576"b"0.1828520.9522570.6639210.685660.079458-0.682739996{600,1}
-0.897967997"b"10.799476"b"0.1310820.3890230.654030.6698340.847043-0.844666997{610,1}
-0.855473998"b"10.12903"b"0.7948420.7948820.2971780.1137990.990062-0.088007998{828,1}
-0.026461999"b"10.209619"b"0.5471870.216120.0063550.1310240.0380850.137398999{933,1}
" + ], + "text/plain": [ + "shape: (1_000, 14)\n", + "┌───────────┬──────────┬───────┬────────┬───┬──────────┬───────────┬───────┬───────────────────────┐\n", + "│ f ┆ time_idx ┆ dummy ┆ actual ┆ … ┆ b ┆ y ┆ index ┆ ball_tree_knn_radius_ │\n", + "│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ freq_cnt │\n", + "│ f64 ┆ i64 ┆ str ┆ i32 ┆ ┆ f64 ┆ f64 ┆ u32 ┆ --- │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ struct[2] │\n", + "╞═══════════╪══════════╪═══════╪════════╪═══╪══════════╪═══════════╪═══════╪═══════════════════════╡\n", + "│ 0.0 ┆ 0 ┆ a ┆ 0 ┆ … ┆ 0.714925 ┆ 0.238894 ┆ 0 ┆ {871,10} │\n", + "│ 0.841471 ┆ 1 ┆ a ┆ 1 ┆ … ┆ 0.258321 ┆ -0.806979 ┆ 1 ┆ {701,10} │\n", + "│ 0.909297 ┆ 2 ┆ a ┆ 1 ┆ … ┆ 0.60795 ┆ 0.098134 ┆ 2 ┆ {167,9} │\n", + "│ 0.14112 ┆ 3 ┆ a ┆ 1 ┆ … ┆ 0.762508 ┆ -1.122252 ┆ 3 ┆ {362,9} │\n", + "│ -0.756802 ┆ 4 ┆ a ┆ 0 ┆ … ┆ 0.942946 ┆ -0.82356 ┆ 4 ┆ {147,9} │\n", + "│ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … │\n", + "│ 0.773833 ┆ 995 ┆ b ┆ 0 ┆ … ┆ 0.96809 ┆ -0.80334 ┆ 995 ┆ {492,1} │\n", + "│ -0.114875 ┆ 996 ┆ b ┆ 0 ┆ … ┆ 0.079458 ┆ -0.682739 ┆ 996 ┆ {600,1} │\n", + "│ -0.897967 ┆ 997 ┆ b ┆ 1 ┆ … ┆ 0.847043 ┆ -0.844666 ┆ 997 ┆ {610,1} │\n", + "│ -0.855473 ┆ 998 ┆ b ┆ 1 ┆ … ┆ 0.990062 ┆ -0.088007 ┆ 998 ┆ {828,1} │\n", + "│ -0.026461 ┆ 999 ┆ b ┆ 1 ┆ … ┆ 0.038085 ┆ 0.137398 ┆ 999 ┆ {933,1} │\n", + "└───────────┴──────────┴───────┴────────┴───┴──────────┴───────────┴───────┴───────────────────────┘" + ] + }, + "execution_count": 69, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Frequency Count\n", + "df.with_columns(\n", + " pds.query_bt_knn_radius_freq_cnt(\n", + " pl.col(\"x2\"), \n", + " pl.col(\"x3\"), \n", + " index=pl.col(\"index\"), \n", + " r=999.0, \n", + " k =5, \n", + " distance_metric=\"haversine\", \n", + " parallel=True\n", + " ).alias(\"ball_tree_knn_radius_freq_cnt\")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "id": "ddf08cd8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (1_000, 14)
ftime_idxdummyactualpredicteddummy_groupsx1x2x3abyindexbtree_knn_avg
f64i64stri32f64strf64f64f64f64f64f64u32f64
0.00"a"00.491635"a"0.7116890.6244970.036810.516180.7149250.2388940131.048471
0.8414711"a"10.01081"a"0.9903790.2739850.6918630.4489780.258321-0.8069791428.336943
0.9092972"a"10.016883"a"0.0798270.3496220.0125410.3488810.607950.098134251.713153
0.141123"a"10.738627"a"0.9838780.3424830.9150680.7223410.762508-1.1222523280.478864
-0.7568024"a"00.596465"a"0.7607990.3092110.6870140.6561740.942946-0.823564108.016655
0.773833995"b"00.81024"b"0.8085940.7986730.7761580.7992830.96809-0.80334995837.903962
-0.114875996"b"00.435576"b"0.1828520.9522570.6639210.685660.079458-0.682739996610.02722
-0.897967997"b"10.799476"b"0.1310820.3890230.654030.6698340.847043-0.844666997723.917782
-0.855473998"b"10.12903"b"0.7948420.7948820.2971780.1137990.990062-0.088007998589.166151
-0.026461999"b"10.209619"b"0.5471870.216120.0063550.1310240.0380850.137398999914.581934
" + ], + "text/plain": [ + "shape: (1_000, 14)\n", + "┌───────────┬──────────┬───────┬────────┬───┬──────────┬───────────┬───────┬───────────────┐\n", + "│ f ┆ time_idx ┆ dummy ┆ actual ┆ … ┆ b ┆ y ┆ index ┆ btree_knn_avg │\n", + "│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │\n", + "│ f64 ┆ i64 ┆ str ┆ i32 ┆ ┆ f64 ┆ f64 ┆ u32 ┆ f64 │\n", + "╞═══════════╪══════════╪═══════╪════════╪═══╪══════════╪═══════════╪═══════╪═══════════════╡\n", + "│ 0.0 ┆ 0 ┆ a ┆ 0 ┆ … ┆ 0.714925 ┆ 0.238894 ┆ 0 ┆ 131.048471 │\n", + "│ 0.841471 ┆ 1 ┆ a ┆ 1 ┆ … ┆ 0.258321 ┆ -0.806979 ┆ 1 ┆ 428.336943 │\n", + "│ 0.909297 ┆ 2 ┆ a ┆ 1 ┆ … ┆ 0.60795 ┆ 0.098134 ┆ 2 ┆ 51.713153 │\n", + "│ 0.14112 ┆ 3 ┆ a ┆ 1 ┆ … ┆ 0.762508 ┆ -1.122252 ┆ 3 ┆ 280.478864 │\n", + "│ -0.756802 ┆ 4 ┆ a ┆ 0 ┆ … ┆ 0.942946 ┆ -0.82356 ┆ 4 ┆ 108.016655 │\n", + "│ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … │\n", + "│ 0.773833 ┆ 995 ┆ b ┆ 0 ┆ … ┆ 0.96809 ┆ -0.80334 ┆ 995 ┆ 837.903962 │\n", + "│ -0.114875 ┆ 996 ┆ b ┆ 0 ┆ … ┆ 0.079458 ┆ -0.682739 ┆ 996 ┆ 610.02722 │\n", + "│ -0.897967 ┆ 997 ┆ b ┆ 1 ┆ … ┆ 0.847043 ┆ -0.844666 ┆ 997 ┆ 723.917782 │\n", + "│ -0.855473 ┆ 998 ┆ b ┆ 1 ┆ … ┆ 0.990062 ┆ -0.088007 ┆ 998 ┆ 589.166151 │\n", + "│ -0.026461 ┆ 999 ┆ b ┆ 1 ┆ … ┆ 0.038085 ┆ 0.137398 ┆ 999 ┆ 914.581934 │\n", + "└───────────┴──────────┴───────┴────────┴───┴──────────┴───────────┴───────┴───────────────┘" + ] + }, + "execution_count": 74, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Average distance of k nearest neighbors\n", + "df.with_columns(\n", + " pds.query_bt_knn_avg(\n", + " pl.col(\"x2\"), \n", + " pl.col(\"x3\"), \n", + " pl.col(\"x1\"), \n", + " index=pl.col(\"index\"), \n", + " r=999.0, \n", + " k =1,\n", + " distance_metric=\"euclidean\", \n", + " parallel=True).alias(\"btree_knn_avg\")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "id": "195210f9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (1_000, 14)
ftime_idxdummyactualpredicteddummy_groupsx1x2x3abyindexbtree_within_radius
f64i64stri32f64strf64f64f64f64f64f64u32u32
0.00"a"00.491635"a"0.7116890.6244970.036810.516180.7149250.238894040
0.8414711"a"10.01081"a"0.9903790.2739850.6918630.4489780.258321-0.806979189
0.9092972"a"10.016883"a"0.0798270.3496220.0125410.3488810.607950.098134245
0.141123"a"10.738627"a"0.9838780.3424830.9150680.7223410.762508-1.122252359
-0.7568024"a"00.596465"a"0.7607990.3092110.6870140.6561740.942946-0.82356490
0.773833995"b"00.81024"b"0.8085940.7986730.7761580.7992830.96809-0.8033499581
-0.114875996"b"00.435576"b"0.1828520.9522570.6639210.685660.079458-0.68273999663
-0.897967997"b"10.799476"b"0.1310820.3890230.654030.6698340.847043-0.84466699790
-0.855473998"b"10.12903"b"0.7948420.7948820.2971780.1137990.990062-0.088007998106
-0.026461999"b"10.209619"b"0.5471870.216120.0063550.1310240.0380850.13739899955
" + ], + "text/plain": [ + "shape: (1_000, 14)\n", + "┌───────────┬──────────┬───────┬────────┬───┬──────────┬───────────┬───────┬─────────────────────┐\n", + "│ f ┆ time_idx ┆ dummy ┆ actual ┆ … ┆ b ┆ y ┆ index ┆ btree_within_radius │\n", + "│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │\n", + "│ f64 ┆ i64 ┆ str ┆ i32 ┆ ┆ f64 ┆ f64 ┆ u32 ┆ u32 │\n", + "╞═══════════╪══════════╪═══════╪════════╪═══╪══════════╪═══════════╪═══════╪═════════════════════╡\n", + "│ 0.0 ┆ 0 ┆ a ┆ 0 ┆ … ┆ 0.714925 ┆ 0.238894 ┆ 0 ┆ 40 │\n", + "│ 0.841471 ┆ 1 ┆ a ┆ 1 ┆ … ┆ 0.258321 ┆ -0.806979 ┆ 1 ┆ 89 │\n", + "│ 0.909297 ┆ 2 ┆ a ┆ 1 ┆ … ┆ 0.60795 ┆ 0.098134 ┆ 2 ┆ 45 │\n", + "│ 0.14112 ┆ 3 ┆ a ┆ 1 ┆ … ┆ 0.762508 ┆ -1.122252 ┆ 3 ┆ 59 │\n", + "│ -0.756802 ┆ 4 ┆ a ┆ 0 ┆ … ┆ 0.942946 ┆ -0.82356 ┆ 4 ┆ 90 │\n", + "│ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … │\n", + "│ 0.773833 ┆ 995 ┆ b ┆ 0 ┆ … ┆ 0.96809 ┆ -0.80334 ┆ 995 ┆ 81 │\n", + "│ -0.114875 ┆ 996 ┆ b ┆ 0 ┆ … ┆ 0.079458 ┆ -0.682739 ┆ 996 ┆ 63 │\n", + "│ -0.897967 ┆ 997 ┆ b ┆ 1 ┆ … ┆ 0.847043 ┆ -0.844666 ┆ 997 ┆ 90 │\n", + "│ -0.855473 ┆ 998 ┆ b ┆ 1 ┆ … ┆ 0.990062 ┆ -0.088007 ┆ 998 ┆ 106 │\n", + "│ -0.026461 ┆ 999 ┆ b ┆ 1 ┆ … ┆ 0.038085 ┆ 0.137398 ┆ 999 ┆ 55 │\n", + "└───────────┴──────────┴───────┴────────┴───┴──────────┴───────────┴───────┴─────────────────────┘" + ] + }, + "execution_count": 75, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Count neighbors within radius\n", + "df.with_columns(\n", + " pds.query_bt_nb_cnt(\n", + " pl.col(\"x2\"),\n", + " pl.col(\"x3\"),\n", + " index=pl.col(\"index\"),\n", + " r=18.0, \n", + " distance_metric=\"haversine\",\n", + " parallel=True).alias(\"btree_within_radius\")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "id": "026350c4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (1_000, 14)
ftime_idxdummyactualpredicteddummy_groupsx1x2x3abyindexbtree_within_dist_from
f64i64stri32f64strf64f64f64f64f64f64u32bool
0.00"a"00.491635"a"0.7116890.6244970.036810.516180.7149250.2388940true
0.8414711"a"10.01081"a"0.9903790.2739850.6918630.4489780.258321-0.8069791true
0.9092972"a"10.016883"a"0.0798270.3496220.0125410.3488810.607950.0981342true
0.141123"a"10.738627"a"0.9838780.3424830.9150680.7223410.762508-1.1222523true
-0.7568024"a"00.596465"a"0.7607990.3092110.6870140.6561740.942946-0.823564true
0.773833995"b"00.81024"b"0.8085940.7986730.7761580.7992830.96809-0.80334995true
-0.114875996"b"00.435576"b"0.1828520.9522570.6639210.685660.079458-0.682739996true
-0.897967997"b"10.799476"b"0.1310820.3890230.654030.6698340.847043-0.844666997true
-0.855473998"b"10.12903"b"0.7948420.7948820.2971780.1137990.990062-0.088007998true
-0.026461999"b"10.209619"b"0.5471870.216120.0063550.1310240.0380850.137398999true
" + ], + "text/plain": [ + "shape: (1_000, 14)\n", + "┌───────────┬──────────┬───────┬────────┬───┬──────────┬───────────┬───────┬───────────────────────┐\n", + "│ f ┆ time_idx ┆ dummy ┆ actual ┆ … ┆ b ┆ y ┆ index ┆ btree_within_dist_fro │\n", + "│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ m │\n", + "│ f64 ┆ i64 ┆ str ┆ i32 ┆ ┆ f64 ┆ f64 ┆ u32 ┆ --- │\n", + "│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ┆ bool │\n", + "╞═══════════╪══════════╪═══════╪════════╪═══╪══════════╪═══════════╪═══════╪═══════════════════════╡\n", + "│ 0.0 ┆ 0 ┆ a ┆ 0 ┆ … ┆ 0.714925 ┆ 0.238894 ┆ 0 ┆ true │\n", + "│ 0.841471 ┆ 1 ┆ a ┆ 1 ┆ … ┆ 0.258321 ┆ -0.806979 ┆ 1 ┆ true │\n", + "│ 0.909297 ┆ 2 ┆ a ┆ 1 ┆ … ┆ 0.60795 ┆ 0.098134 ┆ 2 ┆ true │\n", + "│ 0.14112 ┆ 3 ┆ a ┆ 1 ┆ … ┆ 0.762508 ┆ -1.122252 ┆ 3 ┆ true │\n", + "│ -0.756802 ┆ 4 ┆ a ┆ 0 ┆ … ┆ 0.942946 ┆ -0.82356 ┆ 4 ┆ true │\n", + "│ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … │\n", + "│ 0.773833 ┆ 995 ┆ b ┆ 0 ┆ … ┆ 0.96809 ┆ -0.80334 ┆ 995 ┆ true │\n", + "│ -0.114875 ┆ 996 ┆ b ┆ 0 ┆ … ┆ 0.079458 ┆ -0.682739 ┆ 996 ┆ true │\n", + "│ -0.897967 ┆ 997 ┆ b ┆ 1 ┆ … ┆ 0.847043 ┆ -0.844666 ┆ 997 ┆ true │\n", + "│ -0.855473 ┆ 998 ┆ b ┆ 1 ┆ … ┆ 0.990062 ┆ -0.088007 ┆ 998 ┆ true │\n", + "│ -0.026461 ┆ 999 ┆ b ┆ 1 ┆ … ┆ 0.038085 ┆ 0.137398 ┆ 999 ┆ true │\n", + "└───────────┴──────────┴───────┴────────┴───┴──────────┴───────────┴───────┴───────────────────────┘" + ] + }, + "execution_count": 76, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# returns true if the row is within k nearest neighbors of the given point\n", + "df.with_columns(\n", + " pds.bt_within_dist_from(\n", + " pl.col(\"x2\"),\n", + " pl.col(\"x3\"),\n", + " pt=[123,78.99],\n", + " r=999999, \n", + " distance_metric=\"haversine\",\n", + " parallel=True).alias(\"btree_within_dist_from\")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "id": "dbe0983f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "shape: (1_000, 14)
ftime_idxdummyactualpredicteddummy_groupsx1x2x3abyindexbtree_nn_within
f64i64stri32f64strf64f64f64f64f64f64u32bool
0.00"a"00.491635"a"0.7116890.6244970.036810.516180.7149250.2388940true
0.8414711"a"10.01081"a"0.9903790.2739850.6918630.4489780.258321-0.8069791false
0.9092972"a"10.016883"a"0.0798270.3496220.0125410.3488810.607950.0981342true
0.141123"a"10.738627"a"0.9838780.3424830.9150680.7223410.762508-1.1222523false
-0.7568024"a"00.596465"a"0.7607990.3092110.6870140.6561740.942946-0.823564false
0.773833995"b"00.81024"b"0.8085940.7986730.7761580.7992830.96809-0.80334995false
-0.114875996"b"00.435576"b"0.1828520.9522570.6639210.685660.079458-0.682739996false
-0.897967997"b"10.799476"b"0.1310820.3890230.654030.6698340.847043-0.844666997false
-0.855473998"b"10.12903"b"0.7948420.7948820.2971780.1137990.990062-0.088007998true
-0.026461999"b"10.209619"b"0.5471870.216120.0063550.1310240.0380850.137398999true
" + ], + "text/plain": [ + "shape: (1_000, 14)\n", + "┌───────────┬──────────┬───────┬────────┬───┬──────────┬───────────┬───────┬─────────────────┐\n", + "│ f ┆ time_idx ┆ dummy ┆ actual ┆ … ┆ b ┆ y ┆ index ┆ btree_nn_within │\n", + "│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │\n", + "│ f64 ┆ i64 ┆ str ┆ i32 ┆ ┆ f64 ┆ f64 ┆ u32 ┆ bool │\n", + "╞═══════════╪══════════╪═══════╪════════╪═══╪══════════╪═══════════╪═══════╪═════════════════╡\n", + "│ 0.0 ┆ 0 ┆ a ┆ 0 ┆ … ┆ 0.714925 ┆ 0.238894 ┆ 0 ┆ true │\n", + "│ 0.841471 ┆ 1 ┆ a ┆ 1 ┆ … ┆ 0.258321 ┆ -0.806979 ┆ 1 ┆ false │\n", + "│ 0.909297 ┆ 2 ┆ a ┆ 1 ┆ … ┆ 0.60795 ┆ 0.098134 ┆ 2 ┆ true │\n", + "│ 0.14112 ┆ 3 ┆ a ┆ 1 ┆ … ┆ 0.762508 ┆ -1.122252 ┆ 3 ┆ false │\n", + "│ -0.756802 ┆ 4 ┆ a ┆ 0 ┆ … ┆ 0.942946 ┆ -0.82356 ┆ 4 ┆ false │\n", + "│ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … │\n", + "│ 0.773833 ┆ 995 ┆ b ┆ 0 ┆ … ┆ 0.96809 ┆ -0.80334 ┆ 995 ┆ false │\n", + "│ -0.114875 ┆ 996 ┆ b ┆ 0 ┆ … ┆ 0.079458 ┆ -0.682739 ┆ 996 ┆ false │\n", + "│ -0.897967 ┆ 997 ┆ b ┆ 1 ┆ … ┆ 0.847043 ┆ -0.844666 ┆ 997 ┆ false │\n", + "│ -0.855473 ┆ 998 ┆ b ┆ 1 ┆ … ┆ 0.990062 ┆ -0.088007 ┆ 998 ┆ true │\n", + "│ -0.026461 ┆ 999 ┆ b ┆ 1 ┆ … ┆ 0.038085 ┆ 0.137398 ┆ 999 ┆ true │\n", + "└───────────┴──────────┴───────┴────────┴───┴──────────┴───────────┴───────┴─────────────────┘" + ] + }, + "execution_count": 77, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# check if a point is included in the k nearest neighbors of the given point\n", + "# Note: This is an exact match so your point could be quite close but now show up\n", + "# We can pass in an EPSILON to account for this. THis is defaulted to the EPSILON of the data type in Rust\n", + "\n", + "# pick a random point\n", + "point = df.select(pl.col(\"x2\"), pl.col(\"x3\"))\n", + "rr = list(point.row(2))\n", + "df.with_columns(\n", + " pds.is_bt_knn_from(\n", + " pl.col(\"x2\"),\n", + " pl.col(\"x3\"),\n", + " pt=rr, \n", + " k=56, \n", + " distance_metric=\"haversine\",\n", + " parallel=True, epsilon=0.5).alias(\"btree_nn_within\")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f7cc8f20", + "metadata": {}, "outputs": [], "source": [] } @@ -2896,7 +3400,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.8" + "version": "3.10.3" } }, "nbformat": 4, diff --git a/python/polars_ds/__init__.py b/python/polars_ds/__init__.py index d127e367..34f064ee 100644 --- a/python/polars_ds/__init__.py +++ b/python/polars_ds/__init__.py @@ -11,6 +11,7 @@ from polars_ds.features import * # noqa: F403 from polars_ds.query_knn import * # noqa: F403 from polars_ds.query_linear import * # noqa: F403 +from polars_ds.query_balltree import * # noqa: F403 __version__ = "0.6.0" diff --git a/python/polars_ds/query_balltree.py b/python/polars_ds/query_balltree.py new file mode 100644 index 00000000..8b7e7b83 --- /dev/null +++ b/python/polars_ds/query_balltree.py @@ -0,0 +1,421 @@ +from __future__ import annotations +from typing import Iterable +import polars as pl +from .type_alias import StrOrExpr, str_to_expr, Distance +from ._utils import pl_plugin + + +def query_bt_knn_ptwise( + *features: str | pl.Expr, + index: str | pl.Expr, + r: float, + distance_metric: Distance = "euclidean", + sort: bool = True, + parallel: bool = False, + k: int = None, + return_dist: bool = False, +) -> pl.Expr: + """ + Takes an index column, uses the feature columns to determine distance and finds all neighbors + within distance r from each id. It returns a list containing (neighbor_id, distance) tuples. + This is bounded by max k neighbors. + index columns must be convertible to u32 + + Parameters + ---------- + *features : str | pl.Expr + The feature columns. + index : str | pl.Expr + The index column. + r : float + The radius. + distance_metric : Distance, optional + The distance metric to use, by default "euclidean". Currently the only available options are "euclidean" and "haversine". + sort : bool, optional + Whether to sort the output, by default True. + parallel : bool, optional + Whether to evaluate this in parallel, by default False. + k : int, optional + """ + if r < 0: + raise ValueError("r must be positive") + elif isinstance(r, pl.Expr): + raise ValueError( + "r must be a scalar float. Expressions are not supported.") + + if distance_metric.lower() not in ("euclidean", "haversine"): + raise ValueError( + "Invalid distance metric. Must be 'euclidean' or 'haversine'.") + if len(features) == 0: + raise ValueError("Must provide at least one feature column.") + elif len(features) != 2 and distance_metric == "haversine": + raise ValueError("Haversine distance requires exactly 2 features.") + + idx = str_to_expr(index).cast(pl.UInt32).rechunk() + + if not k: + # the length of any of the feature columns + k = pl.col(features[0]).len() + if k < 1: + raise ValueError("k must be positive.") + + # Columns to send over to Rust as &[Series] + # First column will always be the index column + cols = [idx] + cols.extend([str_to_expr(f) for f in features]) + if return_dist: + return pl_plugin( + symbol="pl_query_knn_ptwise_wdist", + args=cols, + kwargs={"r": r, "distance_metric": distance_metric.lower( + ), "sort": sort, "parallel": parallel, "k": k}, + is_elementwise=False, + ) + else: + return pl_plugin( + symbol="pl_query_knn_ptwise", + args=cols, + kwargs={"r": r, "distance_metric": distance_metric.lower( + ), "sort": sort, "parallel": parallel, "k": k}, + is_elementwise=False, + ) + + +def query_bt_knn_radius_freq_cnt( + *features: str | pl.Expr, + index: str | pl.Expr, + r: float = None, + distance_metric: Distance = "euclidean", + sort: bool = True, + parallel: bool = False, + k: int = None, +) -> pl.Expr: + """ + Returns the frequency count of neighbors within distance r, and within k nearest neighbors for each row using the given distance metric. + The point is always a neighbor of itself. + We need an index column that can be cast to u32. + + Parameters + ---------- + + *features : str | pl.Expr + The feature columns. + index : str | pl.Expr + The index column. + r : float, optional + The radius. If None, it will be set to infinity, by default None. + distance_metric : Distance, optional + The distance metric to use, by default "euclidean". Currently the only available options are "euclidean" and "haversine". + sort : bool, optional + Whether to sort the output, by default True. + parallel : bool, optional + Whether to evaluate this in parallel, by default False. + k : int, optional + The number of nearest neighbors to consider, by default 1. + """ + if not k: + # the length of any of the feature columns + k = pl.col(features[0]).len() + if not r: + r = float("inf") + if r < 0: + raise ValueError("r must be positive") + elif isinstance(r, pl.Expr): + raise ValueError( + "r must be a scalar float. Expressions are not supported.") + + if distance_metric.lower() not in ("euclidean", "haversine"): + raise ValueError( + "Invalid distance metric. Must be 'euclidean' or 'haversine'.") + if len(features) == 0: + raise ValueError("Must provide at least one feature column.") + elif len(features) != 2 and distance_metric == "haversine": + raise ValueError("Haversine distance requires exactly 2 features.") + + if k < 1: + raise ValueError("k must be positive.") + idx = str_to_expr(index).cast(pl.UInt32).rechunk() + # Columns to send over to Rust as &[Series] + # First column will always be the index column + cols = [idx] + cols.extend([str_to_expr(f) for f in features]) + knn_expr: pl.Expr = query_bt_knn_ptwise( + *features, index=index, r=r, distance_metric=distance_metric, sort=sort, parallel=parallel, k=k, return_dist=False) + return knn_expr.explode().drop_nulls().value_counts(sort=True, parallel=parallel) + + +def query_bt_knn_avg( + *features: str | pl.Expr, + index: str | pl.Expr, + r: float, + distance_metric: Distance = "euclidean", + sort: bool = True, + parallel: bool = False, + k: int = None, +) -> pl.Expr: + """Takes an index column, uses the feature columns to determine distance and finds all neighbors + within distance r from each id. It returns a list containing (neighbor_id, distance) tuples. + index columns must be convertible to u32 + + Args: + index (str | pl.Expr): _description_ + r (float): _description_ + dist (Distance, optional): _description_. Defaults to "euclidian". + sort (bool, optional): _description_. Defaults to True. + parallel (bool, optional): _description_. Defaults to False. + + Returns: + pl.Expr: _description_ + """ + if not k: + # the length of any of the feature columns + k = pl.col(features[0]).len() + if r < 0: + raise ValueError("r must be positive") + elif isinstance(r, pl.Expr): + raise ValueError( + "r must be a scalar float. Expressions are not supported.") + + if distance_metric.lower() not in ("euclidean", "haversine"): + raise ValueError( + "Invalid distance metric. Must be 'euclidean' or 'haversine'.") + if len(features) == 0: + raise ValueError("Must provide at least one feature column.") + elif len(features) != 2 and distance_metric == "haversine": + raise ValueError("Haversine distance requires exactly 2 features.") + + if k < 1: + raise ValueError("k must be positive.") + idx = str_to_expr(index).cast(pl.UInt32).rechunk() + # Columns to send over to Rust as &[Series] + # First column will always be the index column + cols = [idx] + cols.extend([str_to_expr(f) for f in features]) + return pl_plugin( + symbol="pl_ball_tree_knn_avg", + args=cols, + kwargs={"r": r, "distance_metric": distance_metric.lower( + ), "sort": sort, "parallel": parallel, "k": k}, + is_elementwise=False, + ) + + +def bt_within_dist_from( + *features: str | pl.Expr, + pt: Iterable[float], + r: float | str | pl.Expr, + distance_metric: Distance = "euclidean", + parallel: bool = False, +) -> pl.Expr: + """ + Returns a boolean column indicating if the provided point is within distance r. + + Parameters + ---------- + *features : str | pl.Expr + The feature columns. + pt : Iterable[float] + The point to compare against. + r : float | str | pl.Expr + The radius. Either a scalar float, or a 1d array with len = row_count(X). + distance_metric : Distance, optional + The distance metric to use, by default "euclidean". Currently the only available + options are "euclidean" and "haversine". + """ + if distance_metric.lower() not in ("euclidean", "haversine"): + raise ValueError( + "Invalid distance metric. Must be 'euclidean' or 'haversine'.") + if len(features) < 2: + raise ValueError("Must provide at least two feature columns.") + if len(pt) != len(features): + raise ValueError( + "Number of features must match the number of dimensions") + + if isinstance(r, (float, int)): + rad = pl.lit(pl.Series(values=[r], dtype=pl.Float64)) + elif isinstance(r, pl.Expr): + rad = r + elif isinstance(r, str): + rad = pl.col(r) + else: + rad = pl.lit(pl.Series(values=r, dtype=pl.Float64)) + cols = [rad] + cols.extend([str_to_expr(f) for f in features]) + return pl_plugin( + symbol="pl_bt_within_dist_from", + args=cols, + kwargs={"point": pt, "distance_metric": distance_metric.lower(), + "parallel": parallel}, + ) + + +def is_bt_knn_from( + *features, + pt: Iterable[float], + k: int, + distance_metric: Distance = "euclidean", + parallel: bool = False, + epsilon: float = None, +) -> pl.Expr: + """ + Returns a boolean column indicating if the provided point is within the k-nearest neighbors. + + Parameters + ---------- + *features : str | pl.Expr + The feature columns. + pt : Iterable[float] + The point to compare against. + k : int + Number of neighbors to consider. + distance_metric : Distance, optional + The distance metric to use, by default "euclidean". Currently the only available + options are "euclidean" and "haversine". + parallel : bool, optional + """ + if distance_metric.lower() not in ("euclidean", "haversine"): + raise ValueError( + "Invalid distance metric. Must be 'euclidean' or 'haversine'.") + if len(features) < 2: + raise ValueError("Must provide at least two feature columns.") + if len(pt) != len(features): + raise ValueError( + "Number of features must match the number of dimensions") + if k < 1: + raise ValueError("k must be positive.") + cols = [str_to_expr(f) for f in features] + return pl_plugin( + symbol="pl_bt_knn_from", + args=cols, + kwargs={"point": pt, "distance_metric": distance_metric.lower( + ), "parallel": parallel, "k": k, "epsilon": epsilon}, + ) + + +def query_bt_nb_cnt( + *features: str | pl.Expr, + r: float | str | Iterable[float], + index: str | pl.Expr, + distance_metric: Distance = "euclidean", + parallel: bool = False +) -> pl.Expr: + """ + Returns the number of neighbors within ( <= ) radius r for each row using the given distance metric. + The point is always a neighbor of itself. + + Parameters + ---------- + r : float | str | Iterable[float] + The radius. Either a scalar float, or a 1d array with len = row_count(X). + index : str | pl.Expr + The index column. + *features : str | pl.Expr + The feature columns. + distance_metric : Distance, optional + The distance metric to use, by default "euclidean". Currently the only available + options are "euclidean" and "haversine". + parallel : bool, optional + Whether to evaluate this in parallel, by default False + """ + if distance_metric.lower() not in ("euclidean", "haversine"): + raise ValueError( + "Invalid distance metric. Must be 'euclidean' or 'haversine'.") + if len(features) == 0: + raise ValueError("Must provide at least one feature column.") + elif len(features) != 2 and distance_metric == "haversine": + raise ValueError("Haversine distance requires exactly 2 features.") + + if isinstance(r, (float, int)): + rad = pl.lit(pl.Series(values=[r], dtype=pl.Float64)) + elif isinstance(r, pl.Expr): + rad = r + elif isinstance(r, str): + rad = pl.col(r) + else: + rad = pl.lit(pl.Series(values=r, dtype=pl.Float64)) + + idx = str_to_expr(index).cast(pl.UInt32).rechunk() + cols = [idx, rad] + cols.extend([str_to_expr(f) for f in features]) + return pl_plugin( + symbol="pl_nb_count", + args=cols, + kwargs={ + "distance_metric": distance_metric.lower(), + "parallel": parallel, + # Stubbed values so we can use the + # same kwargs struct in rust + "sort": True, + "r": 0, + "k": 1, + }, + ) + + +def query_bt_radius_ptwise( + *features: str | pl.Expr, + index: str | pl.Expr, + r: float, + distance_metric: Distance = "euclidean", + sort: bool = True, + parallel: bool = False, + k=None, + return_dist: bool = False, +) -> pl.Expr: + """ + Returns a list of neighbors within distance r from each id. + index columns must be convertible to u32 + + Parameters + ---------- + *features : str | pl.Expr + The feature columns. + index : str | pl.Expr + The index column. + r : float + The radius. + distance_metric : Distance, optional + The distance metric to use, by default "euclidean". Currently the only available options are "euclidean" and "haversine". + sort : bool, optional + Whether to sort the output, by default True. + parallel : bool, optional + Whether to evaluate this in parallel, by default False. + k : int, optional + Max number of neighbors to consider. + """ + return query_bt_knn_ptwise(*features, index=index, r=r, distance_metric=distance_metric, sort=sort, parallel=parallel, k=k, return_dist=return_dist) + + +def query_bt_radius_freq_cnt( + *features: str | pl.Expr, + index: str | pl.Expr, + r: float = None, + distance_metric: Distance = "euclidean", + sort: bool = True, + parallel: bool = False, + k: int = None, +) -> pl.Expr: + """ + Returns the frequency count of neighbors within distance r, and within k nearest neighbors for each row using the given distance metric. + The point is always a neighbor of itself. + We need an index column that can be cast to u32. + + Parameters + ---------- + + *features : str | pl.Expr + The feature columns. + index : str | pl.Expr + The index column. + r : float, optional + The radius. If None, it will be set to infinity, by default None. + distance_metric : Distance, optional + The distance metric to use, by default "euclidean". Currently the only available options are "euclidean" and "haversine". + sort : bool, optional + Whether to sort the output, by default True. + parallel : bool, optional + Whether to evaluate this in parallel, by default False. + k : int, optional + The number of nearest neighbors to consider, by default 1. + """ + return query_bt_knn_radius_freq_cnt(*features, index=index, r=r, distance_metric=distance_metric, sort=sort, parallel=parallel, k=k) diff --git a/src/arkadia/leaf.rs b/src/arkadia/leaf.rs index 7334e719..7d0da476 100644 --- a/src/arkadia/leaf.rs +++ b/src/arkadia/leaf.rs @@ -1,6 +1,6 @@ use num::Float; -#[derive(Clone, Copy)] +#[derive(Clone, Copy, Debug)] pub struct Leaf<'a, T: Float, A> { pub item: A, pub row_vec: &'a [T], diff --git a/src/num_ext/ball_tree.rs b/src/num_ext/ball_tree.rs new file mode 100644 index 00000000..84728901 --- /dev/null +++ b/src/num_ext/ball_tree.rs @@ -0,0 +1,1368 @@ +/// Performs KNN related search queries, classification and regression, and +/// other features/entropies that require KNN to be efficiently computed. +use crate::{ + arkadia::Leaf, + utils::{series_to_row_major_slice, split_offsets}, +}; + +use num::Float; +use polars::prelude::*; +use pyo3_polars::{ + derive::{polars_expr, CallerContext}, + export::polars_core::{ + utils::rayon::prelude::{IntoParallelIterator, ParallelIterator}, + POOL, + }, +}; +use serde::{Deserialize, Serialize}; +use std::{collections::BinaryHeap, fmt::Debug, usize}; + +#[derive(Debug, Clone, PartialEq)] +pub enum DistanceMetrics { + Haversine(fn(&[T], &[T]) -> T), + Euclidean(fn(&[T], &[T]) -> T), +} +/// Enum for valid distance metrics +/// More metrics can be added here +impl DistanceMetrics { + pub fn calculate(&self, a: &[T], b: &[T]) -> T { + match self { + DistanceMetrics::Haversine(f) => f(a, b), + DistanceMetrics::Euclidean(f) => f(a, b), + } + } +} +/// Haversine is only valid for 2D points +fn haversine_distance(a: &[T], b: &[T]) -> T { + assert!( + a.len() == 2 && b.len() == 2, + "Haversine distance requires 2D points", + ); + super::haversine_elementwise(a[0], a[1], b[0], b[1]) +} + +fn euclidean_distance(a: &[T], b: &[T]) -> T { + let sum_of_squares = a + .iter() + .zip(b) + .map(|(a, b)| { + let a_f64 = a.to_f64().unwrap(); + let b_f64 = b.to_f64().unwrap(); + (a_f64 - b_f64).powi(2) + }) + .sum::(); + + T::from(sum_of_squares.sqrt()).unwrap() +} + +impl<'a, T, A> Leaf<'a, T, A> +where + T: Float, +{ + /// distance between two leaves + pub fn distance(&self, other: &Leaf<'a, T, A>, metric: &DistanceMetrics) -> T { + metric.calculate(self.row_vec, other.row_vec) + } + + /// Since Leaf was used from arcadia + /// We need to use unsafe to get a mutable reference to the slice + /// This is because row_vec is a reference rather than being an owned value + pub unsafe fn move_towards( + &self, + other: &Leaf<'a, T, A>, + d: f64, + metric: &DistanceMetrics, + ) -> Self + where + T: Float, + A: Clone, + { + let distance = self.distance(other, metric).to_f64().unwrap(); + if distance == 0.0 { + return self.clone(); + } + let scale = d / distance; + + let new_self = self.clone(); + + // Use unsafe to get a mutable reference to the slice + let row_vec_ptr = new_self.row_vec.as_ptr() as *mut T; + let row_vec_len = new_self.row_vec.len(); + let row_vec_mut = std::slice::from_raw_parts_mut(row_vec_ptr, row_vec_len); + for i in 0..row_vec_len { + let a_f64 = row_vec_mut[i].to_f64().unwrap(); + let b_f64 = other.row_vec[i].to_f64().unwrap(); + row_vec_mut[i] = T::from(a_f64 + (b_f64 - a_f64) * scale).unwrap(); + } + new_self + } +} + +fn midpoint<'a, T, A>( + a: &Leaf<'a, T, A>, + b: &Leaf<'a, T, A>, + metric: &DistanceMetrics, +) -> Leaf<'a, T, A> +where + T: Float, + A: Clone, +{ + /// Midpoint between two given leaves. + let d = a.distance(b, metric).to_f64().unwrap(); + unsafe { a.move_towards(b, d / 2.0, metric) } +} + +#[derive(Debug, Clone, PartialEq)] +struct Sphere { + /// A sphere is the "balls" this stores a center coordinates, a radius and a distance metric + center: C, + radius: f64, + metric: DistanceMetrics, +} + +/// OrdF64 is a wrapper around f64 that implements Ord and Eq +#[derive(Debug, Clone, PartialEq, PartialOrd)] +struct OrdF64(f64); + +impl OrdF64 { + fn new(f: f64) -> Self { + assert!(!f.is_nan(), "We can not compare NaN values"); + Self(f) + } +} + +impl Eq for OrdF64 {} +impl Ord for OrdF64 { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.partial_cmp(other).unwrap() + } +} + +impl<'a, T, A> Sphere, T> +where + T: Float, +{ + fn nearest_distance(&self, leaf: &Leaf<'a, T, A>) -> f64 { + let d = self.center.distance(leaf, &self.metric).to_f64().unwrap() - self.radius; + d.max(0.0) + } + + fn farthest_distance(&self, leaf: &Leaf<'a, T, A>) -> f64 { + self.center.distance(leaf, &self.metric).to_f64().unwrap() + self.radius + } +} +/// Uses a Bouncing ball algorithm to determine a tight bounding sphere around the given leaves +fn bounding_sphere<'a, T, A>( + leaves: &[Leaf<'a, T, A>], + metric: &DistanceMetrics, +) -> Sphere, T> +where + T: Float, + A: Clone, +{ + assert!( + leaves.len() >= 2, + "Bounding sphere requires at least 2 leaves" + ); + let a = leaves + .iter() + .max_by_key(|a| OrdF64::new(leaves[0].distance(a, metric).to_f64().unwrap())) + .unwrap(); + let b = leaves + .iter() + .max_by_key(|b| OrdF64::new(a.distance(b, metric).to_f64().unwrap())) + .unwrap(); + let mut center: Leaf<'a, T, A> = midpoint(a, b, metric); + let mut radius = center + .distance(b, metric) + .to_f64() + .unwrap() + .max(std::f64::EPSILON); + + loop { + match leaves + .iter() + .filter(|p| center.distance(p, metric).to_f64().unwrap() > radius) + .next() + { + None => { + break Sphere { + center, + radius, + metric: metric.clone(), + } + } + Some(p) => { + let c_to_p = center.distance(&p, metric).to_f64().unwrap(); + let d = c_to_p - radius; + center = unsafe { center.move_towards(p, d, metric) }; + radius = radius * 1.01; + } + } + } +} + +/// Partition the leaves into two Left and Right +fn partition<'a, T, A>( + mut leaves: Vec>, + metric: &DistanceMetrics, +) -> (Vec>, Vec>) +where + T: Float, + A: Clone, +{ + assert!(leaves.len() >= 2, "Partition requires at least 2 leaves"); + + let a_i = leaves + .iter() + .enumerate() + .max_by_key(|(_, a)| OrdF64::new(leaves[0].distance(a, metric).to_f64().unwrap())) + .unwrap() + .0; + + let b_i = leaves + .iter() + .enumerate() + .max_by_key(|(_, b)| OrdF64::new(leaves[a_i].distance(b, metric).to_f64().unwrap())) + .unwrap() + .0; + + let (a_i, b_i) = (a_i.max(b_i), a_i.min(b_i)); + + let mut aps = vec![leaves.swap_remove(a_i)]; + let mut bps = vec![leaves.swap_remove(b_i)]; + + for p in leaves.iter() { + if aps[0].distance(p, &metric) < bps[0].distance(p, metric) { + aps.push(p.clone()); + } else { + bps.push(p.clone()); + } + } + (aps, bps) +} + +/// Inner Ball Tree. This is the recursive structure of the Ball Tree +/// It can be a Leaf, Branch or Empty +/// The new function recursively calls itself till each input Leaf is a Leaf in the BallTree +#[derive(Clone, Debug)] +enum BallTreeInner<'a, T, A> +where + T: Float, +{ + Empty, + Leaf(Leaf<'a, T, A>), + Branch { + sphere: Sphere, T>, + count: usize, + left: Box>, + right: Box>, + }, +} + +impl<'a, T, A> Default for BallTreeInner<'a, T, A> +where + T: Float, +{ + fn default() -> Self { + BallTreeInner::Empty + } +} + +impl<'a, T, A> BallTreeInner<'a, T, A> +where + T: Float, + A: Clone, +{ + /// Iterates over the Leaves and recursively calls itseld to create a BallTree + /// Only the Branch nodes have children and recursively call BallTreeInner + fn new(mut leaves: Vec>, metric: &DistanceMetrics) -> Self { + if leaves.is_empty() { + return BallTreeInner::Empty; + } else if leaves.iter().all(|p| p.row_vec == leaves[0].row_vec) { + return BallTreeInner::Leaf(leaves.pop().unwrap()); + } else { + let count = leaves.len(); + let sphere = bounding_sphere(&leaves, &metric); + let (left, right) = if leaves.len() > 2 { + partition(leaves, metric) + } else { + (vec![leaves[0].clone()], vec![leaves[1].clone()]) + }; + BallTreeInner::Branch { + sphere, + count, + left: Box::new(BallTreeInner::new(left, metric)), + right: Box::new(BallTreeInner::new(right, metric)), + } + } + } + + fn nearest_distance(&self, leaf: &Leaf<'a, T, A>, metric: &DistanceMetrics) -> f64 { + match self { + BallTreeInner::Empty => std::f64::INFINITY, + BallTreeInner::Leaf(l) => leaf.distance(l, metric).to_f64().unwrap(), + BallTreeInner::Branch { + sphere, + left, + right, + .. + } => { + let d = sphere.nearest_distance(leaf); + let d_left = left.as_ref().nearest_distance(leaf, metric); + let d_right = right.as_ref().nearest_distance(leaf, metric); + d.min(d_left).min(d_right) + } + } + } +} + +#[derive(Debug, Clone, Copy)] +struct Item(f64, T); +impl PartialEq for Item { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for Item {} + +impl PartialOrd for Item { + fn partial_cmp(&self, other: &Self) -> Option { + self.0.partial_cmp(&other.0).map(|o| o.reverse()) + } +} + +impl Ord for Item { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.partial_cmp(other).unwrap() + } +} + +// Iterator over Nearest Neighbors +#[derive(Debug)] +pub struct NNIter<'tree, 'query, T, A> +where + T: Float, + A: Clone, +{ + leaf: &'query Leaf<'query, T, A>, + balls: &'query mut BinaryHeap>>, + i: usize, + max_radius: f64, + metric: DistanceMetrics, +} + +impl<'tree, 'query, T, A> Iterator for NNIter<'tree, 'query, T, A> +where + T: Float, + A: Clone, +{ + type Item = (&'tree Leaf<'tree, T, A>, f64); + + fn next(&mut self) -> Option { + while self.balls.len() > 0 { + let _bb = self.balls.peek().unwrap(); + // convert self.balls into a vec of BallTreeInner. extract distance and BallTreeInner from Item + let _balls = self.balls.clone().into_sorted_vec(); + if let Item(d, BallTreeInner::Leaf(l)) = self.balls.peek().unwrap() { + if *d <= self.max_radius { + // Since the item has been consumed i wold like to remove it from self.balls + let (d, _lx) = { + let item = self.balls.pop().unwrap(); + (item.0, item.1) + }; + return Some((l, d)); + } + } + self.i = 0; + // Extend Branch nodes + if let Item(_, BallTreeInner::Branch { left, right, .. }) = self.balls.pop().unwrap() { + let d_a = left.as_ref().nearest_distance(self.leaf, &self.metric); + let d_b = right.as_ref().nearest_distance(self.leaf, &self.metric); + if d_a < self.max_radius { + self.balls.push(Item(d_a, left)); + } + if d_b < self.max_radius { + self.balls.push(Item(d_b, right)); + } + } + } + None + } +} + +#[derive(Debug, Clone)] +pub struct BallTree<'a, T: Float, A: Clone>(BallTreeInner<'a, T, A>, DistanceMetrics); + +impl<'a, T, A> BallTree<'a, T, A> +where + T: Float, + A: Clone, +{ + pub fn new(leaves: Vec>, distance_metric: DistanceMetrics) -> Self { + Self( + BallTreeInner::new(leaves, &distance_metric), + distance_metric.clone(), + ) + } + + pub fn query(&self) -> Query { + Query { + ball_tree: self, + balls: Default::default(), + metric: self.1.clone(), + } + } +} + +/// The query struct provides access to the inner tree. +/// Since the BallTree is expensive to create but relatively cheap to query +/// We use the query struct to query the BallTree +#[derive(Debug, Clone)] +pub struct Query<'a, T: Float, A: Clone> { + ball_tree: &'a BallTree<'a, T, A>, + balls: BinaryHeap>>, + metric: DistanceMetrics, +} + +/// The _lazy functions only return an iterator. +/// The code heavily relies on nn_within_lazy since this is the core function to traverse the tree +impl<'a, T, A> Query<'a, T, A> +where + T: Float, + A: Clone, +{ + pub fn nn_lazy<'query>( + &'query mut self, + leaf: &'query Leaf<'query, T, A>, + ) -> NNIter<'a, 'query, T, A> { + self.nn_within_lazy(leaf, f64::INFINITY) + } + + pub fn nn_within_lazy<'query>( + &'query mut self, + leaf: &'query Leaf<'query, T, A>, + max_radius: f64, + ) -> NNIter<'a, 'query, T, A> { + let balls = &mut self.balls; + balls.clear(); + balls.push(Item( + self.ball_tree.0.nearest_distance(leaf, &self.metric), + &self.ball_tree.0, + )); + NNIter { + leaf, + balls, + i: 0, + max_radius, + metric: self.metric.clone(), + } + } + // use nn_within_lazy to implement nn_within + pub fn nn_within<'query>( + &'query mut self, + leaf: &'query Leaf<'query, T, A>, + k: usize, + max_radius: f64, + ) -> Vec<(&'query Leaf<'query, T, A>, f64)> { + let iter = self.nn_within_lazy(leaf, max_radius); + iter.take(k).collect() + } + + // Return the k nearest neighbors with distances + pub fn knn<'query>( + &'query mut self, + leaf: &'query Leaf<'query, T, A>, + k: usize, + ) -> Vec<(&'query Leaf<'query, T, A>, f64)> { + let iter = self.nn_lazy(leaf); + iter.take(k).collect() + } + + pub fn is_knn<'query>( + &'query mut self, + leaf: &'query Leaf<'query, T, A>, + other_point: &[T], + k: usize, + max_radius: Option, + epsilon: Option, + ) -> bool { + let max_radius = max_radius.unwrap_or(f64::INFINITY); + let epsilon = epsilon.unwrap_or_else(T::epsilon); + let iter = self.nn_within_lazy(leaf, max_radius); + // iterate over iter upto k + // And compare each neighbors row vec to other_point elementwise within an epsilon + // Do not use take k since k can be large + // but do ensure we never go over k + let mut count = 0; + for (l, _) in iter { + if l.row_vec + .iter() + .zip(other_point.iter()) + .all(|(a, b)| (*a - *b).abs() < epsilon) + { + return true; + } + if count >= k { + return false; + } + count += 1; + } + false + } + + // Min radius to encompas k points + pub fn min_radius<'query>(&'query mut self, leaf: &'query Leaf<'query, T, A>, k: usize) -> f64 { + let mut total_count = 0; + let balls = &mut self.balls; + balls.clear(); + balls.push(Item( + self.ball_tree.0.nearest_distance(leaf, &self.metric), + &self.ball_tree.0, + )); + + while let Some(Item(distance, node)) = balls.pop() { + match node { + BallTreeInner::Empty => {} + BallTreeInner::Leaf(_) => { + total_count += 1; + if total_count >= k { + return distance; + } + } + BallTreeInner::Branch { + sphere, + left, + right, + count, + } => { + let next_distance = balls.peek().map(|Item(d, _)| *d).unwrap_or(f64::INFINITY); + if total_count + count < k && sphere.farthest_distance(leaf) < next_distance { + total_count += count; + } else { + balls.push(Item(sphere.nearest_distance(leaf), left)); + balls.push(Item(sphere.nearest_distance(leaf), right)); + } + } + } + } + f64::INFINITY + } + + /// Needs some reassessing and validation + /// The tests do seem correct so far. + pub fn count<'query>( + &'query mut self, + leaf: &'query Leaf<'query, T, A>, + max_radius: f64, + ) -> usize { + let mut total_count = 0; + let iter = self.nn_within_lazy(leaf, max_radius); + for _ in iter { + total_count += 1; + } + total_count + } + pub fn allocated_size(&self) -> usize { + self.balls.capacity() * std::mem::size_of::>>() + } + + pub fn deallocate_memory(&mut self) { + self.balls.clear(); + self.balls.shrink_to_fit(); + } +} + +/// So far this is only available on f64 leaves but can be made generic over T +impl<'a> Query<'a, f64, u32> { + pub fn knn_regress<'query>( + &'query mut self, + leaf: &'query Leaf<'query, f64, u32>, + k: usize, + max_radius: f64, + ) -> Option { + let neighbors: Vec<(&Leaf, f64)> = self + .nn_within_lazy(leaf, max_radius) + .into_iter() + .take(k) + .collect(); + let weights = neighbors + .iter() + .map(|(_, d)| (1.0f64 + *d).recip().into()) + .collect::>(); + let sum = weights.iter().copied().sum::(); + Some( + neighbors + .into_iter() + .zip(weights.into_iter()) + .fold(0f64, |acc, (nb, w)| acc + w * nb.0.item as f64) + / sum, + ) + } +} + +fn ball_tree_output(_: &[Field]) -> PolarsResult { + let inner_struct = DataType::Struct(vec![ + Field::new("id", DataType::UInt32), + Field::new("distance", DataType::Float64), + ]); + + Ok(Field::new( + "distance_id", + DataType::List(Box::new(inner_struct.clone())), + )) +} + +/// Converts a list of slices into a single dimension row major slice +/// where all the points are co-located based on the num-columns +pub fn row_major_slice_to_leaves<'a, T: Float + 'static, A: Copy>( + slice: &'a [T], + ncols: usize, + indices: &'a [A], +) -> Vec> { + indices + .iter() + .copied() + .zip(slice.chunks_exact(ncols)) + .map(|(idx, chunk)| (idx, chunk).into()) + .collect() +} + +#[derive(Deserialize, Serialize, Debug)] +pub struct BallTreePtwise { + pub r: f64, + pub sort: bool, + pub parallel: bool, + pub k: usize, + pub distance_metric: String, +} + +fn knn_ptwise_wdist( + inputs: &[Series], + id: &[u32], + radius: f64, + k: usize, + can_parallel: bool, + distance_metric: DistanceMetrics, +) -> Result<(ListChunked, ListChunked), PolarsError> { + let ncols = inputs.len(); + let data = series_to_row_major_slice::(inputs)?; + + let leaves = row_major_slice_to_leaves(&data, ncols, id); + let ball_tree = BallTree::new(leaves.clone(), distance_metric); + + if can_parallel { + POOL.install(|| { + let splits = split_offsets(id.len(), POOL.current_num_threads()); + let chunks: (Vec<_>, Vec<_>) = splits + .into_par_iter() + .map(|(offset, len)| { + let mut id_builder = ListPrimitiveChunkedBuilder::::new( + "id", + id.len(), + k + 1, + DataType::UInt32, + ); + let mut distance_builder = ListPrimitiveChunkedBuilder::::new( + "distance", + id.len(), + k + 1, + DataType::Float64, + ); + let mut binding = ball_tree.query(); + for i in offset..offset + len { + let mut id_vec = Vec::with_capacity(k + 1); + let mut dist_vec = Vec::with_capacity(k + 1); + binding + .nn_within(&leaves[i], 5, radius) + .into_iter() + .for_each(|(l, d)| { + id_vec.push(l.item); + dist_vec.push(d); + }); + id_builder.append_slice(&id_vec); + distance_builder.append_slice(&dist_vec); + } + let id_nb = id_builder.finish(); + let distance_nb = distance_builder.finish(); + ( + id_nb.downcast_iter().cloned().collect::>(), + distance_nb.downcast_iter().cloned().collect::>(), + ) + }) + .collect(); + let id_nb = ListChunked::from_chunk_iter("ids", chunks.0.into_iter().flatten()); + let distance_nb = + ListChunked::from_chunk_iter("distances", chunks.1.into_iter().flatten()); + Ok((id_nb, distance_nb)) + }) + } else { + let mut id_builder = + ListPrimitiveChunkedBuilder::::new("id", id.len(), k + 1, DataType::UInt32); + let mut distance_builder = ListPrimitiveChunkedBuilder::::new( + "distance", + id.len(), + k + 1, + DataType::Float64, + ); + let mut binding = ball_tree.query(); + // For each id create a list of tuples with the id and distance using nn_within into res + for i in 0..id.len() { + let mut id_vec = Vec::with_capacity(k + 1); + let mut dist_vec = Vec::with_capacity(k + 1); + binding + .nn_within(&leaves[i], k, radius) + .into_iter() + .for_each(|(l, d)| { + id_vec.push(l.item); + dist_vec.push(d); + }); + id_builder.append_slice(&id_vec); + distance_builder.append_slice(&dist_vec); + } + Ok((id_builder.finish(), distance_builder.finish())) + } +} + +// Function to validate we can create a ball tree from the series +#[polars_expr(output_type_func=ball_tree_output)] +fn pl_query_knn_ptwise_wdist( + inputs: &[Series], + context: CallerContext, + kwargs: BallTreePtwise, +) -> PolarsResult { + let radius = kwargs.r; + let can_parallel = kwargs.parallel && !context.parallel(); + let k = kwargs.k + 1; + let distance_metric = kwargs.distance_metric; + let distance_metric = match distance_metric.as_str() { + "haversine" => DistanceMetrics::Haversine(haversine_distance), + "euclidean" => DistanceMetrics::Euclidean(euclidean_distance), + _ => { + return Err(PolarsError::InvalidOperation( + "Invalid distance metric".into(), + )) + } + }; + + let id = inputs[0].u32()?; + let id = id.cont_slice()?; + + let (id_nb, distance_nb) = + knn_ptwise_wdist(&inputs[1..], id, radius, k, can_parallel, distance_metric)?; + let out = StructChunked::new( + "knn_dist", + &[id_nb.into_series(), distance_nb.into_series()], + )?; + Ok(out.into_series()) +} + +// Without Distance metrics + +fn knn_ptwise( + inputs: &[Series], + id: &[u32], + radius: f64, + k: usize, + can_parallel: bool, + distance_metric: DistanceMetrics, +) -> Result { + let ncols = inputs.len(); + let data = series_to_row_major_slice::(inputs)?; + + let leaves = row_major_slice_to_leaves(&data, ncols, id); + let ball_tree = BallTree::new(leaves.clone(), distance_metric); + + if can_parallel { + POOL.install(|| { + let splits = split_offsets(id.len(), POOL.current_num_threads()); + let chunks: Vec<_> = splits + .into_par_iter() + .map(|(offset, len)| { + let mut id_builder = ListPrimitiveChunkedBuilder::::new( + "id", + id.len(), + k + 1, + DataType::UInt32, + ); + let mut binding = ball_tree.query(); + for i in offset..offset + len { + let mut id_vec = Vec::with_capacity(k + 1); + binding + .nn_within(&leaves[i], 5, radius) + .into_iter() + .for_each(|(l, _)| { + id_vec.push(l.item); + }); + id_builder.append_slice(&id_vec); + } + let id_nb = id_builder.finish(); + id_nb.downcast_iter().cloned().collect::>() + }) + .collect(); + let id_nb = ListChunked::from_chunk_iter("ids", chunks.into_iter().flatten()); + Ok(id_nb) + }) + } else { + let mut id_builder = + ListPrimitiveChunkedBuilder::::new("id", id.len(), k + 1, DataType::UInt32); + let mut binding = ball_tree.query(); + // For each id create a list of tuples with the id and distance using nn_within into res + for i in 0..id.len() { + let mut id_vec = Vec::with_capacity(k + 1); + binding + .nn_within(&leaves[i], k, radius) + .into_iter() + .for_each(|(l, _)| { + id_vec.push(l.item); + }); + id_builder.append_slice(&id_vec); + } + Ok(id_builder.finish()) + } +} + +// Function to validate we can create a ball tree from the series +#[polars_expr(output_type_func=ball_tree_output)] +fn pl_query_knn_ptwise( + inputs: &[Series], + context: CallerContext, + kwargs: BallTreePtwise, +) -> PolarsResult { + let radius = kwargs.r; + let can_parallel = kwargs.parallel && !context.parallel(); + let k = kwargs.k + 1; + let distance_metric = kwargs.distance_metric; + let distance_metric = match distance_metric.as_str() { + "haversine" => DistanceMetrics::Haversine(haversine_distance), + "euclidean" => DistanceMetrics::Euclidean(euclidean_distance), + _ => { + return Err(PolarsError::InvalidOperation( + "Invalid distance metric".into(), + )) + } + }; + + let id = inputs[0].u32()?; + let id = id.cont_slice()?; + + let id_nb = knn_ptwise(&inputs[1..], id, radius, k, can_parallel, distance_metric)?; + Ok(id_nb.into_series()) +} + +#[polars_expr(output_type=Float64)] +fn pl_ball_tree_knn_avg( + inputs: &[Series], + context: CallerContext, + kwargs: BallTreePtwise, +) -> PolarsResult { + let radius = kwargs.r; + let can_parallel = kwargs.parallel && !context.parallel(); + let k = kwargs.k + 1; + let distance_metric = kwargs.distance_metric; + let distance_metric = match distance_metric.as_str() { + "haversine" => DistanceMetrics::Haversine(haversine_distance), + "euclidean" => DistanceMetrics::Euclidean(euclidean_distance), + _ => { + return Err(PolarsError::InvalidOperation( + "Invalid distance metric".into(), + )) + } + }; + + let id = inputs[0].u32()?; + let id = id.cont_slice()?; + let data = series_to_row_major_slice::(&inputs[1..])?; + let leaves = row_major_slice_to_leaves(&data, inputs.len() - 1, id); + let ball_tree = BallTree::new(leaves.clone(), distance_metric); + /* + let mut binding = ball_tree.query(); + let res = id.iter().zip(leaves).map( + |(_, leaf)| { + if let Some(res) = binding.knn_regress(&leaf, k, radius) { + res + } else { + 0.0 + } + }).collect::>(); + */ + + if can_parallel { + let splits = split_offsets(id.len(), POOL.current_num_threads()); + let chunks = splits.into_par_iter().map(|(offset, len)| { + let mut id_builder = PrimitiveChunkedBuilder::::new("avg", len); + let mut binding = ball_tree.query(); + for i in offset..offset + len { + // unrwap the option so res is f64 and the default is 0,0 + let res = binding.knn_regress(&leaves[i], k, radius); + id_builder.append_option(res); + } + let ca = id_builder.finish(); + ca.downcast_iter().cloned().collect::>() + }); + let chunks = POOL.install(|| chunks.collect::>()); + let ca = Float64Chunked::from_chunk_iter("avg", chunks.into_iter().flatten()); + Ok(ca.into_series()) + } else { + let mut id_builder = PrimitiveChunkedBuilder::::new("avg", id.len()); + let mut binding = ball_tree.query(); + for i in 0..id.len() { + let res = binding.knn_regress(&leaves[i], k, radius); + id_builder.append_option(res); + } + Ok(id_builder.finish().into_series()) + } +} + +#[polars_expr(output_type=UInt32)] +fn pl_nb_count( + inputs: &[Series], + context: CallerContext, + kwargs: BallTreePtwise, +) -> PolarsResult { + let can_parallel = kwargs.parallel && !context.parallel(); + let distance_metric = kwargs.distance_metric; + let distance_metric = match distance_metric.as_str() { + "haversine" => DistanceMetrics::Haversine(haversine_distance), + "euclidean" => DistanceMetrics::Euclidean(euclidean_distance), + _ => { + return Err(PolarsError::InvalidOperation( + "Invalid distance metric".into(), + )) + } + }; + + let id = inputs[0].u32()?; + let id = id.cont_slice()?; + + let radius = inputs[1].f64()?; + let radius = radius.cont_slice()?; + + let data = series_to_row_major_slice::(&inputs[2..])?; + let leaves = row_major_slice_to_leaves(&data, inputs.len() - 2, id); + let ball_tree = BallTree::new(leaves.clone(), distance_metric); + + if can_parallel { + let splits = split_offsets(id.len(), POOL.current_num_threads()); + let chunks = splits.into_par_iter().map(|(offset, len)| { + let mut id_builder = PrimitiveChunkedBuilder::::new("count", len); + let mut binding = ball_tree.query(); + for i in offset..offset + len { + let _leaf = &leaves[i]; + let r = if radius.len() == 1 { + radius[0] + } else { + radius[i] + }; + let count = binding.count(&leaves[i], r); + id_builder.append_value(count as u32); + } + let ca = id_builder.finish(); + ca.downcast_iter().cloned().collect::>() + }); + let chunks = POOL.install(|| chunks.collect::>()); + let ca = UInt32Chunked::from_chunk_iter("counts", chunks.into_iter().flatten()); + Ok(ca.into_series()) + } else { + let mut id_builder = PrimitiveChunkedBuilder::::new("count", id.len()); + let mut binding = ball_tree.query(); + for i in 0..id.len() { + let count = binding.count(&leaves[i], radius[i]); + id_builder.append_value(count as u32); + } + Ok(id_builder.finish().into_series()) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BtWithinDist { + pub parallel: bool, + pub distance_metric: String, + pub point: Vec, +} + +#[polars_expr(output_type=Boolean)] +fn pl_bt_within_dist_from( + inputs: &[Series], + context: CallerContext, + kwargs: BtWithinDist, +) -> PolarsResult { + let can_parallel = kwargs.parallel && !context.parallel(); + let distance_metric = kwargs.distance_metric; + let distance_metric = match distance_metric.as_str() { + "haversine" => DistanceMetrics::Haversine(haversine_distance), + "euclidean" => DistanceMetrics::Euclidean(euclidean_distance), + _ => { + return Err(PolarsError::InvalidOperation( + "Invalid distance metric".into(), + )) + } + }; + let point = &kwargs.point; + + // create a incrementing uint32 for each row + let id = (0..inputs[1].len() as u32).collect::>(); + let radius = inputs[0].f64()?; + let radius = radius.cont_slice()?; + let data = series_to_row_major_slice::(&inputs[1..])?; + let leaves = row_major_slice_to_leaves(&data, inputs[1..].len(), &id); + if can_parallel { + let splits = split_offsets(id.len(), POOL.current_num_threads()); + let chunks = splits.into_par_iter().map(|(offset, len)| { + let mut id_builder = BooleanChunkedBuilder::new("within", len); + for i in offset..offset + len { + let leaf = leaves[i]; + let dist = leaf.distance( + &Leaf { + row_vec: point, + item: 0, + }, + &distance_metric, + ); + let r = if radius.len() == 1 { + radius[0] + } else { + radius[i] + }; + let within = dist <= r; + id_builder.append_value(within); + } + let ca = id_builder.finish(); + ca.downcast_iter().cloned().collect::>() + }); + let chunks = POOL.install(|| chunks.collect::>()); + let ca = BooleanChunked::from_chunk_iter("within", chunks.into_iter().flatten()); + Ok(ca.into_series()) + } else { + let mut id_builder = BooleanChunkedBuilder::new("within", id.len()); + for i in 0..id.len() { + let leaf = leaves[i]; + let dist = leaf.distance( + &Leaf { + row_vec: point, + item: 0, + }, + &distance_metric, + ); + let r = if radius.len() == 1 { + radius[0] + } else { + radius[i] + }; + let within = dist <= r; + id_builder.append_value(within); + } + Ok(id_builder.finish().into_series()) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct BtKnnFrom { + pub parallel: bool, + pub distance_metric: String, + pub point: Vec, + k: usize, + epsilon: Option, + max_radius: Option, +} + +#[polars_expr(output_type=Boolean)] +fn pl_bt_knn_from( + inputs: &[Series], + context: CallerContext, + kwargs: BtKnnFrom, +) -> PolarsResult { + let can_parallel = kwargs.parallel && !context.parallel(); + + let distance_metric = kwargs.distance_metric; + let distance_metric = match distance_metric.as_str() { + "haversine" => DistanceMetrics::Haversine(haversine_distance), + "euclidean" => DistanceMetrics::Euclidean(euclidean_distance), + _ => { + return Err(PolarsError::InvalidOperation( + "Invalid distance metric".into(), + )) + } + }; + let point = &kwargs.point; + + let id = (0..inputs[0].len() as u32).collect::>(); + let data = series_to_row_major_slice::(&inputs[..])?; + let leaves = row_major_slice_to_leaves(&data, inputs.len(), &id); + let k = kwargs.k.min(leaves.len()); + let ball_tree = BallTree::new(leaves.clone(), distance_metric); + + if can_parallel { + let splits = split_offsets(id.len(), POOL.current_num_threads()); + let chunks = splits.into_par_iter().map(|(offset, len)| { + let mut id_builder = BooleanChunkedBuilder::new("within", len); + let mut binding = ball_tree.query(); + for i in offset..offset + len { + let leaf = leaves[i]; + let within = binding.is_knn(&leaf, point, k, kwargs.max_radius, kwargs.epsilon); + id_builder.append_value(within); + } + let ca = id_builder.finish(); + ca.downcast_iter().cloned().collect::>() + }); + let chunks = POOL.install(|| chunks.collect::>()); + let ca = BooleanChunked::from_chunk_iter("within", chunks.into_iter().flatten()); + Ok(ca.into_series()) + } else { + let mut id_builder = BooleanChunkedBuilder::new("within", id.len()); + let mut binding = ball_tree.query(); + for i in 0..id.len() { + let leaf = leaves[i]; + let within = binding.is_knn(&leaf, point, k, kwargs.max_radius, kwargs.epsilon); + id_builder.append_value(within); + } + Ok(id_builder.finish().into_series()) + } +} + +#[cfg(test)] +mod tests { + use core::f64; + use std::collections::HashSet; + + use super::*; + use rand::{Rng, SeedableRng}; + use rand_chacha::ChaChaRng; + use rapidfuzz::distance; + + macro_rules! generate_leaves { + ($size:expr, $rng:ident, $random_leaves:ident, $leaves:ident) => { + let mut $leaves = vec![]; + let mut $random_leaves = vec![]; // Store the generated leaves here + for i in 0..$size { + let leaf_count: usize = if i < 100 { + $rng.gen_range(1..=3) + } else if i < 500 { + $rng.gen_range(1..=10) + } else { + $rng.gen_range(1..=100) + }; + + for _ in 0..leaf_count { + let random_leaf = random_3d_leaf!(); + $random_leaves.push(random_leaf); // Store the generated leaf + } + } + for (i, random_leaf) in $random_leaves.iter().enumerate() { + let leaf: Leaf = Leaf { + row_vec: random_leaf, // Reference the stored leaf + item: i as u32, + }; + $leaves.push(leaf); + } + }; + } + + macro_rules! setup_rng_and_macros { + ($rng:ident) => { + let mut $rng: ChaChaRng = SeedableRng::seed_from_u64(0xcb42c94d23346e96); + + macro_rules! random_small_f64 { + () => { + $rng.gen_range(-100.0..=100.0) + }; + } + + macro_rules! random_3d_leaf { + () => { + [random_small_f64!(), random_small_f64!()] + }; + } + }; + } + + macro_rules! generate_knn_inputs { + ($id:ident, $input1:ident, $input2:ident, $size:expr, $rng:ident) => { + let mut $id = vec![]; + let mut $input1 = vec![]; + let mut $input2 = vec![]; + + for i in 0..$size { + let random_value1 = $rng.gen_range(-100.0..=100.0); + let random_value2 = $rng.gen_range(-100.0..=100.0); + $id.push(i as u32); + $input1.push(random_value1); + $input2.push(random_value2); + } + }; + } + + #[test] + fn test_knn_within() { + setup_rng_and_macros!(rng); + generate_leaves!(100, rng, random_leaves, leaves); + let distance_metric = DistanceMetrics::Haversine(haversine_distance); + let tree = BallTree::new(leaves.clone(), distance_metric); + let mut binding = tree.query(); + let point = leaves[2].row_vec; + let k = 399; + let within = binding.is_knn( + &leaves[2], + point, + k, + Some(f64::INFINITY), + Some(f64::EPSILON), + ); + assert_eq!(within, true); + } + + #[test] + fn test_2d_leaves() { + setup_rng_and_macros!(rng); + generate_leaves!(10, rng, random_leaves, leaves); + let distance_metric = DistanceMetrics::Haversine(haversine_distance); + let tree = BallTree::new(leaves.clone(), distance_metric); + let mut binding = tree.query(); + let mut nnw = binding.nn_within_lazy(&leaves[3], 399.99); + + println!("nnw {:?}", nnw); + println!("tree {:?}", tree); + + println!("nnw_nest {:?}", nnw.next()); + let res = binding.nn_within(&leaves[6], 5, 399.99); + assert!(res.len() <= 5); + } + + #[test] + fn test_bounding_leaves() { + setup_rng_and_macros!(rng); + generate_leaves!(10, rng, random_leaves, leaves); + let distance_metric = DistanceMetrics::Haversine(haversine_distance); + let tree = BallTree::new(leaves.clone(), distance_metric); + let mut binding = tree.query(); + let nnw = binding.nn_within_lazy(&leaves[9], 999.99); + let mut total = 0; + nnw.for_each(|x| { + total += 1; + }); + assert!(total > 1); + } + + #[test] + fn test_knn_regress() { + setup_rng_and_macros!(rng); + generate_leaves!(100, rng, random_leaves, leaves); + let distance_metric = DistanceMetrics::Haversine(haversine_distance); + let tree = BallTree::new(leaves.clone(), distance_metric); + let mut binding = tree.query(); + let res = binding.knn_regress(&leaves[9], 18, 99999.99); + println!("{:?}", res); + } + + #[test] + fn test_nb_count() { + setup_rng_and_macros!(rng); + generate_leaves!(100, rng, random_leaves, leaves); + let distance_metric = DistanceMetrics::Haversine(haversine_distance); + let tree = BallTree::new(leaves.clone(), distance_metric); + let mut binding = tree.query(); + let leaf = leaves[9].clone(); + let res = binding.count(&leaf, 9999999.0f64); + println!("count for leaf: {:?} is {:?}", leaf, res); + } + + #[test] + fn test_overall_tree() { + setup_rng_and_macros!(rng); + generate_leaves!(100, rng, random_leaves, leaves); + let distance_metric = DistanceMetrics::Haversine(euclidean_distance); + let tree = BallTree::new(leaves.clone(), distance_metric.clone()); + let mut binding = tree.query(); + + for _ in 0..100 { + let point = random_3d_leaf!(); + let max_radius = rng.gen_range(0.0..=100.0); + + let expected_values = leaves + .iter() + .filter(|leaf| { + leaf.distance( + &Leaf { + row_vec: &point, + item: 0, + }, + &distance_metric, + ) <= max_radius + }) + .map(|leaf| leaf.item) + .collect::>(); + let mut found_values = HashSet::new(); + + let mut previous_d = 0.0; + let leaf = Leaf { + row_vec: &point, + item: 0, + }; + for (leaf, d) in binding.nn_within_lazy(&leaf, max_radius) { + assert_eq!( + leaf.distance( + &Leaf { + row_vec: &point, + item: 0 + }, + &distance_metric + ), + d + ); + + assert!(d >= previous_d); + + assert!(d <= max_radius); + previous_d = d; + found_values.insert(leaf.item); + } + assert_eq!(expected_values, found_values); + let binding_count = binding.count(&leaf, max_radius); + assert_eq!(found_values.len(), binding_count); + + let radius = binding.min_radius(&leaf, expected_values.len()); + let should_be_fewer = binding.count(&leaf, radius * 0.3); + assert!( + expected_values.is_empty() || should_be_fewer < expected_values.len(), + "{} < {}", + should_be_fewer, + expected_values.len() + ); + } + + assert!(binding.allocated_size() > 0); + + assert!(binding.allocated_size() <= 2 * 8 * leaves.len().next_power_of_two().max(4)); + binding.deallocate_memory(); + assert_eq!(binding.allocated_size(), 0); + } + + #[test] + fn test_leaf_impls() { + let row_vecs = vec![vec![1.0, 2.0], vec![2.0, 3.0], vec![3.0, 4.0]]; + let leaf1 = Leaf { + row_vec: &row_vecs[0], + item: 0, + }; + let leaf2 = Leaf { + row_vec: &row_vecs[1], + item: 0, + }; + let distance_metric = DistanceMetrics::Euclidean(euclidean_distance); + + assert_eq!(leaf1.distance(&leaf2, &distance_metric), 1.4142135623730951); + + let leaf3 = Leaf { + row_vec: &row_vecs[2], + item: 1, + }; + + assert_eq!(leaf1.distance(&leaf3, &distance_metric), 2.8284271247461903); + + unsafe { + let lv = leaf1.move_towards(&leaf3, 1.3, &distance_metric).row_vec; + assert_eq!(lv, [1.3788582233137676, 1.8384776310850235]); + } + } +} diff --git a/src/num_ext/mod.rs b/src/num_ext/mod.rs index 00a96ce8..5c65d2ee 100644 --- a/src/num_ext/mod.rs +++ b/src/num_ext/mod.rs @@ -1,5 +1,6 @@ use num::Float; +mod ball_tree; mod benford; mod cond_entropy; mod convolve;