Skip to content

Commit 993263c

Browse files
committed
Extract Python bits from Rust
1 parent 935a582 commit 993263c

File tree

5 files changed

+300
-291
lines changed

5 files changed

+300
-291
lines changed

Cargo.toml

+5-2
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,8 @@ crate-type = ["cdylib"]
1616
approx = "^0.5.1"
1717
ndarray = "^0.15.6"
1818
num-traits = "^0.2.19"
19-
numpy = "^0.21.0"
20-
pyo3 = { version = "^0.21.1", features = ["extension-module"] }
19+
pyo3 = { version = "^0.21.1", features = ["extension-module", "abi3-py38"], optional = true }
20+
numpy = { version = "^0.21.0", optional = true }
21+
22+
[features]
23+
python = ["dep:pyo3", "dep:numpy"]

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ exclude_also = [
7474
]
7575

7676
[tool.maturin]
77+
features = ["python"]
7778
python-source = "python"
7879

7980
[tool.mypy]

src/counting.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ use std::ops::AddAssign;
33
use ndarray::{Array1, ArrayView1, Axis, ErrorKind, ShapeError};
44
use num_traits::{Float, Num};
55

6-
use crate::utils::{matrices, nan_mean, nan_to_num};
76
use crate::{check_lengths, check_total, Winner};
7+
use crate::utils::{matrices, nan_mean, nan_to_num};
88

