Skip to content

Commit 54af0d2

Browse files
committed
merge ours with no_namespace
2 parents 283187f + 1a1d153 commit 54af0d2

26 files changed

+3519
-1374
lines changed

examples/pattern_match.rs

+9
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,15 @@ fn main() {
1212

1313
println!("> Matching pattern {} to {}:", pat_expr, expr.as_view());
1414

15+
// simple match
16+
for m in expr.pattern_match(&pattern, None, None) {
17+
for (wc, v) in m {
18+
println!("\t{} = {}", wc, v);
19+
}
20+
println!();
21+
}
22+
23+
// advanced match
1524
let mut it = expr.pattern_match(&pattern, None, None);
1625
while let Some(m) = it.next_detailed() {
1726
println!(

examples/tree_replace.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ fn main() {
1313
println!("> Matching pattern {} to {}:", pat_expr, expr);
1414

1515
for x in expr.pattern_match(&pattern, None, None) {
16-
println!("\t x_ = {}", x.get(&symb!("x_")).unwrap().to_atom());
16+
println!("\t x_ = {}", x.get(&symb!("x_")).unwrap());
1717
}
1818

1919
println!("> Matching pattern {} to {}:", pat_expr, expr);

src/api/python.rs

+29-41
Original file line numberDiff line numberDiff line change
@@ -1166,10 +1166,11 @@ impl PythonTransformer {
11661166
py.allow_threads(|| {
11671167
Workspace::get_local()
11681168
.with(|workspace| {
1169-
self.expr.substitute_wildcards(
1169+
self.expr.replace_wildcards_with_matches_impl(
11701170
workspace,
11711171
&mut out,
1172-
&MatchStack::new(&Condition::default(), &MatchSettings::default()),
1172+
&MatchStack::new(),
1173+
true,
11731174
None,
11741175
)
11751176
})
@@ -2067,7 +2068,7 @@ impl TryFrom<Relation> for PatternRestriction {
20672068
WildcardRestriction::Filter(Box::new(move |m| {
20682069
m.to_atom()
20692070
.pattern_match(&pattern, Some(&cond), Some(&settings))
2070-
.next()
2071+
.next_detailed()
20712072
.is_some()
20722073
})),
20732074
)))
@@ -4072,6 +4073,7 @@ impl PythonExpression {
40724073
/// x^2 1
40734074
/// 1 5
40744075
/// ```
4076+
#[pyo3(signature = (*x,))]
40754077
pub fn coefficient_list(
40764078
&self,
40774079
x: Bound<'_, PyTuple>,
@@ -5261,11 +5263,12 @@ impl PythonExpression {
52615263
}
52625264

52635265
/// Canonize (products of) tensors in the expression by relabeling repeated indices.
5264-
/// The tensors must be written as functions, with its indices are the arguments.
5265-
/// The repeated indices should be provided in `contracted_indices`.
5266+
/// The tensors must be written as functions, with its indices as the arguments.
5267+
/// Subexpressions, constants and open indices are supported.
52665268
///
52675269
/// If the contracted indices are distinguishable (for example in their dimension),
5268-
/// you can provide an optional group marker for each index using `index_group`.
5270+
/// you can provide a group marker as the second element in the tuple of the index
5271+
/// specification.
52695272
/// This makes sure that an index will not be renamed to an index from a different group.
52705273
///
52715274
/// Examples
@@ -5276,38 +5279,20 @@ impl PythonExpression {
52765279
/// >>>
52775280
/// >>> e = g(mu2, mu3)*fc(mu4, mu2, k1, mu4, k1, mu3)
52785281
/// >>>
5279-
/// >>> print(e.canonize_tensors([mu1, mu2, mu3, mu4]))
5282+
/// >>> print(e.canonize_tensors([(mu1, 0), (mu2, 0), (mu3, 0), (mu4, 0)]))
52805283
/// yields `g(mu1,mu2)*fc(mu1,mu3,mu2,k1,mu3,k1)`.
5281-
#[pyo3(signature = (contracted_indices, index_group=None))]
52825284
fn canonize_tensors(
52835285
&self,
5284-
contracted_indices: Vec<ConvertibleToExpression>,
5285-
index_group: Option<Vec<ConvertibleToExpression>>,
5286+
contracted_indices: Vec<(ConvertibleToExpression, ConvertibleToExpression)>,
52865287
) -> PyResult<Self> {
52875288
let contracted_indices = contracted_indices
52885289
.into_iter()
5289-
.map(|x| x.to_expression().expr)
5290+
.map(|x| (x.0.to_expression().expr, x.1.to_expression().expr))
52905291
.collect::<Vec<_>>();
5291-
let contracted_indices = contracted_indices
5292-
.iter()
5293-
.map(|x| x.as_view())
5294-
.collect::<Vec<_>>();
5295-
5296-
let index_group = index_group.map(|x| {
5297-
x.into_iter()
5298-
.map(|x| x.to_expression().expr)
5299-
.collect::<Vec<_>>()
5300-
});
5301-
let index_group = index_group
5302-
.as_ref()
5303-
.map(|x| x.iter().map(|x| x.as_view()).collect::<Vec<_>>());
53045292

53055293
let r = self
53065294
.expr
5307-
.canonize_tensors(
5308-
&contracted_indices,
5309-
index_group.as_ref().map(|x| x.as_slice()),
5310-
)
5295+
.canonize_tensors(&contracted_indices)
53115296
.map_err(|e| {
53125297
exceptions::PyValueError::new_err(format!("Could not canonize tensors: {}", e))
53135298
})?;
@@ -5891,7 +5876,7 @@ impl PythonMatchIterator {
58915876
self.with_dependent_mut(|_, i| {
58925877
i.next().map(|m| {
58935878
m.into_iter()
5894-
.map(|(k, v)| (Atom::new_var(k).into(), { v.to_atom().into() }))
5879+
.map(|(k, v)| (Atom::new_var(k).into(), { v.into() }))
58955880
.collect()
58965881
})
58975882
})
@@ -10354,7 +10339,7 @@ impl PythonExpressionEvaluator {
1035410339
(function_name,
1035510340
filename,
1035610341
library_name,
10357-
inline_asm = true,
10342+
inline_asm = "default",
1035810343
optimization_level = 3,
1035910344
compiler_path = None,
1036010345
))]
@@ -10363,7 +10348,7 @@ impl PythonExpressionEvaluator {
1036310348
function_name: &str,
1036410349
filename: &str,
1036510350
library_name: &str,
10366-
inline_asm: bool,
10351+
inline_asm: &str,
1036710352
optimization_level: u8,
1036810353
compiler_path: Option<&str>,
1036910354
) -> PyResult<PythonCompiledExpressionEvaluator> {
@@ -10373,19 +10358,22 @@ impl PythonExpressionEvaluator {
1037310358
options.compiler = compiler_path.to_string();
1037410359
}
1037510360

10361+
let inline_asm = match inline_asm.to_lowercase().as_str() {
10362+
"default" => InlineASM::default(),
10363+
"x64" => InlineASM::X64,
10364+
"aarch64" => InlineASM::AArch64,
10365+
"none" => InlineASM::None,
10366+
_ => {
10367+
return Err(exceptions::PyValueError::new_err(
10368+
"Invalid inline assembly type specified.",
10369+
))
10370+
}
10371+
};
10372+
1037610373
Ok(PythonCompiledExpressionEvaluator {
1037710374
eval: self
1037810375
.eval
10379-
.export_cpp(
10380-
filename,
10381-
function_name,
10382-
true,
10383-
if inline_asm {
10384-
InlineASM::X64
10385-
} else {
10386-
InlineASM::None
10387-
},
10388-
)
10376+
.export_cpp(filename, function_name, true, inline_asm)
1038910377
.map_err(|e| exceptions::PyValueError::new_err(format!("Export error: {}", e)))?
1039010378
.compile(library_name, options)
1039110379
.map_err(|e| {

src/atom/core.rs

+23-20
Original file line numberDiff line numberDiff line change
@@ -1025,11 +1025,12 @@ pub trait AtomCore {
10251025
}
10261026

10271027
/// Canonize (products of) tensors in the expression by relabeling repeated indices.
1028-
/// The tensors must be written as functions, with its indices are the arguments.
1029-
/// The repeated indices should be provided in `contracted_indices`.
1028+
/// The tensors must be written as functions, with its indices as the arguments.
1029+
/// Subexpressions, constants and open indices are supported.
10301030
///
10311031
/// If the contracted indices are distinguishable (for example in their dimension),
1032-
/// you can provide an optional group marker for each index using `index_group`.
1032+
/// you can provide a group marker as the second element in the tuple of the index
1033+
/// specification.
10331034
/// This makes sure that an index will not be renamed to an index from a different group.
10341035
///
10351036
/// Example
@@ -1042,23 +1043,25 @@ pub trait AtomCore {
10421043
/// let _ = Symbol::new_with_attributes("fc", &[FunctionAttribute::Cyclesymmetric]).unwrap();
10431044
/// let a = Atom::parse("fs(mu2,mu3)*fc(mu4,mu2,k1,mu4,k1,mu3)").unwrap();
10441045
///
1045-
/// let mu1 = Atom::parse("mu1").unwrap();
1046-
/// let mu2 = Atom::parse("mu2").unwrap();
1047-
/// let mu3 = Atom::parse("mu3").unwrap();
1048-
/// let mu4 = Atom::parse("mu4").unwrap();
1046+
/// let mu1 = (Atom::parse("mu1").unwrap(), 0);
1047+
/// let mu2 = (Atom::parse("mu2").unwrap(), 0);
1048+
/// let mu3 = (Atom::parse("mu3").unwrap(), 0);
1049+
/// let mu4 = (Atom::parse("mu4").unwrap(), 0);
10491050
///
1050-
/// let r = a.canonize_tensors(&[mu1.as_view(), mu2.as_view(), mu3.as_view(), mu4.as_view()], None).unwrap();
1051+
/// let r = a.canonize_tensors(&[mu1, mu2, mu3 ,mu4]).unwrap();
10511052
/// println!("{}", r);
10521053
/// # }
10531054
/// ```
10541055
/// yields `fs(mu1,mu2)*fc(mu1,k1,mu3,k1,mu2,mu3)`.
1055-
fn canonize_tensors(
1056+
fn canonize_tensors<T: AtomCore, G: Ord + std::hash::Hash>(
10561057
&self,
1057-
contracted_indices: &[AtomView],
1058-
index_group: Option<&[AtomView]>,
1058+
indices: &[(T, G)],
10591059
) -> Result<Atom, String> {
1060-
self.as_atom_view()
1061-
.canonize_tensors(contracted_indices, index_group)
1060+
let indices = indices
1061+
.iter()
1062+
.map(|(i, g)| (i.as_atom_view(), g))
1063+
.collect::<Vec<_>>();
1064+
self.as_atom_view().canonize_tensors(&indices)
10621065
}
10631066

10641067
fn to_pattern(&self) -> Pattern {
@@ -1318,16 +1321,16 @@ pub trait AtomCore {
13181321
/// let mut iter = expr.pattern_match(&pattern, None, None);
13191322
/// let result = iter.next().unwrap();
13201323
/// assert_eq!(
1321-
/// result.get(&Symbol::new("x_")).unwrap().to_atom(),
1322-
/// Atom::new_num(1)
1324+
/// result.get(&Symbol::new("x_")).unwrap(),
1325+
/// &Atom::new_num(1)
13231326
/// );
13241327
/// ```
1325-
fn pattern_match<'a>(
1328+
fn pattern_match<'a: 'b, 'b>(
13261329
&'a self,
1327-
pattern: &'a Pattern,
1328-
conditions: Option<&'a Condition<PatternRestriction>>,
1329-
settings: Option<&'a MatchSettings>,
1330-
) -> PatternAtomTreeIterator<'a, 'a> {
1330+
pattern: &'b Pattern,
1331+
conditions: Option<&'b Condition<PatternRestriction>>,
1332+
settings: Option<&'b MatchSettings>,
1333+
) -> PatternAtomTreeIterator<'a, 'b> {
13311334
PatternAtomTreeIterator::new(pattern, self.as_atom_view(), conditions, settings)
13321335
}
13331336
}

src/coefficient.rs

+9
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use crate::{
2727
atom::AtomField,
2828
finite_field::{
2929
FiniteField, FiniteFieldCore, FiniteFieldElement, FiniteFieldWorkspace, ToFiniteField,
30+
Zp64,
3031
},
3132
float::{Float, NumericalFloatLike, Real, SingleFloat},
3233
integer::{Integer, IntegerRing, Z},
@@ -63,6 +64,14 @@ pub enum Coefficient {
6364
RationalPolynomial(RationalPolynomial<IntegerRing, u16>),
6465
}
6566

67+
impl Coefficient {
68+
/// Construct a coefficient from a finite field element.
69+
pub fn from_finite_field(field: Zp64, element: FiniteFieldElement<u64>) -> Self {
70+
let index = State::get_or_insert_finite_field(field);
71+
Coefficient::FiniteField(element, index)
72+
}
73+
}
74+
6675
impl From<i64> for Coefficient {
6776
fn from(value: i64) -> Self {
6877
Coefficient::Rational(value.into())

src/domains.rs

+13-2
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ pub trait Ring: Clone + PartialEq + Eq + Hash + Debug + Display {
141141
fn zero(&self) -> Self::Element;
142142
fn one(&self) -> Self::Element;
143143
/// Return the nth element by computing `n * 1`.
144-
fn nth(&self, n: u64) -> Self::Element;
144+
fn nth(&self, n: Integer) -> Self::Element;
145145
fn pow(&self, b: &Self::Element, e: u64) -> Self::Element;
146146
fn is_zero(a: &Self::Element) -> bool;
147147
fn is_one(&self, a: &Self::Element) -> bool;
@@ -151,6 +151,10 @@ pub trait Ring: Clone + PartialEq + Eq + Hash + Debug + Display {
151151
/// The number of elements in the ring. 0 is used for infinite rings.
152152
fn size(&self) -> Integer;
153153

154+
/// Return the result of dividing `a` by `b`, if possible and if the result is unique.
155+
/// For example, in [Z](type@integer::Z), `4/2` is possible but `3/2` is not.
156+
fn try_div(&self, a: &Self::Element, b: &Self::Element) -> Option<Self::Element>;
157+
154158
fn sample(&self, rng: &mut impl rand::RngCore, range: (i64, i64)) -> Self::Element;
155159
/// Format a ring element with custom [PrintOptions] and [PrintState].
156160
fn format<W: std::fmt::Write>(
@@ -299,7 +303,7 @@ impl<R: Ring, C: Clone + Borrow<R>> Ring for WrappedRingElement<R, C> {
299303
}
300304
}
301305

302-
fn nth(&self, n: u64) -> Self::Element {
306+
fn nth(&self, n: Integer) -> Self::Element {
303307
WrappedRingElement {
304308
ring: self.ring.clone(),
305309
element: self.ring().nth(n),
@@ -349,6 +353,13 @@ impl<R: Ring, C: Clone + Borrow<R>> Ring for WrappedRingElement<R, C> {
349353
) -> Result<bool, Error> {
350354
self.ring().format(&element.element, opts, state, f)
351355
}
356+
357+
fn try_div(&self, a: &Self::Element, b: &Self::Element) -> Option<Self::Element> {
358+
Some(WrappedRingElement {
359+
ring: self.ring.clone(),
360+
element: self.ring().try_div(&a.element, &b.element)?,
361+
})
362+
}
352363
}
353364

354365
impl<R: Ring, C: Clone + Borrow<R>> Debug for WrappedRingElement<R, C> {

0 commit comments

Comments
 (0)