Skip to content

Commit

Permalink
Fix Python API
Browse files Browse the repository at this point in the history
  • Loading branch information
cschwan committed Feb 10, 2025
1 parent 52b96f4 commit 3c7da40
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 86 deletions.
78 changes: 0 additions & 78 deletions pineappl_py/src/import_subgrid.rs

This file was deleted.

2 changes: 0 additions & 2 deletions pineappl_py/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ pub mod convolutions;
pub mod evolution;
pub mod fk_table;
pub mod grid;
pub mod import_subgrid;
pub mod interpolation;
pub mod pids;
pub mod subgrid;
Expand All @@ -23,7 +22,6 @@ fn pineappl(m: &Bound<'_, PyModule>) -> PyResult<()> {
fk_table::register(m)?;
grid::register(m)?;
interpolation::register(m)?;
import_subgrid::register(m)?;
pids::register(m)?;
subgrid::register(m)?;
m.add("version", env!("CARGO_PKG_VERSION"))?;
Expand Down
58 changes: 55 additions & 3 deletions pineappl_py/src/subgrid.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,61 @@
//! Subgrid interface.
use ndarray::ArrayD;
use numpy::{IntoPyArray, PyArrayDyn};
use pineappl::subgrid::{Subgrid, SubgridEnum};
use ndarray::{ArrayD, Dimension};
use numpy::{IntoPyArray, PyArrayDyn, PyReadonlyArrayDyn};
use pineappl::packed_array::PackedArray;
use pineappl::subgrid::{ImportSubgridV1, Subgrid, SubgridEnum};
use pyo3::prelude::*;

/// PyO3 wrapper to :rustdoc:`pineappl::subgrid::ImportSubgridV1 <subgrid/struct.ImportSubgridV1.html>`.
#[pyclass(name = "ImportSubgridV1")]
#[derive(Clone)]
#[repr(transparent)]
pub struct PyImportSubgridV1 {
pub(crate) import_subgrid: ImportSubgridV1,
}

#[pymethods]
impl PyImportSubgridV1 {
/// Constructor.
///
/// # Panics
/// TODO
///
/// Parameters
/// ----------
/// array : numpy.ndarray(float)
/// `N`-dimensional array with all weights
/// node_values: list(list(float))
/// list containing the arrays of energy scales {q1, ..., qn} and momentum fractions
/// {x1, ..., xn}.
#[new]
#[must_use]
pub fn new(array: PyReadonlyArrayDyn<f64>, node_values: Vec<Vec<f64>>) -> Self {
let mut sparse_array: PackedArray<f64> =
PackedArray::new(node_values.iter().map(Vec::len).collect());

for (index, value) in array
.as_array()
.indexed_iter()
.filter(|(_, value)| **value != 0.0)
{
sparse_array[index.as_array_view().to_slice().unwrap()] = *value;
}

Self {
import_subgrid: ImportSubgridV1::new(sparse_array, node_values),
}
}

/// Ensures that the subgrid has type `PySubgridEnum`.
#[must_use]
pub fn into(&self) -> PySubgridEnum {
PySubgridEnum {
subgrid_enum: self.import_subgrid.clone().into(),
}
}
}

/// PyO3 wrapper to :rustdoc:`pineappl::subgrid::SubgridEnum <subgrid/struct.SubgridEnum.html>`
#[pyclass(name = "SubgridEnum")]
#[derive(Clone)]
Expand Down Expand Up @@ -72,6 +123,7 @@ pub fn register(parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
m,
"import sys; sys.modules['pineappl.subgrid'] = m"
);
m.add_class::<PyImportSubgridV1>()?;
m.add_class::<PySubgridEnum>()?;
parent_module.add_submodule(&m)
}
2 changes: 1 addition & 1 deletion pineappl_py/tests/test_fk_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pineappl.boc import Channel, Order
from pineappl.convolutions import Conv, ConvType
from pineappl.fk_table import FkAssumptions, FkTable
from pineappl.import_subgrid import ImportSubgridV1
from pineappl.subgrid import ImportSubgridV1


class TestFkTable:
Expand Down
2 changes: 1 addition & 1 deletion pineappl_py/tests/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pineappl.evolution import OperatorSliceInfo
from pineappl.fk_table import FkTable
from pineappl.grid import Grid
from pineappl.import_subgrid import ImportSubgridV1
from pineappl.subgrid import ImportSubgridV1
from pineappl.pids import PidBasis

# Construct the type of convolutions and the convolution object
Expand Down
2 changes: 1 addition & 1 deletion pineappl_py/tests/test_subgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pineappl.boc import Channel, Order
from pineappl.convolutions import Conv, ConvType
from pineappl.grid import Grid
from pineappl.import_subgrid import ImportSubgridV1
from pineappl.subgrid import ImportSubgridV1
from pineappl.subgrid import SubgridEnum

# Define some default for the minimum value of `Q2`
Expand Down

0 comments on commit 3c7da40

Please sign in to comment.