99
pub fn counting<A: Num + Copy + AddAssign>(
1010
xs: &ArrayView1<usize>,

src/lib.rs

+5-288
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,15 @@
1-
use numpy::{Element, IntoPyArray, PyArray1, PyArray2, PyArrayDescr, PyArrayLike1};
2-
use pyo3::create_exception;
3-
use pyo3::prelude::*;
4-
5-
use crate::bradley_terry::{bradley_terry, newman};
6-
use crate::counting::{average_win_rate, counting};
7-
use crate::elo::elo;
8-
use crate::linalg::{eigen, pagerank};
9-
use crate::utils::matrices;
1+
#[cfg(feature = "python")]
2+
use pyo3::prelude::pyclass;
103

114
mod bradley_terry;
125
mod counting;
136
mod elo;
147
mod linalg;
8+
#[cfg(feature = "python")]
9+
mod python;
1510
mod utils;
1611

17-
#[pyclass(module = "evalica")]
12+
#[cfg_attr(feature = "python", pyclass(module = "evalica"))]
1813
#[repr(u8)]
1914
#[derive(Clone, Debug, PartialEq, Hash)]
2015
pub enum Winner {
@@ -45,281 +40,3 @@ impl Into<u8> for Winner {
4540
}
4641
}
4742
}
48-
49-
#[pymethods]
50-
impl Winner {
51-
#[new]
52-
fn new() -> Self {
53-
Winner::Ignore
54-
}
55-
56-
fn __getstate__(&self) -> PyResult<u8> {
57-
Ok(self.clone().into())
58-
}
59-
60-
fn __setstate__(&mut self, state: u8) -> PyResult<()> {
61-
*self = Winner::from(state);
62-
Ok(())
63-
}
64-
}
65-
66-
unsafe impl Element for Winner {
67-
const IS_COPY: bool = true;
68-
69-
fn get_dtype_bound(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
70-
numpy::dtype_bound::<u8>(py)
71-
}
72-
}
73-
74-
create_exception!(evalica, LengthMismatchError, pyo3::exceptions::PyValueError);
75-
76-
#[pyfunction]
77-
fn matrices_pyo3<'py>(
78-
py: Python<'py>,
79-
xs: PyArrayLike1<'py, usize>,
80-
ys: PyArrayLike1<'py, usize>,
81-
ws: PyArrayLike1<'py, Winner>,
82-
total: usize,
83-
) -> PyResult<(Py<PyArray2<i64>>, Py<PyArray2<i64>>)> {
84-
match matrices(&xs.as_array(), &ys.as_array(), &ws.as_array(), total, 1, 1) {
85-
Ok((wins, ties)) => Ok((
86-
wins.into_pyarray_bound(py).unbind(),
87-
ties.into_pyarray_bound(py).unbind(),
88-
)),
89-
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
90-
}
91-
}
92-
93-
#[pyfunction]
94-
fn counting_pyo3<'py>(
95-
py: Python,
96-
xs: PyArrayLike1<'py, usize>,
97-
ys: PyArrayLike1<'py, usize>,
98-
ws: PyArrayLike1<'py, Winner>,
99-
total: usize,
100-
win_weight: f64,
101-
tie_weight: f64,
102-
) -> PyResult<Py<PyArray1<f64>>> {
103-
match counting(
104-
&xs.as_array(),
105-
&ys.as_array(),
106-
&ws.as_array(),
107-
total,
108-
win_weight,
109-
tie_weight,
110-
) {
111-
Ok(scores) => Ok(scores.into_pyarray_bound(py).unbind()),
112-
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
113-
}
114-
}
115-
116-
#[pyfunction]
117-
fn average_win_rate_pyo3<'py>(
118-
py: Python,
119-
xs: PyArrayLike1<'py, usize>,
120-
ys: PyArrayLike1<'py, usize>,
121-
ws: PyArrayLike1<'py, Winner>,
122-
total: usize,
123-
win_weight: f64,
124-
tie_weight: f64,
125-
) -> PyResult<Py<PyArray1<f64>>> {
126-
match average_win_rate(
127-
&xs.as_array(),
128-
&ys.as_array(),
129-
&ws.as_array(),
130-
total,
131-
win_weight,
132-
tie_weight,
133-
) {
134-
Ok(scores) => Ok(scores.into_pyarray_bound(py).unbind()),
135-
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
136-
}
137-
}
138-
139-
#[pyfunction]
140-
fn bradley_terry_pyo3<'py>(
141-
py: Python,
142-
xs: PyArrayLike1<'py, usize>,
143-
ys: PyArrayLike1<'py, usize>,
144-
ws: PyArrayLike1<'py, Winner>,
145-
total: usize,
146-
win_weight: f64,
147-
tie_weight: f64,
148-
tolerance: f64,
149-
limit: usize,
150-
) -> PyResult<(Py<PyArray1<f64>>, usize)> {
151-
match matrices(
152-
&xs.as_array(),
153-
&ys.as_array(),
154-
&ws.as_array(),
155-
total,
156-
win_weight,
157-
tie_weight,
158-
) {
159-
Ok((win_matrix, tie_matrix)) => {
160-
let matrix = &win_matrix + &tie_matrix;
161-
162-
match bradley_terry(&matrix.view(), tolerance, limit) {
163-
Ok((scores, iterations)) => {
164-
Ok((scores.into_pyarray_bound(py).unbind(), iterations))
165-
}
166-
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
167-
}
168-
}
169-
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
170-
}
171-
}
172-
173-
#[pyfunction]
174-
fn newman_pyo3<'py>(
175-
py: Python,
176-
xs: PyArrayLike1<'py, usize>,
177-
ys: PyArrayLike1<'py, usize>,
178-
ws: PyArrayLike1<'py, Winner>,
179-
total: usize,
180-
v_init: f64,
181-
win_weight: f64,
182-
tie_weight: f64,
183-
tolerance: f64,
184-
limit: usize,
185-
) -> PyResult<(Py<PyArray1<f64>>, f64, usize)> {
186-
match matrices(
187-
&xs.as_array(),
188-
&ys.as_array(),
189-
&ws.as_array(),
190-
total,
191-
win_weight,
192-
tie_weight,
193-
) {
194-
Ok((win_matrix, tie_matrix)) => {
195-
match newman(
196-
&win_matrix.view(),
197-
&tie_matrix.view(),
198-
v_init,
199-
tolerance,
200-
limit,
201-
) {
202-
Ok((scores, v, iterations)) => {
203-
Ok((scores.into_pyarray_bound(py).unbind(), v, iterations))
204-
}
205-
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
206-
}
207-
}
208-
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
209-
}
210-
}
211-
212-
#[pyfunction]
213-
fn elo_pyo3<'py>(
214-
py: Python,
215-
xs: PyArrayLike1<'py, usize>,
216-
ys: PyArrayLike1<'py, usize>,
217-
ws: PyArrayLike1<'py, Winner>,
218-
total: usize,
219-
initial: f64,
220-
base: f64,
221-
scale: f64,
222-
k: f64,
223-
) -> PyResult<Py<PyArray1<f64>>> {
224-
match elo(
225-
&xs.as_array(),
226-
&ys.as_array(),
227-
&ws.as_array(),
228-
total,
229-
initial,
230-
base,
231-
scale,
232-
k,
233-
) {
234-
Ok(scores) => Ok(scores.into_pyarray_bound(py).unbind()),
235-
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
236-
}
237-
}
238-
239-
#[pyfunction]
240-
fn eigen_pyo3<'py>(
241-
py: Python<'py>,
242-
xs: PyArrayLike1<'py, usize>,
243-
ys: PyArrayLike1<'py, usize>,
244-
ws: PyArrayLike1<'py, Winner>,
245-
total: usize,
246-
win_weight: f64,
247-
tie_weight: f64,
248-
tolerance: f64,
249-
limit: usize,
250-
) -> PyResult<(Py<PyArray1<f64>>, usize)> {
251-
match matrices(
252-
&xs.as_array(),
253-
&ys.as_array(),
254-
&ws.as_array(),
255-
total,
256-
win_weight,
257-
tie_weight,
258-
) {
259-
Ok((win_matrix, tie_matrix)) => {
260-
let matrix = &win_matrix + &tie_matrix;
261-
262-
match eigen(&matrix.view(), tolerance, limit) {
263-
Ok((scores, iterations)) => {
264-
Ok((scores.into_pyarray_bound(py).unbind(), iterations))
265-
}
266-
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
267-
}
268-
}
269-
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
270-
}
271-
}
272-
273-
#[pyfunction]
274-
fn pagerank_pyo3<'py>(
275-
py: Python,
276-
xs: PyArrayLike1<'py, usize>,
277-
ys: PyArrayLike1<'py, usize>,
278-
ws: PyArrayLike1<'py, Winner>,
279-
total: usize,
280-
damping: f64,
281-
win_weight: f64,
282-
tie_weight: f64,
283-
tolerance: f64,
284-
limit: usize,
285-
) -> PyResult<(Py<PyArray1<f64>>, usize)> {
286-
match matrices(
287-
&xs.as_array(),
288-
&ys.as_array(),
289-
&ws.as_array(),
290-
total,
291-
win_weight,
292-
tie_weight,
293-
) {
294-
Ok((win_matrix, tie_matrix)) => {
295-
let matrix = &win_matrix + &tie_matrix;
296-
297-
match pagerank(&matrix.view(), damping, tolerance, limit) {
298-
Ok((scores, iterations)) => {
299-
Ok((scores.into_pyarray_bound(py).unbind(), iterations))
300-
}
301-
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
302-
}
303-
}
304-
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
305-
}
306-
}
307-
308-
#[pymodule]
309-
fn evalica(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
310-
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
311-
m.add(
312-
"LengthMismatchError",
313-
py.get_type_bound::<LengthMismatchError>(),
314-
)?;
315-
m.add_function(wrap_pyfunction!(matrices_pyo3, m)?)?;
316-
m.add_function(wrap_pyfunction!(counting_pyo3, m)?)?;
317-
m.add_function(wrap_pyfunction!(average_win_rate_pyo3, m)?)?;
318-
m.add_function(wrap_pyfunction!(bradley_terry_pyo3, m)?)?;
319-
m.add_function(wrap_pyfunction!(newman_pyo3, m)?)?;
320-
m.add_function(wrap_pyfunction!(elo_pyo3, m)?)?;
321-
m.add_function(wrap_pyfunction!(eigen_pyo3, m)?)?;
322-
m.add_function(wrap_pyfunction!(pagerank_pyo3, m)?)?;
323-
m.add_class::<Winner>()?;
324-
Ok(())
325-
}

0 commit comments

Comments
 (0)