From cc5169cad653a9984ba49c4927eff616686c5a1c Mon Sep 17 00:00:00 2001 From: Ben Ruijl Date: Sun, 24 Nov 2024 10:38:38 +0100 Subject: [PATCH 1/6] Allow negative exponents in polynomials - Improve speed of polynomial multiplication - Support LaTeX output for series --- src/api/python.rs | 69 ++ src/domains/algebraic_number.rs | 8 +- src/domains/factorized_rational_polynomial.rs | 54 +- src/domains/rational_polynomial.rs | 67 +- src/parser.rs | 4 +- src/poly.rs | 543 ++++++++-- src/poly/evaluate.rs | 10 +- src/poly/factor.rs | 24 +- src/poly/gcd.rs | 31 +- src/poly/polynomial.rs | 997 ++++++++++-------- src/poly/series.rs | 78 +- src/poly/univariate.rs | 17 +- src/solve.rs | 6 +- symbolica.pyi | 26 + 14 files changed, 1280 insertions(+), 654 deletions(-) diff --git a/src/api/python.rs b/src/api/python.rs index 829bb760..36bceb74 100644 --- a/src/api/python.rs +++ b/src/api/python.rs @@ -4941,6 +4941,75 @@ impl PythonSeries { Ok(format!("{}", self.series)) } + /// Convert the series into a LaTeX string. + pub fn to_latex(&self) -> PyResult { + Ok(format!( + "$${}$$", + self.series + .format_string(&PrintOptions::latex(), PrintState::new()) + )) + } + + /// Convert the expression into a human-readable string, with tunable settings. + /// + /// Examples + /// -------- + /// >>> a = Expression.parse('128378127123 z^(2/3)*w^2/x/y + y^4 + z^34 + x^(x+2)+3/5+f(x,x^2)') + /// >>> print(a.format(number_thousands_separator='_', multiplication_operator=' ')) + #[pyo3(signature = + (terms_on_new_line = false, + color_top_level_sum = true, + color_builtin_symbols = true, + print_finite_field = true, + symmetric_representation_for_finite_field = false, + explicit_rational_polynomial = false, + number_thousands_separator = None, + multiplication_operator = '*', + double_star_for_exponentiation = false, + square_brackets_for_function = false, + num_exp_as_superscript = true, + latex = false, + precision = None) + )] + pub fn format( + &self, + terms_on_new_line: bool, + color_top_level_sum: bool, + color_builtin_symbols: bool, + print_finite_field: bool, + symmetric_representation_for_finite_field: bool, + explicit_rational_polynomial: bool, + number_thousands_separator: Option, + multiplication_operator: char, + double_star_for_exponentiation: bool, + square_brackets_for_function: bool, + num_exp_as_superscript: bool, + latex: bool, + precision: Option, + ) -> PyResult { + Ok(format!( + "{}", + self.series.format_string( + &PrintOptions { + terms_on_new_line, + color_top_level_sum, + color_builtin_symbols, + print_finite_field, + symmetric_representation_for_finite_field, + explicit_rational_polynomial, + number_thousands_separator, + multiplication_operator, + double_star_for_exponentiation, + square_brackets_for_function, + num_exp_as_superscript, + latex, + precision, + }, + PrintState::new() + ) + )) + } + pub fn sin(&self) -> PyResult { Ok(Self { series: self diff --git a/src/domains/algebraic_number.rs b/src/domains/algebraic_number.rs index 00e913d5..a13116ce 100644 --- a/src/domains/algebraic_number.rs +++ b/src/domains/algebraic_number.rs @@ -6,8 +6,8 @@ use crate::{ coefficient::ConvertToRing, combinatorics::CombinationIterator, poly::{ - factor::Factorize, gcd::PolynomialGCD, polynomial::MultivariatePolynomial, Exponent, - Variable, + factor::Factorize, gcd::PolynomialGCD, polynomial::MultivariatePolynomial, + PositiveExponent, Variable, }, }; @@ -577,7 +577,9 @@ impl> AlgebraicExtension { } } -impl, E: Exponent> MultivariatePolynomial, E> { +impl, E: PositiveExponent> + MultivariatePolynomial, E> +{ /// Get the norm of a non-constant square-free polynomial `f` in the algebraic number field. pub fn norm(&self) -> MultivariatePolynomial { self.norm_impl().3 diff --git a/src/domains/factorized_rational_polynomial.rs b/src/domains/factorized_rational_polynomial.rs index 5a4e04aa..102fb5a1 100644 --- a/src/domains/factorized_rational_polynomial.rs +++ b/src/domains/factorized_rational_polynomial.rs @@ -9,8 +9,8 @@ use std::{ use crate::{ poly::{ - factor::Factorize, gcd::PolynomialGCD, polynomial::MultivariatePolynomial, Exponent, - Variable, + factor::Factorize, gcd::PolynomialGCD, polynomial::MultivariatePolynomial, + PositiveExponent, Variable, }, printer::{PrintOptions, PrintState}, }; @@ -23,13 +23,13 @@ use super::{ }; #[derive(Clone, PartialEq, Eq, Hash, Debug)] -pub struct FactorizedRationalPolynomialField { +pub struct FactorizedRationalPolynomialField { ring: R, var_map: Arc>, _phantom_exp: PhantomData, } -impl FactorizedRationalPolynomialField { +impl FactorizedRationalPolynomialField { pub fn new( coeff_ring: R, var_map: Arc>, @@ -52,7 +52,7 @@ impl FactorizedRationalPolynomialField { } } -pub trait FromNumeratorAndFactorizedDenominator { +pub trait FromNumeratorAndFactorizedDenominator { /// Construct a rational polynomial from a numerator and a factorized denominator. /// An empty denominator means a denominator of 1. fn from_num_den( @@ -64,21 +64,21 @@ pub trait FromNumeratorAndFactorizedDenominator } #[derive(Clone, PartialEq, Eq, Hash, Debug)] -pub struct FactorizedRationalPolynomial { +pub struct FactorizedRationalPolynomial { pub numerator: MultivariatePolynomial, pub numer_coeff: R::Element, pub denom_coeff: R::Element, pub denominators: Vec<(MultivariatePolynomial, usize)>, // TODO: sort factors? } -impl InternalOrdering for FactorizedRationalPolynomial { +impl InternalOrdering for FactorizedRationalPolynomial { /// An ordering of rational polynomials that has no intuitive meaning. fn internal_cmp(&self, _other: &Self) -> Ordering { todo!() } } -impl FactorizedRationalPolynomial { +impl FactorizedRationalPolynomial { pub fn new(field: &R, var_map: Arc>) -> FactorizedRationalPolynomial { let num = MultivariatePolynomial::new(field, None, var_map); @@ -147,7 +147,7 @@ impl FactorizedRationalPolynomial { } } -impl SelfRing for FactorizedRationalPolynomial { +impl SelfRing for FactorizedRationalPolynomial { fn is_zero(&self) -> bool { self.is_zero() } @@ -385,7 +385,7 @@ impl SelfRing for FactorizedRationalPolynomial { } } -impl FromNumeratorAndFactorizedDenominator +impl FromNumeratorAndFactorizedDenominator for FactorizedRationalPolynomial { fn from_num_den( @@ -428,7 +428,7 @@ impl FromNumeratorAndFactorizedDenominator FromNumeratorAndFactorizedDenominator +impl FromNumeratorAndFactorizedDenominator for FactorizedRationalPolynomial { fn from_num_den( @@ -534,7 +534,7 @@ impl FromNumeratorAndFactorizedDenominator +impl FromNumeratorAndFactorizedDenominator, FiniteField, E> for FactorizedRationalPolynomial, E> where @@ -618,7 +618,7 @@ where } } -impl, E: Exponent> FactorizedRationalPolynomial +impl, E: PositiveExponent> FactorizedRationalPolynomial where Self: FromNumeratorAndFactorizedDenominator, MultivariatePolynomial: Factorize, @@ -644,7 +644,7 @@ where } } -impl, E: Exponent> FactorizedRationalPolynomial +impl, E: PositiveExponent> FactorizedRationalPolynomial where Self: FromNumeratorAndFactorizedDenominator, { @@ -727,20 +727,20 @@ where } } -impl Display for FactorizedRationalPolynomial { +impl Display for FactorizedRationalPolynomial { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.format(&PrintOptions::from_fmt(f), PrintState::from_fmt(f), f) .map(|_| ()) } } -impl Display for FactorizedRationalPolynomialField { +impl Display for FactorizedRationalPolynomialField { fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { Ok(()) } } -impl, E: Exponent> Ring +impl, E: PositiveExponent> Ring for FactorizedRationalPolynomialField where FactorizedRationalPolynomial: FromNumeratorAndFactorizedDenominator, @@ -864,7 +864,7 @@ where } } -impl, E: Exponent> EuclideanDomain +impl, E: PositiveExponent> EuclideanDomain for FactorizedRationalPolynomialField where FactorizedRationalPolynomial: FromNumeratorAndFactorizedDenominator, @@ -888,7 +888,7 @@ where } } -impl, E: Exponent> Field +impl, E: PositiveExponent> Field for FactorizedRationalPolynomialField where FactorizedRationalPolynomial: FromNumeratorAndFactorizedDenominator, @@ -907,7 +907,7 @@ where } } -impl<'a, 'b, R: EuclideanDomain + PolynomialGCD + PolynomialGCD, E: Exponent> +impl<'a, 'b, R: EuclideanDomain + PolynomialGCD + PolynomialGCD, E: PositiveExponent> Add<&'a FactorizedRationalPolynomial> for &'b FactorizedRationalPolynomial where FactorizedRationalPolynomial: FromNumeratorAndFactorizedDenominator, @@ -1014,7 +1014,8 @@ where } } -impl, E: Exponent> Sub for FactorizedRationalPolynomial +impl, E: PositiveExponent> Sub + for FactorizedRationalPolynomial where FactorizedRationalPolynomial: FromNumeratorAndFactorizedDenominator, { @@ -1025,7 +1026,7 @@ where } } -impl<'a, 'b, R: EuclideanDomain + PolynomialGCD, E: Exponent> +impl<'a, 'b, R: EuclideanDomain + PolynomialGCD, E: PositiveExponent> Sub<&'a FactorizedRationalPolynomial> for &'b FactorizedRationalPolynomial where FactorizedRationalPolynomial: FromNumeratorAndFactorizedDenominator, @@ -1037,7 +1038,8 @@ where } } -impl, E: Exponent> Neg for FactorizedRationalPolynomial +impl, E: PositiveExponent> Neg + for FactorizedRationalPolynomial where FactorizedRationalPolynomial: FromNumeratorAndFactorizedDenominator, { @@ -1052,7 +1054,7 @@ where } } -impl<'a, 'b, R: EuclideanDomain + PolynomialGCD, E: Exponent> +impl<'a, 'b, R: EuclideanDomain + PolynomialGCD, E: PositiveExponent> Mul<&'a FactorizedRationalPolynomial> for &'b FactorizedRationalPolynomial { type Output = FactorizedRationalPolynomial; @@ -1145,7 +1147,7 @@ impl<'a, 'b, R: EuclideanDomain + PolynomialGCD, E: Exponent> } } -impl<'a, 'b, R: EuclideanDomain + PolynomialGCD, E: Exponent> +impl<'a, 'b, R: EuclideanDomain + PolynomialGCD, E: PositiveExponent> Div<&'a FactorizedRationalPolynomial> for &'b FactorizedRationalPolynomial where FactorizedRationalPolynomial: FromNumeratorAndFactorizedDenominator, @@ -1256,7 +1258,7 @@ where } } -impl, E: Exponent> FactorizedRationalPolynomial +impl, E: PositiveExponent> FactorizedRationalPolynomial where FactorizedRationalPolynomial: FromNumeratorAndFactorizedDenominator, MultivariatePolynomial: Factorize, diff --git a/src/domains/rational_polynomial.rs b/src/domains/rational_polynomial.rs index ee95ee1a..2c7e91bc 100644 --- a/src/domains/rational_polynomial.rs +++ b/src/domains/rational_polynomial.rs @@ -12,7 +12,7 @@ use ahash::HashMap; use crate::{ poly::{ factor::Factorize, gcd::PolynomialGCD, polynomial::MultivariatePolynomial, - univariate::UnivariatePolynomial, Exponent, Variable, + univariate::UnivariatePolynomial, PositiveExponent, Variable, }, printer::{PrintOptions, PrintState}, }; @@ -25,12 +25,12 @@ use super::{ }; #[derive(Clone, PartialEq, Eq, Hash, Debug)] -pub struct RationalPolynomialField { +pub struct RationalPolynomialField { ring: R, _phantom_exp: PhantomData, } -impl RationalPolynomialField { +impl RationalPolynomialField { pub fn new(coeff_ring: R) -> RationalPolynomialField { RationalPolynomialField { ring: coeff_ring, @@ -46,7 +46,7 @@ impl RationalPolynomialField { } } -pub trait FromNumeratorAndDenominator { +pub trait FromNumeratorAndDenominator { fn from_num_den( num: MultivariatePolynomial, den: MultivariatePolynomial, @@ -56,12 +56,12 @@ pub trait FromNumeratorAndDenominator { } #[derive(Clone, PartialEq, Eq, Hash, Debug)] -pub struct RationalPolynomial { +pub struct RationalPolynomial { pub numerator: MultivariatePolynomial, pub denominator: MultivariatePolynomial, } -impl InternalOrdering for RationalPolynomial { +impl InternalOrdering for RationalPolynomial { /// An ordering of rational polynomials that has no intuitive meaning. fn internal_cmp(&self, other: &Self) -> Ordering { self.numerator @@ -81,7 +81,7 @@ impl InternalOrdering for RationalPolynomial { } } -impl From> for RationalPolynomial +impl From> for RationalPolynomial where Self: FromNumeratorAndDenominator, { @@ -92,7 +92,7 @@ where } } -impl RationalPolynomial +impl RationalPolynomial where Self: FromNumeratorAndDenominator, { @@ -120,7 +120,7 @@ where } } -impl RationalPolynomial { +impl RationalPolynomial { pub fn new(field: &R, var_map: Arc>) -> RationalPolynomial { let num = MultivariatePolynomial::new(field, None, var_map); let den = num.one(); @@ -165,7 +165,7 @@ impl RationalPolynomial { } } -impl SelfRing for RationalPolynomial { +impl SelfRing for RationalPolynomial { fn is_zero(&self) -> bool { self.is_zero() } @@ -244,7 +244,7 @@ impl SelfRing for RationalPolynomial { } } -impl FromNumeratorAndDenominator +impl FromNumeratorAndDenominator for RationalPolynomial { fn from_num_den( @@ -293,7 +293,7 @@ impl FromNumeratorAndDenominator } } -impl FromNumeratorAndDenominator +impl FromNumeratorAndDenominator for RationalPolynomial { fn from_num_den( @@ -333,7 +333,7 @@ impl FromNumeratorAndDenominator } } -impl +impl FromNumeratorAndDenominator, FiniteField, E> for RationalPolynomial, E> where @@ -378,7 +378,7 @@ where } } -impl, E: Exponent> RationalPolynomial +impl, E: PositiveExponent> RationalPolynomial where Self: FromNumeratorAndDenominator, { @@ -550,20 +550,21 @@ where } } -impl Display for RationalPolynomial { +impl Display for RationalPolynomial { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.format(&PrintOptions::from_fmt(f), PrintState::from_fmt(f), f) .map(|_| ()) } } -impl Display for RationalPolynomialField { +impl Display for RationalPolynomialField { fn fmt(&self, _: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { Ok(()) } } -impl, E: Exponent> Ring for RationalPolynomialField +impl, E: PositiveExponent> Ring + for RationalPolynomialField where RationalPolynomial: FromNumeratorAndDenominator, { @@ -684,7 +685,7 @@ where } } -impl, E: Exponent> EuclideanDomain +impl, E: PositiveExponent> EuclideanDomain for RationalPolynomialField where RationalPolynomial: FromNumeratorAndDenominator, @@ -705,7 +706,8 @@ where } } -impl, E: Exponent> Field for RationalPolynomialField +impl, E: PositiveExponent> Field + for RationalPolynomialField where RationalPolynomial: FromNumeratorAndDenominator, { @@ -722,7 +724,7 @@ where } } -impl<'a, 'b, R: EuclideanDomain + PolynomialGCD + PolynomialGCD, E: Exponent> +impl<'a, 'b, R: EuclideanDomain + PolynomialGCD + PolynomialGCD, E: PositiveExponent> Add<&'a RationalPolynomial> for &'b RationalPolynomial where RationalPolynomial: FromNumeratorAndDenominator, @@ -774,7 +776,7 @@ where } } -impl, E: Exponent> Sub for RationalPolynomial +impl, E: PositiveExponent> Sub for RationalPolynomial where RationalPolynomial: FromNumeratorAndDenominator, { @@ -785,8 +787,8 @@ where } } -impl<'a, 'b, R: EuclideanDomain + PolynomialGCD, E: Exponent> Sub<&'a RationalPolynomial> - for &'b RationalPolynomial +impl<'a, 'b, R: EuclideanDomain + PolynomialGCD, E: PositiveExponent> + Sub<&'a RationalPolynomial> for &'b RationalPolynomial where RationalPolynomial: FromNumeratorAndDenominator, { @@ -797,7 +799,7 @@ where } } -impl, E: Exponent> Neg for RationalPolynomial { +impl, E: PositiveExponent> Neg for RationalPolynomial { type Output = Self; fn neg(self) -> Self::Output { RationalPolynomial { @@ -807,8 +809,8 @@ impl, E: Exponent> Neg for RationalPolynom } } -impl<'a, 'b, R: EuclideanDomain + PolynomialGCD, E: Exponent> Mul<&'a RationalPolynomial> - for &'b RationalPolynomial +impl<'a, 'b, R: EuclideanDomain + PolynomialGCD, E: PositiveExponent> + Mul<&'a RationalPolynomial> for &'b RationalPolynomial where RationalPolynomial: FromNumeratorAndDenominator, { @@ -851,8 +853,8 @@ where } } -impl<'a, 'b, R: EuclideanDomain + PolynomialGCD, E: Exponent> Div<&'a RationalPolynomial> - for &'b RationalPolynomial +impl<'a, 'b, R: EuclideanDomain + PolynomialGCD, E: PositiveExponent> + Div<&'a RationalPolynomial> for &'b RationalPolynomial where RationalPolynomial: FromNumeratorAndDenominator, { @@ -864,7 +866,7 @@ where } } -impl, E: Exponent> RationalPolynomial +impl, E: PositiveExponent> RationalPolynomial where RationalPolynomial: FromNumeratorAndDenominator, { @@ -897,7 +899,8 @@ where } } -impl, E: Exponent> Derivable for RationalPolynomialField +impl, E: PositiveExponent> Derivable + for RationalPolynomialField where RationalPolynomial: FromNumeratorAndDenominator, { @@ -910,7 +913,7 @@ where } } -impl, E: Exponent> RationalPolynomial +impl, E: PositiveExponent> RationalPolynomial where RationalPolynomial: FromNumeratorAndDenominator, MultivariatePolynomial: Factorize, @@ -999,7 +1002,7 @@ where } } -impl, E: Exponent> RationalPolynomial +impl, E: PositiveExponent> RationalPolynomial where RationalPolynomial: FromNumeratorAndDenominator, MultivariatePolynomial: Factorize, diff --git a/src/parser.rs b/src/parser.rs index 63b53a36..322e223e 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -10,7 +10,7 @@ use crate::{ atom::Atom, coefficient::{Coefficient, ConvertToRing}, domains::{float::Float, integer::Integer, Ring}, - poly::{polynomial::MultivariatePolynomial, Exponent, Variable}, + poly::{polynomial::MultivariatePolynomial, PositiveExponent, Variable}, state::{State, Workspace}, LicenseManager, }; @@ -1085,7 +1085,7 @@ impl Token { /// A special routine that can parse a polynomial written in expanded form, /// where the coefficient comes first. - pub fn parse_polynomial<'a, R: Ring + ConvertToRing, E: Exponent>( + pub fn parse_polynomial<'a, R: Ring + ConvertToRing, E: PositiveExponent>( mut input: &'a [u8], var_map: &Arc>, var_name_map: &[SmartString], diff --git a/src/poly.rs b/src/poly.rs index 1e8968e5..6ba3a0aa 100644 --- a/src/poly.rs +++ b/src/poly.rs @@ -27,7 +27,7 @@ use crate::domains::factorized_rational_polynomial::{ }; use crate::domains::integer::Integer; use crate::domains::rational_polynomial::{FromNumeratorAndDenominator, RationalPolynomial}; -use crate::domains::{EuclideanDomain, Ring, RingPrinter}; +use crate::domains::{EuclideanDomain, Ring, SelfRing}; use crate::parser::{Operator, Token}; use crate::printer::{PrintOptions, PrintState}; use crate::state::{State, Workspace}; @@ -55,13 +55,14 @@ pub trait Exponent: + Copy + PartialEq + Eq + + TryFrom { fn zero() -> Self; fn one() -> Self; - /// Convert the exponent to `u32`. This is always possible, as `u32` is the largest supported exponent type. - fn to_u32(&self) -> u32; - /// Convert from `u32`. This function may panic if the exponent is too large. - fn from_u32(n: u32) -> Self; + /// Convert the exponent to `i32`. This is always possible, as `i32` is the largest supported exponent type. + fn to_i32(&self) -> i32; + /// Convert from `i32`. This function may panic if the exponent is too large. + fn from_i32(n: i32) -> Self; fn is_zero(&self) -> bool; fn checked_add(&self, other: &Self) -> Option; fn gcd(&self, other: &Self) -> Self; @@ -93,13 +94,16 @@ impl Exponent for u32 { } #[inline] - fn to_u32(&self) -> u32 { - *self + fn to_i32(&self) -> i32 { + *self as i32 } #[inline] - fn from_u32(n: u32) -> Self { - n + fn from_i32(n: i32) -> Self { + if n < 0 { + panic!("Exponent {} is negative", n); + } + n as u32 } #[inline] @@ -109,7 +113,7 @@ impl Exponent for u32 { #[inline] fn checked_add(&self, other: &Self) -> Option { - u32::checked_add(*self, *other) + i32::checked_add(*self as i32, *other as i32).map(|x| x as u32) } #[inline] @@ -150,6 +154,77 @@ impl Exponent for u32 { } } +impl Exponent for i32 { + #[inline] + fn zero() -> Self { + 0 + } + + #[inline] + fn one() -> Self { + 1 + } + + #[inline] + fn to_i32(&self) -> i32 { + *self + } + + #[inline] + fn from_i32(n: i32) -> Self { + n + } + + #[inline] + fn is_zero(&self) -> bool { + *self == 0 + } + + #[inline] + fn checked_add(&self, other: &Self) -> Option { + i32::checked_add(*self, *other) + } + + #[inline] + fn gcd(&self, other: &Self) -> Self { + utils::gcd_signed(*self as i64, *other as i64) as Self + } + + // Pack a list of positive exponents. + fn pack(list: &[Self]) -> u64 { + let mut num: u64 = 0; + for x in list.iter().rev() { + num = (num << 8) + (*x as u8 as u64); + } + num.swap_bytes() + } + + fn unpack(mut n: u64, out: &mut [Self]) { + n = n.swap_bytes(); + let s = unsafe { std::slice::from_raw_parts(&n as *const u64 as *const u8, out.len()) }; + for (o, ss) in out.iter_mut().zip(s) { + *o = *ss as i32; + } + } + + // Pack a list of positive exponents. + fn pack_u16(list: &[Self]) -> u64 { + let mut num: u64 = 0; + for x in list.iter().rev() { + num = (num << 16) + ((*x as u16).to_be() as u64); + } + num.swap_bytes() + } + + fn unpack_u16(mut n: u64, out: &mut [Self]) { + n = n.swap_bytes(); + let s = unsafe { std::slice::from_raw_parts(&n as *const u64 as *const u16, out.len()) }; + for (o, ss) in out.iter_mut().zip(s) { + *o = ss.swap_bytes() as i32; + } + } +} + impl Exponent for u16 { #[inline] fn zero() -> Self { @@ -162,13 +237,13 @@ impl Exponent for u16 { } #[inline] - fn to_u32(&self) -> u32 { - *self as u32 + fn to_i32(&self) -> i32 { + *self as i32 } #[inline] - fn from_u32(n: u32) -> Self { - if n <= u16::MAX as u32 { + fn from_i32(n: i32) -> Self { + if n >= 0 && n <= u16::MAX as i32 { n as u16 } else { panic!("Exponent {} too large for u16", n); @@ -223,6 +298,81 @@ impl Exponent for u16 { } } +impl Exponent for i16 { + #[inline] + fn zero() -> Self { + 0 + } + + #[inline] + fn one() -> Self { + 1 + } + + #[inline] + fn to_i32(&self) -> i32 { + *self as i32 + } + + #[inline] + fn from_i32(n: i32) -> Self { + if n >= i16::MIN as i32 && n <= i16::MAX as i32 { + n as i16 + } else { + panic!("Exponent {} too large for i16", n); + } + } + + #[inline] + fn is_zero(&self) -> bool { + *self == 0 + } + + #[inline] + fn checked_add(&self, other: &Self) -> Option { + i16::checked_add(*self, *other) + } + + #[inline] + fn gcd(&self, other: &Self) -> Self { + utils::gcd_signed(*self as i64, *other as i64) as Self + } + + // Pack a list of positive exponents. + fn pack(list: &[Self]) -> u64 { + let mut num: u64 = 0; + for x in list.iter().rev() { + num = (num << 8) + (*x as u8 as u64); + } + num.swap_bytes() + } + + fn unpack(mut n: u64, out: &mut [Self]) { + n = n.swap_bytes(); + let s = unsafe { std::slice::from_raw_parts(&n as *const u64 as *const u8, out.len()) }; + for (o, ss) in out.iter_mut().zip(s) { + *o = *ss as i16; + } + } + + // Pack a list of positive exponents. + fn pack_u16(list: &[Self]) -> u64 { + let mut num: u64 = 0; + for x in list.iter().rev() { + num = (num << 16) + ((*x as u16).to_be() as u64); + } + num.swap_bytes() + } + + fn unpack_u16(mut n: u64, out: &mut [Self]) { + n = n.swap_bytes(); + let s = unsafe { std::slice::from_raw_parts(&n as *const u64 as *const u16, out.len()) }; + for (o, ss) in out.iter_mut().zip(s) { + *o = ss.swap_bytes() as i16; + } + } +} + /// An exponent limited to 255 for efficiency impl Exponent for u8 { #[inline] @@ -236,13 +386,13 @@ impl Exponent for u8 { } #[inline] - fn to_u32(&self) -> u32 { - *self as u32 + fn to_i32(&self) -> i32 { + *self as i32 } #[inline] - fn from_u32(n: u32) -> Self { - if n <= u8::MAX as u32 { + fn from_i32(n: i32) -> Self { + if n >= 0 && n <= u8::MAX as i32 { n as u8 } else { panic!("Exponent {} too large for u8", n); @@ -295,6 +445,150 @@ impl Exponent for u8 { } } +impl Exponent for i8 { + #[inline] + fn zero() -> Self { + 0 + } + + #[inline] + fn one() -> Self { + 1 + } + + #[inline] + fn to_i32(&self) -> i32 { + *self as i32 + } + + #[inline] + fn from_i32(n: i32) -> Self { + if n >= i8::MIN as i32 && n <= i8::MAX as i32 { + n as i8 + } else { + panic!("Exponent {} too large for i8", n); + } + } + + #[inline] + fn is_zero(&self) -> bool { + *self == 0 + } + + #[inline] + fn checked_add(&self, other: &Self) -> Option { + i8::checked_add(*self, *other) + } + + #[inline] + fn gcd(&self, other: &Self) -> Self { + utils::gcd_signed(*self as i64, *other as i64) as Self + } + + // Pack a list of positive exponents. + fn pack(list: &[Self]) -> u64 { + let mut num: u64 = 0; + for x in list.iter().rev() { + num = (num << 8) + (*x as u8 as u64); + } + num.swap_bytes() + } + + fn unpack(mut n: u64, out: &mut [Self]) { + n = n.swap_bytes(); + let s = unsafe { std::slice::from_raw_parts(&n as *const u64 as *const u8, out.len()) }; + for (o, ss) in out.iter_mut().zip(s) { + *o = *ss as i8; + } + } + + // Pack a list of positive exponents. + fn pack_u16(list: &[Self]) -> u64 { + let mut num: u64 = 0; + for x in list.iter().rev() { + num = (num << 16) + ((*x as u16).to_be() as u64); + } + num.swap_bytes() + } + + fn unpack_u16(mut n: u64, out: &mut [Self]) { + n = n.swap_bytes(); + let s = unsafe { std::slice::from_raw_parts(&n as *const u64 as *const u16, out.len()) }; + for (o, ss) in out.iter_mut().zip(s) { + *o = ss.swap_bytes() as i8; + } + } +} + +pub trait PositiveExponent: Exponent { + fn from_u32(n: u32) -> Self { + if n > i32::MAX as u32 { + panic!("Exponent {} too large for i32", n); + } + Self::from_i32(n as i32) + } + fn to_u32(&self) -> u32; +} + +impl PositiveExponent for u8 { + #[inline] + fn to_u32(&self) -> u32 { + *self as u32 + } +} +impl PositiveExponent for u16 { + #[inline] + fn to_u32(&self) -> u32 { + *self as u32 + } +} +impl PositiveExponent for u32 { + #[inline] + fn to_u32(&self) -> u32 { + *self + } +} + +macro_rules! to_positive { + ($neg: ty, $pos: ty) => { + impl MultivariatePolynomial { + /// Convert a polynomial with positive exponents to its unsigned type equivalent + /// by a safe and almost zero-cost cast. + /// + /// Panics if the polynomial has negative exponents. + pub fn to_positive(self) -> MultivariatePolynomial { + if !self.is_polynomial() { + panic!("Polynomial has negative exponent"); + } + + unsafe { std::mem::transmute_copy(&std::mem::ManuallyDrop::new(self)) } + } + } + + impl MultivariatePolynomial { + /// Convert a polynomial with positive exponents to its signed type equivalent + /// by a safe and almost zero-cost cast. + /// + /// Panics if the polynomial has exponents that are too large. + pub fn to_signed(self) -> MultivariatePolynomial { + if self + .exponents + .iter() + .any(|x| x.to_i32() > <$neg>::MAX as i32) + { + panic!("Polynomial has exponents that are too large"); + } + + unsafe { std::mem::transmute_copy(&std::mem::ManuallyDrop::new(self)) } + } + } + }; +} + +to_positive!(i8, u8); +to_positive!(i16, u16); +to_positive!(i32, u32); + /// A well-order of monomials. pub trait MonomialOrder: Clone { fn cmp(a: &[E], b: &[E]) -> Ordering; @@ -344,7 +638,7 @@ impl MonomialOrder for LexOrder { /// A polynomial variable. It is either a (global) symbol /// a temporary variable (for internal use), an array entry, /// a function or any other non-polynomial part. -#[derive(Clone, Hash, PartialEq, Eq, Debug)] +#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Debug)] pub enum Variable { Symbol(Symbol), Temporary(usize), // a temporary variable, for internal use @@ -394,19 +688,11 @@ impl Variable { } } - pub fn to_string_with_state(&self, state: PrintState) -> String { + fn format_string(&self, opts: &PrintOptions, state: PrintState) -> String { match self { Variable::Symbol(v) => State::get_name(*v).to_string(), Variable::Temporary(t) => format!("_TMP_{}", *t), - Variable::Function(_, a) | Variable::Other(a) => format!( - "{}", - RingPrinter { - element: a.as_ref(), - ring: &AtomField::new(), - opts: PrintOptions::default(), - state, - } - ), + Variable::Function(_, a) | Variable::Other(a) => a.format_string(opts, state), } } @@ -462,6 +748,18 @@ impl Atom { self.as_view().to_polynomial(field, var_map) } + /// Convert the atom to a polynomial in specific variables. + /// All other parts will be collected into the coefficient, which + /// is a general expression. + /// + /// This routine does not perform expansions. + pub fn to_polynomial_in_vars( + &self, + var_map: &Arc>, + ) -> MultivariatePolynomial { + self.as_view().to_polynomial_in_vars(var_map) + } + /// Convert the atom to a rational polynomial, optionally in the variable ordering /// specified by `var_map`. If new variables are encountered, they are /// added to the variable map. Similarly, non-rational polynomial parts are automatically @@ -469,7 +767,7 @@ impl Atom { pub fn to_rational_polynomial< R: EuclideanDomain + ConvertToRing, RO: EuclideanDomain + PolynomialGCD, - E: Exponent, + E: PositiveExponent, >( &self, field: &R, @@ -491,7 +789,7 @@ impl Atom { pub fn to_factorized_rational_polynomial< R: EuclideanDomain + ConvertToRing, RO: EuclideanDomain + PolynomialGCD, - E: Exponent, + E: PositiveExponent, >( &self, field: &R, @@ -660,11 +958,11 @@ impl<'a> AtomView<'a> { match exp { AtomView::Num(n) => match n.get_coeff_view() { CoefficientView::Natural(r, _) => { - exponents[var_index] += E::from_u32(r as u32) + exponents[var_index] += E::from_i32(r as i32) } CoefficientView::Large(r) => { exponents[var_index] += - E::from_u32(r.to_rat().numerator_ref().to_i64().unwrap() as u32) + E::from_i32(r.to_rat().numerator_ref().to_i64().unwrap() as i32) } _ => unreachable!(), }, @@ -749,8 +1047,39 @@ impl<'a> AtomView<'a> { if let AtomView::Num(n) = exp { let num_n = n.get_coeff_view(); if let CoefficientView::Natural(nn, nd) = num_n { - if nd == 1 && nn > 0 && nn < u32::MAX as i64 { - return base.to_polynomial_impl(field, var_map).pow(nn as usize); + if nd == 1 { + if nn > 0 && nn < i32::MAX as i64 { + return base.to_polynomial_impl(field, var_map).pow(nn as usize); + } else if nn < 0 && nn > i32::MIN as i64 { + // allow x^-2 as a term if supported by the exponent + if let Ok(e) = (nn as i32).try_into() { + if let AtomView::Var(v) = base { + let s = Variable::Symbol(v.get_symbol()); + if let Some(id) = var_map.iter().position(|v| v == &s) { + let mut exp = vec![E::zero(); var_map.len()]; + exp[id] = e; + return MultivariatePolynomial::new( + field, + None, + var_map.clone(), + ) + .monomial(field.one(), exp); + } else { + let mut var_map = var_map.as_ref().clone(); + var_map.push(s); + let mut exp = vec![E::zero(); var_map.len()]; + exp[var_map.len() - 1] = e; + + return MultivariatePolynomial::new( + field, + None, + Arc::new(var_map), + ) + .monomial(field.one(), exp); + } + } + } + } } } } @@ -829,6 +1158,20 @@ impl<'a> AtomView<'a> { pub fn to_polynomial_in_vars( &self, var_map: &Arc>, + ) -> MultivariatePolynomial { + let poly = MultivariatePolynomial::<_, E>::new(&AtomField::new(), None, var_map.clone()); + self.to_polynomial_in_vars_impl(var_map, &poly) + } + + /// Convert the atom to a polynomial in specific variables. + /// All other parts will be collected into the coefficient, which + /// is a general expression. + /// + /// This routine does not perform expansions. + fn to_polynomial_in_vars_impl( + &self, + var_map: &Arc>, + poly: &MultivariatePolynomial, ) -> MultivariatePolynomial { let field = AtomField::new(); // see if the current term can be cast into a polynomial using a fast routine @@ -837,9 +1180,7 @@ impl<'a> AtomView<'a> { } match self { - AtomView::Num(_) | AtomView::Var(_) => { - MultivariatePolynomial::new(&field, None, var_map.clone()).constant(self.to_owned()) - } + AtomView::Num(_) | AtomView::Var(_) => poly.constant(self.to_owned()), AtomView::Pow(p) => { let (base, exp) = p.get_base_exp(); @@ -847,7 +1188,8 @@ impl<'a> AtomView<'a> { let num_n = n.get_coeff_view(); if let CoefficientView::Natural(nn, nd) = num_n { if nd == 1 && nn > 0 && nn < u32::MAX as i64 { - return base.to_polynomial_in_vars(var_map).pow(nn as usize); + let b = base.to_polynomial_in_vars_impl(var_map, poly); + return b.pow(nn as usize); } } } @@ -858,11 +1200,9 @@ impl<'a> AtomView<'a> { }) { let mut exp = vec![E::zero(); var_map.len()]; exp[id] = E::one(); - MultivariatePolynomial::new(&field, None, var_map.clone()) - .monomial(field.one(), exp) + poly.monomial(field.one(), exp) } else { - MultivariatePolynomial::new(&field, None, var_map.clone()) - .constant(self.to_owned()) + poly.constant(self.to_owned()) } } AtomView::Fun(_) => { @@ -872,28 +1212,23 @@ impl<'a> AtomView<'a> { }) { let mut exp = vec![E::zero(); var_map.len()]; exp[id] = E::one(); - MultivariatePolynomial::new(&field, None, var_map.clone()) - .monomial(field.one(), exp) + poly.monomial(field.one(), exp) } else { - MultivariatePolynomial::new(&field, None, var_map.clone()) - .constant(self.to_owned()) + poly.constant(self.to_owned()) } } AtomView::Mul(m) => { - let mut r = MultivariatePolynomial::new(&field, None, var_map.clone()) - .constant(field.one()); + let mut r = poly.one(); for arg in m { - let mut arg_r = arg.to_polynomial_in_vars(&r.variables); - r.unify_variables(&mut arg_r); + let arg_r = arg.to_polynomial_in_vars_impl(&r.variables, poly); r = &r * &arg_r; } r } AtomView::Add(a) => { - let mut r = MultivariatePolynomial::new(&field, None, var_map.clone()); + let mut r = poly.zero(); for arg in a { - let mut arg_r = arg.to_polynomial_in_vars(&r.variables); - r.unify_variables(&mut arg_r); + let arg_r = arg.to_polynomial_in_vars_impl(&r.variables, poly); r = &r + &arg_r; } r @@ -908,7 +1243,7 @@ impl<'a> AtomView<'a> { pub fn to_rational_polynomial< R: EuclideanDomain + ConvertToRing, RO: EuclideanDomain + PolynomialGCD, - E: Exponent, + E: PositiveExponent, >( &self, field: &R, @@ -929,7 +1264,7 @@ impl<'a> AtomView<'a> { fn to_rational_polynomial_impl< R: EuclideanDomain + ConvertToRing, RO: EuclideanDomain + PolynomialGCD, - E: Exponent, + E: PositiveExponent, >( &self, field: &R, @@ -1051,7 +1386,7 @@ impl<'a> AtomView<'a> { pub fn to_factorized_rational_polynomial< R: EuclideanDomain + ConvertToRing, RO: EuclideanDomain + PolynomialGCD, - E: Exponent, + E: PositiveExponent, >( &self, field: &R, @@ -1073,7 +1408,7 @@ impl<'a> AtomView<'a> { pub fn to_factorized_rational_polynomial_impl< R: EuclideanDomain + ConvertToRing, RO: EuclideanDomain + PolynomialGCD, - E: Exponent, + E: PositiveExponent, >( &self, field: &R, @@ -1210,33 +1545,27 @@ impl MultivariatePolynomial { let add = out.to_add(); let mut mul_h = workspace.new_atom(); - let mut var_h = workspace.new_atom(); let mut num_h = workspace.new_atom(); let mut pow_h = workspace.new_atom(); + let vars: Vec<_> = self.variables.iter().map(|v| v.to_atom()).collect(); + + let mut sorted_vars = (0..vars.len()).collect::>(); + sorted_vars.sort_by_key(|&i| vars[i].clone()); + for monomial in self { let mul = mul_h.to_mul(); - for (var_id, &pow) in self.variables.iter().zip(monomial.exponents) { - if pow > E::zero() { - match var_id { - Variable::Symbol(v) => { - var_h.to_var(*v); - } - Variable::Temporary(_) => { - unreachable!("Temporary variable in expression") - } - Variable::Function(_, a) | Variable::Other(a) => { - var_h.set_from_view(&a.as_view()); - } - } - - if pow > E::one() { - num_h.to_num((pow.to_u32() as i64).into()); - pow_h.to_pow(var_h.as_view(), num_h.as_view()); + for i in &sorted_vars { + let var = &vars[*i]; + let pow = monomial.exponents[*i]; + if pow != E::zero() { + if pow != E::one() { + num_h.to_num((pow.to_i32() as i64).into()); + pow_h.to_pow(var.as_view(), num_h.as_view()); mul.extend(pow_h.as_view()); } else { - mul.extend(var_h.as_view()); + mul.extend(var.as_view()); } } } @@ -1284,34 +1613,38 @@ impl MultivariatePolynomial { let add = out.to_add(); let mut mul_h = workspace.new_atom(); - let mut var_h = workspace.new_atom(); let mut num_h = workspace.new_atom(); let mut pow_h = workspace.new_atom(); + let vars: Vec<_> = self + .variables + .iter() + .map(|v| { + if let Variable::Temporary(_) = v { + let a = map.get(v).expect("Variable missing from map"); + a.to_owned() + } else { + v.to_atom() + } + }) + .collect(); + + let mut sorted_vars = (0..vars.len()).collect::>(); + sorted_vars.sort_by_key(|&i| vars[i].clone()); + for monomial in self { let mul = mul_h.to_mul(); - for (var_id, &pow) in self.variables.iter().zip(monomial.exponents) { - if pow > E::zero() { - match var_id { - Variable::Symbol(v) => { - var_h.to_var(*v); - } - Variable::Temporary(_) => { - let a = map.get(var_id).expect("Variable missing from map"); - var_h.set_from_view(a); - } - Variable::Function(_, a) | Variable::Other(a) => { - var_h.set_from_view(&a.as_view()); - } - } - - if pow > E::one() { - num_h.to_num((pow.to_u32() as i64).into()); - pow_h.to_pow(var_h.as_view(), num_h.as_view()); + for i in &sorted_vars { + let var = &vars[*i]; + let pow = monomial.exponents[*i]; + if pow != E::zero() { + if pow != E::one() { + num_h.to_num((pow.to_i32() as i64).into()); + pow_h.to_pow(var.as_view(), num_h.as_view()); mul.extend(pow_h.as_view()); } else { - mul.extend(var_h.as_view()); + mul.extend(var.as_view()); } } } @@ -1364,7 +1697,7 @@ impl MultivariatePolynomial { let mul = mul_h.to_mul(); for (var_id, &pow) in self.variables.iter().zip(monomial.exponents) { - if pow > E::zero() { + if pow != E::zero() { match var_id { Variable::Symbol(v) => { var_h.to_var(*v); @@ -1377,8 +1710,8 @@ impl MultivariatePolynomial { } } - if pow > E::one() { - num_h.to_num((pow.to_u32() as i64).into()); + if pow != E::one() { + num_h.to_num((pow.to_i32() as i64).into()); pow_h.to_pow(var_h.as_view(), num_h.as_view()); mul.extend(pow_h.as_view()); } else { @@ -1398,7 +1731,7 @@ impl MultivariatePolynomial { } } -impl RationalPolynomial { +impl RationalPolynomial { pub fn to_expression(&self) -> Atom where R::Element: Into, @@ -1500,8 +1833,8 @@ impl Token { match &args[1] { Token::Number(n) => { - if let Ok(x) = n.parse::() { - exponents[var_index] += E::from_u32(x); + if let Ok(x) = n.parse::() { + exponents[var_index] += E::from_i32(x); } else { Err("Invalid exponent")? }; @@ -1596,7 +1929,7 @@ impl Token { pub fn to_rational_polynomial< R: EuclideanDomain + ConvertToRing, RO: EuclideanDomain + ConvertToRing + PolynomialGCD, - E: Exponent, + E: PositiveExponent, >( &self, field: &R, @@ -1704,7 +2037,7 @@ impl Token { pub fn to_factorized_rational_polynomial< R: EuclideanDomain + ConvertToRing, RO: EuclideanDomain + ConvertToRing + PolynomialGCD, - E: Exponent, + E: PositiveExponent, >( &self, field: &R, diff --git a/src/poly/evaluate.rs b/src/poly/evaluate.rs index 4ba813a8..dd7f15d4 100644 --- a/src/poly/evaluate.rs +++ b/src/poly/evaluate.rs @@ -23,7 +23,7 @@ use crate::{ state::State, }; -use super::{polynomial::MultivariatePolynomial, Exponent}; +use super::{polynomial::MultivariatePolynomial, PositiveExponent}; /// A borrowed version of a Horner node, suitable as a key in a /// hashmap. It uses precomputed hashes for the complete node @@ -315,7 +315,7 @@ where } } -impl MultivariatePolynomial { +impl MultivariatePolynomial { /// Write the polynomial in a Horner scheme with the variable ordering /// defined in `order`. pub fn to_horner_scheme(&self, order: &[usize]) -> HornerScheme { @@ -464,7 +464,7 @@ impl MultivariatePolynomial { let mut h = AHasher::default(); h.write_u8(0); var.hash(&mut h); - (min_pow.to_u32() as usize).hash(&mut h); + (min_pow.to_i32() as usize).hash(&mut h); let pow_hash = h.finish(); // hash var^pow @@ -513,7 +513,7 @@ impl MultivariatePolynomial { HornerScheme::Node(HornerNode { var, - pow: min_pow.to_u32() as usize, + pow: min_pow.to_i32() as usize, gcd, hash: (pow_hash, pow_content_hash, full_hash), content_rest: boxed_children, @@ -545,7 +545,7 @@ impl MultivariatePolynomial { } impl HornerScheme { - pub fn optimize_multiple( + pub fn optimize_multiple( polys: &[&MultivariatePolynomial], num_tries: usize, ) -> (Vec>, usize, Vec) { diff --git a/src/poly/factor.rs b/src/poly/factor.rs index 3edbdd54..aaf9de88 100644 --- a/src/poly/factor.rs +++ b/src/poly/factor.rs @@ -20,7 +20,7 @@ use crate::{ utils, }; -use super::{gcd::PolynomialGCD, polynomial::MultivariatePolynomial, Exponent, LexOrder}; +use super::{gcd::PolynomialGCD, polynomial::MultivariatePolynomial, LexOrder, PositiveExponent}; pub trait Factorize: Sized { /// Perform a square-free factorization. @@ -32,7 +32,9 @@ pub trait Factorize: Sized { fn is_irreducible(&self) -> bool; } -impl, E: Exponent> MultivariatePolynomial { +impl, E: PositiveExponent> + MultivariatePolynomial +{ /// Find factors that do not contain all variables. pub fn factor_separable(&self) -> Vec { let mut stripped = self.clone(); @@ -205,7 +207,7 @@ impl, E: Exponent> MultivariatePolynomial< } } -impl Factorize for MultivariatePolynomial { +impl Factorize for MultivariatePolynomial { fn square_free_factorization(&self) -> Vec<(Self, usize)> { if self.is_zero() { return vec![]; @@ -343,7 +345,7 @@ impl Factorize for MultivariatePolynomial } } -impl Factorize for MultivariatePolynomial { +impl Factorize for MultivariatePolynomial { fn square_free_factorization(&self) -> Vec<(Self, usize)> { let c = self.content(); @@ -411,7 +413,7 @@ impl Factorize for MultivariatePolynomial Factorize +impl Factorize for MultivariatePolynomial, E, LexOrder> { fn square_free_factorization(&self) -> Vec<(Self, usize)> { @@ -493,7 +495,7 @@ impl Factorize impl< UField: FiniteFieldWorkspace, F: GaloisField> + PolynomialGCD, - E: Exponent, + E: PositiveExponent, > Factorize for MultivariatePolynomial where FiniteField: Field + FiniteFieldCore + PolynomialGCD, @@ -648,7 +650,7 @@ where impl< UField: FiniteFieldWorkspace, F: GaloisField> + PolynomialGCD, - E: Exponent, + E: PositiveExponent, > MultivariatePolynomial where FiniteField: Field + FiniteFieldCore + PolynomialGCD, @@ -1650,7 +1652,7 @@ where } } -impl MultivariatePolynomial { +impl MultivariatePolynomial { fn multivariate_diophantine( univariate_deltas: &[Self], univariate_factors: &mut [Self], @@ -1987,7 +1989,7 @@ impl MultivariatePolynomial { } } -impl MultivariatePolynomial { +impl MultivariatePolynomial { /// Hensel lift a solution of `self = u * w mod p` to `self = u * w mod max_p` /// where `max_p` is a power of `p`. /// @@ -2159,7 +2161,7 @@ impl MultivariatePolynomial { } let bound = self.coefficient_bound(); - let p: Integer = (field.get_prime().to_u32() as i64).into(); + let p: Integer = (field.get_prime() as i64).into(); let mut max_p = p.clone(); while max_p < bound { max_p = &max_p * &p; @@ -3361,7 +3363,7 @@ impl MultivariatePolynomial { } } -impl MultivariatePolynomial, E, LexOrder> { +impl MultivariatePolynomial, E, LexOrder> { /// Compute a univariate diophantine equation in `Z_p^k` by Newton iteration. fn get_univariate_factors_and_deltas( factors: &[Self], diff --git a/src/poly/gcd.rs b/src/poly/gcd.rs index b57bbcb2..acd156ca 100755 --- a/src/poly/gcd.rs +++ b/src/poly/gcd.rs @@ -18,7 +18,7 @@ use crate::poly::INLINED_EXPONENTS; use crate::tensors::matrix::{Matrix, MatrixError}; use super::polynomial::MultivariatePolynomial; -use super::Exponent; +use super::PositiveExponent; // 100 large u32 primes starting from the 203213901st prime number pub const LARGE_U32_PRIMES: [u32; 100] = [ @@ -122,7 +122,7 @@ enum GCDError { BadCurrentImage, } -impl MultivariatePolynomial { +impl MultivariatePolynomial { /// Evaluation of the exponents by filling in the variables #[inline(always)] fn evaluate_exponents( @@ -184,7 +184,7 @@ impl MultivariatePolynomial { } } -impl MultivariatePolynomial { +impl MultivariatePolynomial { /// Compute the univariate GCD using Euclid's algorithm. The result is normalized to 1. pub fn univariate_gcd(&self, b: &Self) -> Self { if self.is_zero() { @@ -1077,7 +1077,7 @@ impl MultivariatePolynomial { } } -impl, E: Exponent> MultivariatePolynomial { +impl, E: PositiveExponent> MultivariatePolynomial { /// Compute the gcd shape of two polynomials in a finite field by filling in random /// numbers. #[instrument(level = "debug", skip_all)] @@ -1376,7 +1376,7 @@ impl, E: Exponent> MultivariatePolynomial { } } -impl, E: Exponent> MultivariatePolynomial { +impl, E: PositiveExponent> MultivariatePolynomial { /// Get the content of a multivariate polynomial viewed as a /// univariate polynomial in `x`. pub fn univariate_content(&self, x: usize) -> MultivariatePolynomial { @@ -1628,7 +1628,7 @@ impl, E: Exponent> MultivariatePolynomial< /// Undo simplifications made to the input polynomials and normalize the gcd. #[inline(always)] - fn rescale_gcd, E: Exponent>( + fn rescale_gcd, E: PositiveExponent>( mut g: MultivariatePolynomial, shared_degree: &[E], base_degree: &[Option], @@ -1852,11 +1852,11 @@ pub enum HeuristicGCDError { BadReconstruction, } -impl MultivariatePolynomial { +impl MultivariatePolynomial { /// Perform a heuristic GCD algorithm. #[instrument(level = "debug", skip_all)] pub fn heuristic_gcd(&self, b: &Self) -> Result<(Self, Self, Self), HeuristicGCDError> { - fn interpolate( + fn interpolate( mut gamma: MultivariatePolynomial, var: usize, xi: &Integer, @@ -2391,7 +2391,7 @@ impl MultivariatePolynomial { } /// Polynomial GCD functions for a certain coefficient type `Self`. -pub trait PolynomialGCD: Ring { +pub trait PolynomialGCD: Ring { fn heuristic_gcd( a: &MultivariatePolynomial, b: &MultivariatePolynomial, @@ -2417,7 +2417,7 @@ pub trait PolynomialGCD: Ring { fn normalize(a: MultivariatePolynomial) -> MultivariatePolynomial; } -impl PolynomialGCD for IntegerRing { +impl PolynomialGCD for IntegerRing { fn heuristic_gcd( a: &MultivariatePolynomial, b: &MultivariatePolynomial, @@ -2530,7 +2530,7 @@ impl PolynomialGCD for IntegerRing { } } -impl PolynomialGCD for RationalField { +impl PolynomialGCD for RationalField { fn heuristic_gcd( _a: &MultivariatePolynomial, _b: &MultivariatePolynomial, @@ -2605,8 +2605,11 @@ impl PolynomialGCD for RationalField { } } -impl>, E: Exponent> - PolynomialGCD for F +impl< + UField: FiniteFieldWorkspace, + F: GaloisField>, + E: PositiveExponent, + > PolynomialGCD for F where FiniteField: FiniteFieldCore, as Ring>::Element: Copy, @@ -2669,7 +2672,7 @@ where } } -impl PolynomialGCD for AlgebraicExtension { +impl PolynomialGCD for AlgebraicExtension { fn heuristic_gcd( _a: &MultivariatePolynomial, _b: &MultivariatePolynomial, diff --git a/src/poly/polynomial.rs b/src/poly/polynomial.rs index 275a6b3e..4af53e42 100755 --- a/src/poly/polynomial.rs +++ b/src/poly/polynomial.rs @@ -1,5 +1,5 @@ use ahash::{HashMap, HashMapExt}; -use std::cell::Cell; +use std::cell::{Cell, UnsafeCell}; use std::cmp::{Ordering, Reverse}; use std::collections::{BTreeMap, BinaryHeap}; use std::fmt::Display; @@ -16,7 +16,7 @@ use crate::printer::{PrintOptions, PrintState}; use super::gcd::PolynomialGCD; use super::univariate::UnivariatePolynomial; -use super::{Exponent, LexOrder, MonomialOrder, Variable, INLINED_EXPONENTS}; +use super::{Exponent, LexOrder, MonomialOrder, PositiveExponent, Variable, INLINED_EXPONENTS}; use smallvec::{smallvec, SmallVec}; const MAX_DENSE_MUL_BUFFER_SIZE: usize = 1 << 24; @@ -160,7 +160,9 @@ impl Ring for PolynomialRing { } } -impl, E: Exponent> EuclideanDomain for PolynomialRing { +impl, E: PositiveExponent> EuclideanDomain + for PolynomialRing +{ fn rem(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { a.rem(b) } @@ -175,6 +177,7 @@ impl, E: Exponent> EuclideanDomain for Pol } /// Multivariate polynomial with a sparse degree and dense variable representation. +/// Negative exponents are supported, if they are allowed by the exponent type. #[derive(Clone)] pub struct MultivariatePolynomial { // Data format: the i-th monomial is stored as coefficients[i] and @@ -429,7 +432,7 @@ impl MultivariatePolynomial { self.exponents.chunks_mut(nvars) } - /// Returns the number of variables in the polynomial. + /// Reset the polynomial to 0. #[inline] pub fn clear(&mut self) { self.coefficients.clear(); @@ -446,7 +449,7 @@ impl MultivariatePolynomial { self.variables.as_ref() } - /// Renaname a variable. + /// Rename a variable. pub fn rename_variable(&mut self, old: &Variable, new: &Variable) { if let Some(pos) = self.variables.iter().position(|v| v == old) { let mut new_vars = self.variables.as_ref().clone(); @@ -493,6 +496,26 @@ impl MultivariatePolynomial { self.variables = Arc::new(new_var_map); self.exponents = newexp; + // check if term ordering remains unchanged + if new_var_pos_other.windows(2).all(|w| w[0] <= w[1]) { + let mut newexp = vec![E::zero(); self.nvars() * other.nterms()]; + + if other.nvars() > 0 { + for (d, t) in newexp + .chunks_mut(self.nvars()) + .zip(other.exponents.chunks(other.nvars())) + { + for (var, e) in t.iter().enumerate() { + d[new_var_pos_other[var]] = *e; + } + } + } + + other.variables = self.variables.clone(); + other.exponents = newexp; + return; + } + // reconstruct 'other' with correct monomial ordering let mut newother = Self::new(&other.ring, other.nterms().into(), self.variables.clone()); let mut newexp = vec![E::zero(); self.nvars()]; @@ -688,25 +711,6 @@ impl MultivariatePolynomial { let i = l * self.nvars(); self.exponents.splice(i..i, exponents.iter().cloned()); } - - /// Take the derivative of the polynomial w.r.t the variable `var`. - pub fn derivative(&self, var: usize) -> Self { - debug_assert!(var < self.nvars()); - - let mut res = self.zero_with_capacity(self.nterms()); - - let mut exp = vec![E::zero(); self.nvars()]; - for x in self { - if x.exponents[var] > E::zero() { - exp.copy_from_slice(x.exponents); - let pow = exp[var].to_u32() as u64; - exp[var] = exp[var] - E::one(); - res.append_monomial(self.ring.mul(x.coefficient, &self.ring.nth(pow)), &exp); - } - } - - res - } } impl SelfRing for MultivariatePolynomial { @@ -761,10 +765,13 @@ impl SelfRing for MultivariatePolynomial .as_ref() .iter() .map(|v| { - v.to_string_with_state(PrintState { - in_exp: true, - ..state - }) + v.format_string( + opts, + PrintState { + in_exp: true, + ..state + }, + ) }) .collect(); @@ -787,7 +794,7 @@ impl SelfRing for MultivariatePolynomial f.write_str(var_id)?; - if e.to_u32() != 1 { + if e.to_i32() != 1 { if opts.latex { write!(f, "^{{{}}}", e)?; } else if opts.double_star_for_exponentiation { @@ -1158,8 +1165,8 @@ impl<'a, F: Ring, E: Exponent> Mul<&'a MultivariatePolynomial> } } -impl<'a, 'b, F: EuclideanDomain, E: Exponent> Div<&'a MultivariatePolynomial> - for &'b MultivariatePolynomial +impl<'a, 'b, F: EuclideanDomain, E: PositiveExponent> + Div<&'a MultivariatePolynomial> for &'b MultivariatePolynomial { type Output = MultivariatePolynomial; @@ -1169,7 +1176,7 @@ impl<'a, 'b, F: EuclideanDomain, E: Exponent> Div<&'a MultivariatePolynomial Div<&'a MultivariatePolynomial> +impl<'a, F: EuclideanDomain, E: PositiveExponent> Div<&'a MultivariatePolynomial> for MultivariatePolynomial { type Output = MultivariatePolynomial; @@ -1345,7 +1352,335 @@ impl MultivariatePolynomial { } } +impl MultivariatePolynomial { + /// Take the derivative of the polynomial w.r.t the variable `var`. + pub fn derivative(&self, var: usize) -> Self { + debug_assert!(var < self.nvars()); + + let mut res = self.zero_with_capacity(self.nterms()); + + let mut exp = vec![E::zero(); self.nvars()]; + for x in self { + if x.exponents[var] > E::zero() { + exp.copy_from_slice(x.exponents); + let pow = exp[var].to_i32() as u64; + exp[var] = exp[var] - E::one(); + res.append_monomial(self.ring.mul(x.coefficient, &self.ring.nth(pow)), &exp); + } + } + + res + } + + /// Replace a variable `n` in the polynomial by an element from + /// the ring `v`. + pub fn replace(&self, n: usize, v: &F::Element) -> MultivariatePolynomial { + if (n + 1..self.nvars()).all(|i| self.degree(i) == E::zero()) { + return self.replace_last(n, v); + } + + let mut res = self.zero_with_capacity(self.nterms()); + let mut e: SmallVec<[E; INLINED_EXPONENTS]> = smallvec![E::zero(); self.nvars()]; + + // TODO: cache power taking? + for t in self { + if t.exponents[n] == E::zero() { + res.append_monomial(t.coefficient.clone(), t.exponents); + continue; + } + + let c = self.ring.mul( + t.coefficient, + &self.ring.pow(v, t.exponents[n].to_i32() as u64), + ); + + e.copy_from_slice(t.exponents); + e[n] = E::zero(); + res.append_monomial(c, &e); + } + + res + } + + /// Replace the last variable `n` in the polynomial by an element from + /// the ring `v`. + pub fn replace_last(&self, n: usize, v: &F::Element) -> MultivariatePolynomial { + let mut res = self.zero_with_capacity(self.nterms()); + let mut e: SmallVec<[E; INLINED_EXPONENTS]> = smallvec![E::zero(); self.nvars()]; + + // TODO: cache power taking? + for t in self { + if t.exponents[n] == E::zero() { + res.append_monomial(t.coefficient.clone(), t.exponents); + continue; + } + + let c = self.ring.mul( + t.coefficient, + &self.ring.pow(v, t.exponents[n].to_i32() as u64), + ); + + if F::is_zero(&c) { + continue; + } + + e.copy_from_slice(t.exponents); + e[n] = E::zero(); + + if res.is_zero() || res.last_exponents() != e.as_slice() { + res.coefficients.push(c); + res.exponents.extend_from_slice(&e); + } else { + let l = res.coefficients.last_mut().unwrap(); + self.ring.add_assign(l, &c); + + if F::is_zero(l) { + res.coefficients.pop(); + res.exponents.truncate(res.exponents.len() - self.nvars()); + } + } + } + + res + } + + /// Replace a variable `n` in the polynomial by an element from + /// the ring `v`. + pub fn replace_all(&self, r: &[F::Element]) -> F::Element { + let mut res = self.ring.zero(); + + // TODO: cache power taking? + for t in self { + let mut c = t.coefficient.clone(); + + for (i, v) in r.iter().zip(t.exponents) { + if v != &E::zero() { + self.ring + .mul_assign(&mut c, &self.ring.pow(i, v.to_i32() as u64)); + } + } + + self.ring.add_assign(&mut res, &c); + } + + res + } + + /// Replace a variable `n` in the polynomial by a polynomial `v`. + pub fn replace_with_poly(&self, n: usize, v: &Self) -> Self { + assert_eq!(self.variables, v.variables); + + if v.is_constant() { + return self.replace(n, &v.lcoeff()); + } + + let mut res = self.zero_with_capacity(self.nterms()); + let mut exp = vec![E::zero(); self.nvars()]; + for t in self { + if t.exponents[n] == E::zero() { + res.append_monomial(t.coefficient.clone(), &t.exponents[..self.nvars()]); + continue; + } + + exp.copy_from_slice(t.exponents); + exp[n] = E::zero(); + + // TODO: cache v^e + res = res + + (&v.pow(t.exponents[n].to_i32() as usize) + * &self.monomial(t.coefficient.clone(), exp.clone())); + } + res + } + + /// Replace all variables except `v` in the polynomial by elements from + /// the ring. + pub fn replace_all_except( + &self, + v: usize, + r: &[(usize, F::Element)], + cache: &mut [Vec], + ) -> MultivariatePolynomial { + let mut tm: HashMap = HashMap::new(); + + for t in self { + let mut c = t.coefficient.clone(); + for (n, vv) in r { + let p = t.exponents[*n].to_i32() as usize; + if p > 0 { + if p < cache[*n].len() { + if F::is_zero(&cache[*n][p]) { + cache[*n][p] = self.ring.pow(vv, p as u64); + } + + self.ring.mul_assign(&mut c, &cache[*n][p]); + } else { + self.ring.mul_assign(&mut c, &self.ring.pow(vv, p as u64)); + } + } + } + + tm.entry(t.exponents[v]) + .and_modify(|e| self.ring.add_assign(e, &c)) + .or_insert(c); + } + + let mut res = self.zero(); + let mut e = vec![E::zero(); self.nvars()]; + for (k, c) in tm { + e[v] = k; + res.append_monomial(c, &e); + e[v] = E::zero(); + } + + res + } + + /// Shift a variable `var` to `var+shift`. + pub fn shift_var(&self, var: usize, shift: &F::Element) -> Self { + let d = self.degree(var).to_i32() as usize; + + let y_poly = self.to_univariate_polynomial_list(var); + + let mut v = vec![self.zero(); d + 1]; + for (x_poly, p) in y_poly { + v[p.to_i32() as usize] = x_poly; + } + + for k in 0..d { + for j in (k..d).rev() { + v[j] = &v[j] + &v[j + 1].clone().mul_coeff(shift.clone()); + } + } + + let mut poly = self.zero(); + for (i, mut v) in v.into_iter().enumerate() { + for x in v.exponents.chunks_mut(self.nvars()) { + x[var] = E::from_i32(i as i32); + } + + for m in &v { + poly.append_monomial(m.coefficient.clone(), m.exponents); + } + } + + poly + } + + /// Synthetic division for univariate polynomials, where `div` is monic. + // TODO: create UnivariatePolynomial? + pub fn quot_rem_univariate_monic( + &self, + div: &MultivariatePolynomial, + ) -> ( + MultivariatePolynomial, + MultivariatePolynomial, + ) { + debug_assert_eq!(div.lcoeff(), self.ring.one()); + if self.is_zero() { + return (self.clone(), self.clone()); + } + + let mut dividendpos = self.nterms() - 1; // work from the back + + let mut q = self.zero_with_capacity(self.nterms()); + let mut r = self.zero(); + + // determine the variable + let mut var = 0; + for (i, x) in self.last_exponents().iter().enumerate() { + if !x.is_zero() { + var = i; + break; + } + } + + let m = div.ldegree_max(); + let mut pow = self.ldegree_max(); + + loop { + // find the power in the dividend if it exists + let mut coeff = loop { + if self.exponents(dividendpos)[var] == pow { + break self.coefficients[dividendpos].clone(); + } + if dividendpos == 0 || self.exponents(dividendpos)[var] < pow { + break self.ring.zero(); + } + dividendpos -= 1; + }; + + let mut qindex = 0; // starting from highest + let mut bindex = 0; // starting from lowest + while bindex < div.nterms() && qindex < q.nterms() { + while bindex + 1 < div.nterms() + && div.exponents(bindex)[var] + q.exponents(qindex)[var] < pow + { + bindex += 1; + } + + if div.exponents(bindex)[var] + q.exponents(qindex)[var] == pow { + self.ring.sub_mul_assign( + &mut coeff, + &div.coefficients[bindex], + &q.coefficients[qindex], + ); + } + + qindex += 1; + } + + if !F::is_zero(&coeff) { + // can the division be performed? if not, add to rest + // TODO: refactor + let (quot, div) = if pow >= m { + (coeff, true) + } else { + (coeff, false) + }; + + if div { + let nterms = q.nterms(); + let nvars = q.nvars(); + q.coefficients.push(quot); + q.exponents.resize((nterms + 1) * nvars, E::zero()); + q.exponents[nterms * nvars + var] = pow - m; + } else { + let nterms = r.nterms(); + let nvars = r.nvars(); + r.coefficients.push(quot); + r.exponents.resize((nterms + 1) * nvars, E::zero()); + r.exponents[nterms * nvars + var] = pow; + } + } + + if pow.is_zero() { + break; + } + + pow = pow - E::one(); + } + + q.reverse(); + r.reverse(); + + #[cfg(debug_assertions)] + { + if !(&q * div + r.clone() - self.clone()).is_zero() { + panic!("Division failed: ({})/({}): q={}, r={}", self, div, q, r); + } + } + + (q, r) + } +} + impl MultivariatePolynomial { + /// Check if all exponents are positive. + pub fn is_polynomial(&self) -> bool { + self.is_zero() || self.exponents.iter().all(|e| *e >= E::zero()) + } + /// Get the leading coefficient under a given variable ordering. /// This operation is O(n) if the variables are out of order. pub fn lcoeff_varorder(&self, vars: &[usize]) -> F::Element { @@ -1532,205 +1867,41 @@ impl MultivariatePolynomial { /// The map can also be reversed, by setting `inverse` to `true`. pub fn rearrange( &self, - order: &[usize], - inverse: bool, - ) -> MultivariatePolynomial { - self.rearrange_impl(order, inverse, true) - } - - /// Change the order of the variables in the polynomial, using `order`. - /// The order may contain `None`, to signal unmapped indices. This operation - /// allows the polynomial to grow in size. - /// - /// Note that the polynomial `var_map` is not updated. - pub fn rearrange_with_growth( - &self, - order: &[Option], - ) -> MultivariatePolynomial { - let mut new_exp = vec![E::zero(); self.nterms() * order.len()]; - for (e, er) in new_exp.chunks_mut(order.len()).zip(self.exponents_iter()) { - for x in 0..order.len() { - if let Some(v) = order[x] { - e[x] = er[v]; - } - } - } - - let mut indices: Vec = (0..self.nterms()).collect(); - indices.sort_unstable_by_key(|&i| &new_exp[i * order.len()..(i + 1) * order.len()]); - - let mut res = - MultivariatePolynomial::new(&self.ring, self.nterms().into(), self.variables.clone()); - - for i in indices { - res.append_monomial( - self.coefficients[i].clone(), - &new_exp[i * order.len()..(i + 1) * order.len()], - ); - } - - res - } - - /// Replace a variable `n` in the polynomial by an element from - /// the ring `v`. - pub fn replace(&self, n: usize, v: &F::Element) -> MultivariatePolynomial { - if (n + 1..self.nvars()).all(|i| self.degree(i) == E::zero()) { - return self.replace_last(n, v); - } - - let mut res = self.zero_with_capacity(self.nterms()); - let mut e: SmallVec<[E; INLINED_EXPONENTS]> = smallvec![E::zero(); self.nvars()]; - - // TODO: cache power taking? - for t in self { - if t.exponents[n] == E::zero() { - res.append_monomial(t.coefficient.clone(), t.exponents); - continue; - } - - let c = self.ring.mul( - t.coefficient, - &self.ring.pow(v, t.exponents[n].to_u32() as u64), - ); - - e.copy_from_slice(t.exponents); - e[n] = E::zero(); - res.append_monomial(c, &e); - } - - res - } - - /// Replace the last variable `n` in the polynomial by an element from - /// the ring `v`. - pub fn replace_last(&self, n: usize, v: &F::Element) -> MultivariatePolynomial { - let mut res = self.zero_with_capacity(self.nterms()); - let mut e: SmallVec<[E; INLINED_EXPONENTS]> = smallvec![E::zero(); self.nvars()]; - - // TODO: cache power taking? - for t in self { - if t.exponents[n] == E::zero() { - res.append_monomial(t.coefficient.clone(), t.exponents); - continue; - } - - let c = self.ring.mul( - t.coefficient, - &self.ring.pow(v, t.exponents[n].to_u32() as u64), - ); - - if F::is_zero(&c) { - continue; - } - - e.copy_from_slice(t.exponents); - e[n] = E::zero(); - - if res.is_zero() || res.last_exponents() != e.as_slice() { - res.coefficients.push(c); - res.exponents.extend_from_slice(&e); - } else { - let l = res.coefficients.last_mut().unwrap(); - self.ring.add_assign(l, &c); - - if F::is_zero(l) { - res.coefficients.pop(); - res.exponents.truncate(res.exponents.len() - self.nvars()); - } - } - } - - res - } - - /// Replace a variable `n` in the polynomial by an element from - /// the ring `v`. - pub fn replace_all(&self, r: &[F::Element]) -> F::Element { - let mut res = self.ring.zero(); - - // TODO: cache power taking? - for t in self { - let mut c = t.coefficient.clone(); - - for (i, v) in r.iter().zip(t.exponents) { - if v != &E::zero() { - self.ring - .mul_assign(&mut c, &self.ring.pow(i, v.to_u32() as u64)); - } - } - - self.ring.add_assign(&mut res, &c); - } - - res - } - - /// Replace a variable `n` in the polynomial by a polynomial `v`. - pub fn replace_with_poly(&self, n: usize, v: &Self) -> Self { - assert_eq!(self.variables, v.variables); - - if v.is_constant() { - return self.replace(n, &v.lcoeff()); - } - - let mut res = self.zero_with_capacity(self.nterms()); - let mut exp = vec![E::zero(); self.nvars()]; - for t in self { - if t.exponents[n] == E::zero() { - res.append_monomial(t.coefficient.clone(), &t.exponents[..self.nvars()]); - continue; - } - - exp.copy_from_slice(t.exponents); - exp[n] = E::zero(); - - // TODO: cache v^e - res = res - + (&v.pow(t.exponents[n].to_u32() as usize) - * &self.monomial(t.coefficient.clone(), exp.clone())); - } - res + order: &[usize], + inverse: bool, + ) -> MultivariatePolynomial { + self.rearrange_impl(order, inverse, true) } - /// Replace all variables except `v` in the polynomial by elements from - /// the ring. - pub fn replace_all_except( + /// Change the order of the variables in the polynomial, using `order`. + /// The order may contain `None`, to signal unmapped indices. This operation + /// allows the polynomial to grow in size. + /// + /// Note that the polynomial `var_map` is not updated. + pub fn rearrange_with_growth( &self, - v: usize, - r: &[(usize, F::Element)], - cache: &mut [Vec], + order: &[Option], ) -> MultivariatePolynomial { - let mut tm: HashMap = HashMap::new(); - - for t in self { - let mut c = t.coefficient.clone(); - for (n, vv) in r { - let p = t.exponents[*n].to_u32() as usize; - if p > 0 { - if p < cache[*n].len() { - if F::is_zero(&cache[*n][p]) { - cache[*n][p] = self.ring.pow(vv, p as u64); - } - - self.ring.mul_assign(&mut c, &cache[*n][p]); - } else { - self.ring.mul_assign(&mut c, &self.ring.pow(vv, p as u64)); - } + let mut new_exp = vec![E::zero(); self.nterms() * order.len()]; + for (e, er) in new_exp.chunks_mut(order.len()).zip(self.exponents_iter()) { + for x in 0..order.len() { + if let Some(v) = order[x] { + e[x] = er[v]; } } - - tm.entry(t.exponents[v]) - .and_modify(|e| self.ring.add_assign(e, &c)) - .or_insert(c); } - let mut res = self.zero(); - let mut e = vec![E::zero(); self.nvars()]; - for (k, c) in tm { - e[v] = k; - res.append_monomial(c, &e); - e[v] = E::zero(); + let mut indices: Vec = (0..self.nterms()).collect(); + indices.sort_unstable_by_key(|&i| &new_exp[i * order.len()..(i + 1) * order.len()]); + + let mut res = + MultivariatePolynomial::new(&self.ring, self.nterms().into(), self.variables.clone()); + + for i in indices { + res.append_monomial( + self.coefficients[i].clone(), + &new_exp[i * order.len()..(i + 1) * order.len()], + ); } res @@ -1774,9 +1945,13 @@ impl MultivariatePolynomial { return p; } - p.coefficients = vec![self.zero(); c.last().unwrap().1.to_u32() as usize + 1]; + p.coefficients = vec![self.zero(); c.last().unwrap().1.to_i32() as usize + 1]; for (q, e) in c { - p.coefficients[e.to_u32() as usize] = q; + if e < E::zero() { + panic!("Negative exponent in univariate conversion"); + } + + p.coefficients[e.to_i32() as usize] = q; } p @@ -1790,9 +1965,13 @@ impl MultivariatePolynomial { return p; } - p.coefficients = vec![p.ring.zero(); self.degree(var).to_u32() as usize + 1]; + p.coefficients = vec![p.ring.zero(); self.degree(var).to_i32() as usize + 1]; for (q, e) in self.coefficients.iter().zip(self.exponents_iter()) { - p.coefficients[e[var].to_u32() as usize] = q.clone(); + if e[var] < E::zero() { + panic!("Negative exponent in univariate conversion"); + } + + p.coefficients[e[var].to_i32() as usize] = q.clone(); } p @@ -1809,22 +1988,26 @@ impl MultivariatePolynomial { } // get maximum degree for variable x + let mut mindeg = E::zero(); let mut maxdeg = E::zero(); for t in 0..self.nterms() { let d = self.exponents(t)[x]; if d > maxdeg { maxdeg = d; } + if d < mindeg { + mindeg = d; + } } // construct the coefficient per power of x let mut result = vec![]; let mut e: SmallVec<[E; INLINED_EXPONENTS]> = smallvec![E::zero(); self.nvars()]; - for d in 0..maxdeg.to_u32() + 1 { + for d in mindeg.to_i32()..maxdeg.to_i32() + 1 { // TODO: add bounds estimate let mut a = self.zero(); for t in 0..self.nterms() { - if self.exponents(t)[x].to_u32() == d { + if self.exponents(t)[x].to_i32() == d { for (i, ee) in self.exponents(t).iter().enumerate() { e[i] = *ee; } @@ -1834,7 +2017,7 @@ impl MultivariatePolynomial { } if !a.is_zero() { - result.push((a, E::from_u32(d))); + result.push((a, E::from_i32(d))); } } @@ -1923,7 +2106,7 @@ impl MultivariatePolynomial { c2.exponents = c .exponents_iter() - .map(|x| x[var_index].to_u32() as u16) + .map(|x| x[var_index].to_i32() as u16) .collect(); c2.coefficients = c.coefficients; @@ -1936,9 +2119,9 @@ impl MultivariatePolynomial { if self.is_constant() { if let Some(m) = max_pow { if let Some(var) = rhs.last_exponents().iter().position(|e| *e != E::zero()) { - if rhs.degree(var).to_u32() > m as u32 { + if rhs.degree(var).to_i32() > m as i32 { return rhs - .mod_var(var, E::from_u32(m as u32 + 1)) + .mod_var(var, E::from_i32(m as i32 + 1)) .mul_coeff(self.lcoeff()); } } @@ -1949,9 +2132,9 @@ impl MultivariatePolynomial { if rhs.is_constant() { if let Some(m) = max_pow { if let Some(var) = self.last_exponents().iter().position(|e| *e != E::zero()) { - if self.degree(var).to_u32() > m as u32 { + if self.degree(var).to_i32() > m as i32 { return self - .mod_var(var, E::from_u32(m as u32 + 1)) + .mod_var(var, E::from_i32(m as i32 + 1)) .mul_coeff(rhs.lcoeff()); } } @@ -1967,7 +2150,7 @@ impl MultivariatePolynomial { let d1 = self.degree(var); let d2 = rhs.degree(var); - let mut max = (d1.to_u32() + d2.to_u32()) as usize; + let mut max = (d1.to_i32() + d2.to_i32()) as usize; if let Some(m) = max_pow { max = max.min(m); } @@ -1976,7 +2159,7 @@ impl MultivariatePolynomial { for x in self { for y in rhs { - let pos = x.exponents[var].to_u32() + y.exponents[var].to_u32(); + let pos = x.exponents[var].to_i32() + y.exponents[var].to_i32(); if pos as usize > max { continue; } @@ -1990,7 +2173,7 @@ impl MultivariatePolynomial { let mut res = self.zero_with_capacity(coeffs.len()); for (p, c) in coeffs.into_iter().enumerate() { if !F::is_zero(&c) { - exp[var] = E::from_u32(p as u32); + exp[var] = E::from_i32(p as i32); res.append_monomial(c, &exp); } } @@ -2001,9 +2184,13 @@ impl MultivariatePolynomial { &self, rhs: &MultivariatePolynomial, ) -> Option> { + if !self.is_polynomial() || !rhs.is_polynomial() { + return None; + } + let max_degs_rev = (0..self.nvars()) .rev() - .map(|i| 1 + self.degree(i).to_u32() as usize + rhs.degree(i).to_u32() as usize) + .map(|i| 1 + self.degree(i).to_i32() as usize + rhs.degree(i).to_i32() as usize) .collect::>(); if max_degs_rev.iter().filter(|x| **x > 1).count() == 1 { @@ -2034,10 +2221,10 @@ impl MultivariatePolynomial { #[inline(always)] fn to_uni_var(s: &[E], max_degs_rev: &[usize]) -> u32 { let mut shift = 1; - let mut res = s.last().unwrap().to_u32(); + let mut res = s.last().unwrap().to_i32() as u32; for (ee, &x) in s.iter().rev().skip(1).zip(max_degs_rev) { - shift = shift.to_u32() * x as u32; - res += ee.to_u32() * shift; + shift = shift * x as u32; + res += ee.to_i32() as u32 * shift; } res } @@ -2045,7 +2232,7 @@ impl MultivariatePolynomial { #[inline(always)] fn from_uni_var(mut p: u32, max_degs_rev: &[usize], exp: &mut [E]) { for (ee, &x) in exp.iter_mut().rev().zip(max_degs_rev) { - *ee = E::from_u32(p % x as u32); + *ee = E::from_i32((p % x as u32) as i32); p /= x as u32; } } @@ -2069,7 +2256,7 @@ impl MultivariatePolynomial { for (c1, e1) in self.coefficients.iter().zip(&uni_exp_self) { for (c2, e2) in rhs.coefficients.iter().zip(&uni_exp_rhs) { - let pos = e1.to_u32() as usize + e2.to_u32() as usize; + let pos = *e1 as usize + *e2 as usize; self.ring.add_mul_assign(&mut coeffs[pos], c1, c2); } } @@ -2093,7 +2280,7 @@ impl MultivariatePolynomial { for (c1, e1) in self.coefficients.iter().zip(&uni_exp_self) { for (c2, e2) in rhs.coefficients.iter().zip(&uni_exp_rhs) { - let pos = e1.to_u32() as usize + e2.to_u32() as usize; + let pos = *e1 as usize + *e2 as usize; if coeff_index[pos] == 0 { coeffs.push(self.ring.mul(c1, c2)); coeff_index[pos] = coeffs.len() as u32; @@ -2124,17 +2311,6 @@ impl MultivariatePolynomial { } } - /// Multiplication for multivariate polynomials using a custom variation of the heap method - /// described in "Sparse polynomial division using a heap" by Monagan, Pearce (2011) and using - /// the sorting described in "Sparse Polynomial Powering Using Heaps". - /// It uses a heap to obtain the next monomial of the result in an ordered fashion. - /// Additionally, this method uses a hashmap with the monomial exponent as a key and a vector of all pairs - /// of indices in `self` and `other` that have that monomial exponent when multiplied together. - /// When a multiplication of two monomials is considered, its indices are added to the hashmap, - /// but they are only added to the heap if the monomial exponent is new. As a result, the heap - /// only has new monomials, and by taking (and removing) the corresponding entry from the hashmap, all - /// monomials that have that exponent can be summed. Then, new monomials combinations are added that - /// should be considered next as they are smaller than the current monomial. fn heap_mul( &self, rhs: &MultivariatePolynomial, @@ -2146,12 +2322,14 @@ impl MultivariatePolynomial { } let degree_sum: Vec<_> = (0..self.nvars()) - .map(|i| self.degree(i).to_u32() as usize + rhs.degree(i).to_u32() as usize) + .map(|i| self.degree(i).to_i32() as i64 + rhs.degree(i).to_i32() as i64) .collect(); // use a special routine if the exponents can be packed into a u64 let mut pack_u8 = true; if self.nvars() <= 8 + && self.is_polynomial() + && rhs.is_polynomial() && degree_sum.iter().all(|deg| { if *deg > 255 { pack_u8 = false; @@ -2163,24 +2341,85 @@ impl MultivariatePolynomial { return self.heap_mul_packed_exp(rhs, pack_u8); } + let mut monomials = Vec::with_capacity(self.nterms() * self.nvars()); + monomials.extend( + self.exponents(0) + .iter() + .zip(rhs.exponents(0)) + .map(|(e1, e2)| *e1 + *e2), + ); + + let monomials = UnsafeCell::new((self.nvars(), monomials)); + + /// In order to prevent allocations of the exponents, store them in a single + /// append-only vector and use a key to index into it. For performance, + /// we use an unsafe cell. + #[derive(Clone, Copy)] + struct Key<'a, E: Exponent> { + index: usize, + monomials: &'a UnsafeCell<(usize, Vec)>, + } + + impl<'a, E: Exponent> PartialEq for Key<'a, E> { + #[inline(always)] + fn eq(&self, other: &Self) -> bool { + unsafe { + let b1 = &*self.monomials.get(); + b1.1.get_unchecked(self.index..self.index + b1.0) + == b1.1.get_unchecked(other.index..other.index + b1.0) + } + } + } + + impl<'a, E: Exponent> Eq for Key<'a, E> {} + + impl<'a, E: Exponent> PartialOrd for Key<'a, E> { + #[inline(always)] + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + + impl<'a, E: Exponent> Ord for Key<'a, E> { + #[inline(always)] + fn cmp(&self, other: &Self) -> Ordering { + unsafe { + let b1 = &*self.monomials.get(); + b1.1.get_unchecked(self.index..self.index + b1.0) + .cmp(&b1.1.get_unchecked(other.index..other.index + b1.0)) + } + } + } + + impl<'a, E: Exponent> std::hash::Hash for Key<'_, E> { + #[inline(always)] + fn hash(&self, state: &mut H) { + unsafe { + let b = &*self.monomials.get(); + b.1.get_unchecked(self.index..self.index + b.0).hash(state); + } + } + } + let mut res = self.zero_with_capacity(self.nterms().max(rhs.nterms())); - let mut cache: BTreeMap, Vec<(usize, usize)>> = BTreeMap::new(); + let mut cache: HashMap<_, Vec<(usize, usize)>> = HashMap::new(); let mut q_cache: Vec> = vec![]; // create a min-heap since our polynomials are sorted smallest to largest - let mut h: BinaryHeap>> = BinaryHeap::with_capacity(self.nterms()); - - let monom: Vec = self - .exponents(0) - .iter() - .zip(rhs.exponents(0)) - .map(|(e1, e2)| *e1 + *e2) - .collect(); - cache.insert(monom.clone(), vec![(0, 0)]); - h.push(Reverse(monom)); - - let mut m_cache: Vec = vec![E::zero(); self.nvars()]; + let mut h: BinaryHeap> = BinaryHeap::with_capacity(self.nterms()); + + cache.insert( + Key { + index: 0, + monomials: &monomials, + }, + vec![(0, 0)], + ); + h.push(Reverse(Key { + index: 0, + monomials: &monomials, + })); // i=merged_index[j] signifies that self[i]*other[j] has been merged let mut merged_index = vec![0; rhs.nterms()]; @@ -2205,23 +2444,31 @@ impl MultivariatePolynomial { merged_index[j] = i + 1; if i + 1 < self.nterms() && (j == 0 || merged_index[j - 1] > i + 1) { - for ((m, e1), e2) in m_cache - .iter_mut() - .zip(self.exponents(i + 1)) - .zip(rhs.exponents(j)) - { - *m = *e1 + *e2; - } + let m = unsafe { + let b = &mut *monomials.get(); + let index = b.1.len(); + b.1.extend( + self.exponents(i + 1) + .iter() + .zip(rhs.exponents(j)) + .map(|(e1, e2)| *e1 + *e2), + ); - if let Some(e) = cache.get_mut(&m_cache) { + Key { + index, + monomials: &monomials, + } + }; + + if let Some(e) = cache.get_mut(&m) { e.push((i + 1, j)); } else { - h.push(Reverse(m_cache.clone())); // only add when new + h.push(Reverse(m)); // only add when new if let Some(mut qq) = q_cache.pop() { qq.push((i + 1, j)); - cache.insert(m_cache.clone(), qq); + cache.insert(m, qq); } else { - cache.insert(m_cache.clone(), vec![(i + 1, j)]); + cache.insert(m, vec![(i + 1, j)]); } } } else { @@ -2229,24 +2476,32 @@ impl MultivariatePolynomial { } if j + 1 < rhs.nterms() && !in_heap[j + 1] { - for ((m, e1), e2) in m_cache - .iter_mut() - .zip(self.exponents(i)) - .zip(rhs.exponents(j + 1)) - { - *m = *e1 + *e2; - } + let m = unsafe { + let b = &mut *monomials.get(); + let index = b.1.len(); + b.1.extend( + self.exponents(i) + .iter() + .zip(rhs.exponents(j + 1)) + .map(|(e1, e2)| *e1 + *e2), + ); - if let Some(e) = cache.get_mut(&m_cache) { + Key { + index, + monomials: &monomials, + } + }; + + if let Some(e) = cache.get_mut(&m) { e.push((i, j + 1)); } else { - h.push(Reverse(m_cache.clone())); // only add when new + h.push(Reverse(m)); // only add when new if let Some(mut qq) = q_cache.pop() { qq.push((i, j + 1)); - cache.insert(m_cache.clone(), qq); + cache.insert(m, qq); } else { - cache.insert(m_cache.clone(), vec![(i, j + 1)]); + cache.insert(m, vec![(i, j + 1)]); } } @@ -2258,9 +2513,15 @@ impl MultivariatePolynomial { if !F::is_zero(&coefficient) { res.coefficients.push(coefficient); - res.exponents.extend_from_slice(&cur_mon.0); + + unsafe { + let b = &*monomials.get(); + res.exponents + .extend_from_slice(&b.1[cur_mon.0.index..cur_mon.0.index + b.0]); + } } } + res } @@ -2368,147 +2629,9 @@ impl MultivariatePolynomial { } res } - - /// Synthetic division for univariate polynomials, where `div` is monic. - // TODO: create UnivariatePolynomial? - pub fn quot_rem_univariate_monic( - &self, - div: &MultivariatePolynomial, - ) -> ( - MultivariatePolynomial, - MultivariatePolynomial, - ) { - debug_assert_eq!(div.lcoeff(), self.ring.one()); - if self.is_zero() { - return (self.clone(), self.clone()); - } - - let mut dividendpos = self.nterms() - 1; // work from the back - - let mut q = self.zero_with_capacity(self.nterms()); - let mut r = self.zero(); - - // determine the variable - let mut var = 0; - for (i, x) in self.last_exponents().iter().enumerate() { - if !x.is_zero() { - var = i; - break; - } - } - - let m = div.ldegree_max(); - let mut pow = self.ldegree_max(); - - loop { - // find the power in the dividend if it exists - let mut coeff = loop { - if self.exponents(dividendpos)[var] == pow { - break self.coefficients[dividendpos].clone(); - } - if dividendpos == 0 || self.exponents(dividendpos)[var] < pow { - break self.ring.zero(); - } - dividendpos -= 1; - }; - - let mut qindex = 0; // starting from highest - let mut bindex = 0; // starting from lowest - while bindex < div.nterms() && qindex < q.nterms() { - while bindex + 1 < div.nterms() - && div.exponents(bindex)[var] + q.exponents(qindex)[var] < pow - { - bindex += 1; - } - - if div.exponents(bindex)[var] + q.exponents(qindex)[var] == pow { - self.ring.sub_mul_assign( - &mut coeff, - &div.coefficients[bindex], - &q.coefficients[qindex], - ); - } - - qindex += 1; - } - - if !F::is_zero(&coeff) { - // can the division be performed? if not, add to rest - // TODO: refactor - let (quot, div) = if pow >= m { - (coeff, true) - } else { - (coeff, false) - }; - - if div { - let nterms = q.nterms(); - let nvars = q.nvars(); - q.coefficients.push(quot); - q.exponents.resize((nterms + 1) * nvars, E::zero()); - q.exponents[nterms * nvars + var] = pow - m; - } else { - let nterms = r.nterms(); - let nvars = r.nvars(); - r.coefficients.push(quot); - r.exponents.resize((nterms + 1) * nvars, E::zero()); - r.exponents[nterms * nvars + var] = pow; - } - } - - if pow.is_zero() { - break; - } - - pow = pow - E::one(); - } - - q.reverse(); - r.reverse(); - - #[cfg(debug_assertions)] - { - if !(&q * div + r.clone() - self.clone()).is_zero() { - panic!("Division failed: ({})/({}): q={}, r={}", self, div, q, r); - } - } - - (q, r) - } - - /// Shift a variable `var` to `var+shift`. - pub fn shift_var(&self, var: usize, shift: &F::Element) -> Self { - let d = self.degree(var).to_u32() as usize; - - let y_poly = self.to_univariate_polynomial_list(var); - - let mut v = vec![self.zero(); d + 1]; - for (x_poly, p) in y_poly { - v[p.to_u32() as usize] = x_poly; - } - - for k in 0..d { - for j in (k..d).rev() { - v[j] = &v[j] + &v[j + 1].clone().mul_coeff(shift.clone()); - } - } - - let mut poly = self.zero(); - for (i, mut v) in v.into_iter().enumerate() { - for x in v.exponents.chunks_mut(self.nvars()) { - x[var] = E::from_u32(i as u32); - } - - for m in &v { - poly.append_monomial(m.coefficient.clone(), m.exponents); - } - } - - poly - } } -impl MultivariatePolynomial { +impl MultivariatePolynomial { /// Get the content from the coefficients. pub fn content(&self) -> F::Element { if self.coefficients.is_empty() { @@ -2541,7 +2664,9 @@ impl MultivariatePolynomial { let c = self.content(); self.div_coeff(&c) } +} +impl MultivariatePolynomial { pub fn divides( &self, div: &MultivariatePolynomial, @@ -3254,7 +3379,9 @@ impl MultivariatePolynomial { self } } +} +impl MultivariatePolynomial { /// Integrate the polynomial w.r.t the variable `var`, /// producing the antiderivative with zero constant. pub fn integrate(&self, var: usize) -> Self { @@ -3277,7 +3404,7 @@ impl MultivariatePolynomial { } } -impl MultivariatePolynomial { +impl MultivariatePolynomial { /// Optimized division routine for univariate polynomials over a field, which /// makes the divisor monic first. pub fn quot_rem_univariate( @@ -3502,7 +3629,7 @@ impl MultivariatePolynomial { } } -impl Derivable for PolynomialRing { +impl Derivable for PolynomialRing { fn derivative( &self, p: &MultivariatePolynomial, @@ -3542,7 +3669,7 @@ impl MultivariatePolynomial, E> { for t in self { exp[..self.nvars()].copy_from_slice(t.exponents); for t2 in &t.coefficient.poly { - exp[var_index] = E::from_u32(t2.exponents[0].to_u32()); + exp[var_index] = E::from_i32(t2.exponents[0].to_i32()); poly.append_monomial(t2.coefficient.clone(), &exp); } } diff --git a/src/poly/series.rs b/src/poly/series.rs index 6bbafea7..f23f17de 100644 --- a/src/poly/series.rs +++ b/src/poly/series.rs @@ -12,7 +12,7 @@ use crate::{ atom::AtomField, integer::Integer, rational::{Rational, Q}, - EuclideanDomain, InternalOrdering, Ring, + EuclideanDomain, InternalOrdering, Ring, SelfRing, }, printer::{PrintOptions, PrintState}, state::State, @@ -465,20 +465,40 @@ impl Series { self } +} + +impl SelfRing for Series { + fn is_zero(&self) -> bool { + self.is_zero() + } - pub fn format( + fn is_one(&self) -> bool { + self.is_one() + } + + fn format( &self, opts: &PrintOptions, mut state: PrintState, f: &mut W, ) -> Result { - let v = self.variable.to_string_with_state(PrintState { - in_exp: true, - ..state - }); + let v = self.variable.format_string( + opts, + PrintState { + in_exp: true, + ..state + }, + ); if self.coefficients.is_empty() { - write!(f, "𝒪({}^{})", v, self.absolute_order())?; + let o = self.absolute_order(); + if opts.latex { + write!(f, "\\mathcal{{O}}\\left({}^{{{}}}\\right)", v, o)?; + } else { + write!(f, "𝒪({}^", v)?; + Q.format(&o, opts, state.step(false, false, true), f)?; + f.write_char(')')?; + } return Ok(false); } @@ -493,6 +513,7 @@ impl Series { state.in_exp = false; f.write_str("(")?; } + let in_product = state.in_product; for (e, c) in self.coefficients.iter().enumerate() { if F::is_zero(c) { @@ -501,6 +522,7 @@ impl Series { let e = self.get_exponent(e); + state.in_product = in_product || !e.is_zero(); state.suppress_one = !e.is_zero(); let suppressed_one = self.field.format( c, @@ -517,18 +539,30 @@ impl Series { write!(f, "{}", v)?; } else if !e.is_zero() { write!(f, "{}^", v)?; + state.suppress_one = false; + + if opts.latex { + f.write_char('{')?; + } + Q.format(&e, opts, state.step(false, false, true), f)?; + + if opts.latex { + f.write_char('}')?; + } } state.in_sum = true; - state.in_product = true; } let o = self.absolute_order(); - if o.is_integer() { - write!(f, "+𝒪({}^{})", v, o)?; + + if opts.latex { + write!(f, "+\\mathcal{{O}}\\left({}^{{{}}}\\right)", v, o)?; } else { - write!(f, "+𝒪({}^({}))", v, o)?; + write!(f, "+𝒪({}^", v)?; + Q.format(&o, opts, state.step(false, false, true), f)?; + f.write_char(')')?; } if add_paren { @@ -539,6 +573,28 @@ impl Series { } } +impl InternalOrdering for Series { + fn internal_cmp(&self, other: &Self) -> Ordering { + if self.variable != other.variable { + return self.variable.cmp(&other.variable); + } + + if self.shift != other.shift { + return self.shift.cmp(&other.shift); + } + + if self.ramification != other.ramification { + return self.ramification.cmp(&other.ramification); + } + + if self.order != other.order { + return self.order.cmp(&other.order); + } + + self.coefficients.internal_cmp(&other.coefficients) + } +} + impl PartialEq for Series { #[inline] fn eq(&self, other: &Self) -> bool { diff --git a/src/poly/univariate.rs b/src/poly/univariate.rs index aabd824a..fe49da20 100644 --- a/src/poly/univariate.rs +++ b/src/poly/univariate.rs @@ -17,7 +17,7 @@ use crate::{ use super::{ factor::Factorize, polynomial::{MultivariatePolynomial, PolynomialRing}, - Exponent, Variable, + PositiveExponent, Variable, }; #[derive(Clone, PartialEq, Eq, Hash, Debug)] @@ -457,7 +457,7 @@ impl UnivariatePolynomial { } /// Convert from a univariate polynomial to a multivariate polynomial. - pub fn to_multivariate(self) -> MultivariatePolynomial { + pub fn to_multivariate(self) -> MultivariatePolynomial { let mut res = MultivariatePolynomial::new( &self.ring, self.degree().into(), @@ -538,10 +538,13 @@ impl SelfRing for UnivariatePolynomial { f.write_str("(")?; } - let v = self.variable.to_string_with_state(PrintState { - in_exp: true, - ..state - }); + let v = self.variable.format_string( + opts, + PrintState { + in_exp: true, + ..state + }, + ); for (e, c) in self.coefficients.iter().enumerate() { state.suppress_one = e > 0; @@ -1730,7 +1733,7 @@ impl UnivariatePolynomial { } } -impl UnivariatePolynomial> { +impl UnivariatePolynomial> { /// Convert a univariate polynomial of multivariate polynomials to a multivariate polynomial. pub fn flatten(self) -> MultivariatePolynomial { if self.is_zero() { diff --git a/src/solve.rs b/src/solve.rs index 9b97f817..d31850db 100644 --- a/src/solve.rs +++ b/src/solve.rs @@ -10,7 +10,7 @@ use crate::{ InternalOrdering, }, evaluate::FunctionMap, - poly::{Exponent, Variable}, + poly::{PositiveExponent, Variable}, tensors::matrix::Matrix, }; @@ -41,7 +41,7 @@ impl Atom { /// Solve a system that is linear in `vars`, if possible. /// Each expression in `system` is understood to yield 0. - pub fn solve_linear_system( + pub fn solve_linear_system( system: &[AtomView], vars: &[Symbol], ) -> Result, String> { @@ -189,7 +189,7 @@ impl<'a> AtomView<'a> { /// Solve a system that is linear in `vars`, if possible. /// Each expression in `system` is understood to yield 0. - pub fn solve_linear_system( + pub fn solve_linear_system( system: &[AtomView], vars: &[Symbol], ) -> Result, String> { diff --git a/symbolica.pyi b/symbolica.pyi index 4d1bc546..4472867e 100644 --- a/symbolica.pyi +++ b/symbolica.pyi @@ -2057,6 +2057,32 @@ class Series: >>> print(s) """ + def __str__(self) -> str: + """Print the series in a human-readable format.""" + + def to_latex(self) -> str: + """Convert the series into a LaTeX string.""" + + def format( + self, + terms_on_new_line: bool = False, + color_top_level_sum: bool = True, + color_builtin_symbols: bool = True, + print_finite_field: bool = True, + symmetric_representation_for_finite_field: bool = False, + explicit_rational_polynomial: bool = False, + number_thousands_separator: Optional[str] = None, + multiplication_operator: str = "*", + double_star_for_exponentiation: bool = False, + square_brackets_for_function: bool = False, + num_exp_as_superscript: bool = True, + latex: bool = False, + precision: Optional[int] = None, + ) -> str: + """ + Convert the series into a human-readable string. + """ + def __add__(self, other: Series | Expression) -> Series: """Add another series or expression to this series, returning the result.""" From e7a3a982a02cf0ebaa5c70858403472bf443c3c6 Mon Sep 17 00:00:00 2001 From: Ben Ruijl Date: Tue, 26 Nov 2024 10:23:28 +0100 Subject: [PATCH 2/6] Support collecting in multiple variables or functions - Collect using a conversion to a polynomial (negative powers are supported) - Add option to expand using a conversion to polynomial --- examples/collect.rs | 13 +- src/api/python.rs | 158 +++++++++++++++--------- src/atom.rs | 6 + src/collect.rs | 285 ++++++++++++-------------------------------- src/expand.rs | 163 ++++++++++++++++++++++--- src/id.rs | 5 + src/poly.rs | 63 +++++++--- src/transformer.rs | 68 ++++++----- symbolica.pyi | 33 ++--- 9 files changed, 448 insertions(+), 346 deletions(-) diff --git a/examples/collect.rs b/examples/collect.rs index 99efd6d6..b1c9a2ef 100644 --- a/examples/collect.rs +++ b/examples/collect.rs @@ -2,21 +2,20 @@ use symbolica::{atom::Atom, fun, state::State}; fn main() { let input = Atom::parse("x*(1+a)+x*5*y+f(5,x)+2+y^2+x^2 + x^3").unwrap(); - let x = State::get_symbol("x"); + let x = State::get_symbol("x").into(); let key = State::get_symbol("key"); let coeff = State::get_symbol("val"); - let (r, rest) = input.coefficient_list(x); + let r = input.coefficient_list::(std::slice::from_ref(&x)); println!("> Coefficient list:"); for (key, val) in r { println!("\t{} {}", key, val); } - println!("\t1 {}", rest); println!("> Collect in x:"); - let out = input.collect( - x, + let out = input.collect::( + &x, Some(Box::new(|x, out| { out.set_from_view(&x); })), @@ -25,8 +24,8 @@ fn main() { println!("\t{}", out); println!("> Collect in x with wrapping:"); - let out = input.collect( - x, + let out = input.collect::( + &x, Some(Box::new(move |a, out| { out.set_from_view(&a); *out = fun!(key, out); diff --git a/src/api/python.rs b/src/api/python.rs index 36bceb74..405deaca 100644 --- a/src/api/python.rs +++ b/src/api/python.rs @@ -452,20 +452,26 @@ impl PythonTransformer { /// >>> f = Expression.symbol('f') /// >>> e = f((x+1)**2).replace_all(f(x_), x_.transform().expand()) /// >>> print(e) - #[pyo3(signature = (var = None))] - pub fn expand(&self, var: Option) -> PyResult { + #[pyo3(signature = (var = None, via_poly = None))] + pub fn expand( + &self, + var: Option, + via_poly: Option, + ) -> PyResult { if let Some(var) = var { - let id = if let AtomView::Var(x) = var.to_expression().expr.as_view() { - x.get_symbol() + let e = var.to_expression(); + if matches!(e.expr, Atom::Var(_) | Atom::Fun(_)) { + return append_transformer!( + self, + Transformer::Expand(Some(e.expr), via_poly.unwrap_or(false)) + ); } else { return Err(exceptions::PyValueError::new_err( - "Expansion must be done wrt a variable or function name", + "Expansion must be done wrt an indeterminate", )); - }; - - return append_transformer!(self, Transformer::Expand(Some(id))); + } } else { - return append_transformer!(self, Transformer::Expand(None)); + return append_transformer!(self, Transformer::Expand(None, via_poly.unwrap_or(false))); } } @@ -1001,7 +1007,7 @@ impl PythonTransformer { } /// Create a transformer that collects terms involving the same power of `x`, - /// where `x` is a variable or function name. + /// where `x` is an indeterminate. /// Return the list of key-coefficient pairs and the remainder that matched no key. /// /// Both the key (the quantity collected in) and its coefficient can be mapped using @@ -1033,20 +1039,29 @@ impl PythonTransformer { /// A transformer to be applied to the quantity collected in /// coeff_map: Transformer /// A transformer to be applied to the coefficient - #[pyo3(signature = (x, key_map = None, coeff_map = None))] + #[pyo3(signature = (*x, key_map = None, coeff_map = None))] pub fn collect( &self, - x: ConvertibleToExpression, + x: Bound<'_, PyTuple>, key_map: Option, coeff_map: Option, ) -> PyResult { - let id = if let AtomView::Var(x) = x.to_expression().expr.as_view() { - x.get_symbol() - } else { - return Err(exceptions::PyValueError::new_err( - "Collect must be done wrt a variable or function name", - )); - }; + let mut xs = vec![]; + for a in x { + if let Ok(r) = a.extract::() { + if matches!(r.expr, Atom::Var(_) | Atom::Fun(_)) { + xs.push(r.expr.into()); + } else { + return Err(exceptions::PyValueError::new_err( + "Collect must be done wrt a variable or function", + )); + } + } else { + return Err(exceptions::PyValueError::new_err( + "Collect must be done wrt a variable or function", + )); + } + } let key_map = if let Some(key_map) = key_map { let Pattern::Transformer(p) = key_map.expr else { @@ -1084,7 +1099,7 @@ impl PythonTransformer { vec![] }; - return append_transformer!(self, Transformer::Collect(id, key_map, coeff_map)); + return append_transformer!(self, Transformer::Collect(xs, key_map, coeff_map)); } /// Create a transformer that collects terms involving the literal occurrence of `x`. @@ -3336,18 +3351,33 @@ impl PythonExpression { } /// Expand the expression. Optionally, expand in `var` only. - #[pyo3(signature = (var = None))] - pub fn expand(&self, var: Option) -> PyResult { + #[pyo3(signature = (var = None, via_poly = None))] + pub fn expand( + &self, + var: Option, + via_poly: Option, + ) -> PyResult { if let Some(var) = var { - let id = if let AtomView::Var(x) = var.to_expression().expr.as_view() { - x.get_symbol() + let e = var.to_expression(); + + if matches!(e.expr, Atom::Var(_) | Atom::Fun(_)) { + if via_poly.unwrap_or(false) { + let b = self + .expr + .as_view() + .expand_via_poly::(Some(e.expr.as_view())); + Ok(b.into()) + } else { + let b = self.expr.as_view().expand_in(e.expr.as_view()); + Ok(b.into()) + } } else { return Err(exceptions::PyValueError::new_err( - "Expansion must be done wrt a variable or function name", + "Expansion must be done wrt an indeterminate", )); - }; - - let b = self.expr.as_view().expand_in(id); + } + } else if via_poly.unwrap_or(false) { + let b = self.expr.as_view().expand_via_poly::(None); Ok(b.into()) } else { let b = self.expr.as_view().expand(); @@ -3355,7 +3385,7 @@ impl PythonExpression { } } - /// Collect terms involving the same power of `x`, where `x` is a variable or function name. + /// Collect terms involving the same power of `x`, where `x` is an indeterminate. /// Return the list of key-coefficient pairs and the remainder that matched no key. /// /// Both the *key* (the quantity collected in) and its coefficient can be mapped using @@ -3380,23 +3410,32 @@ impl PythonExpression { /// >>> print(e.collect(x, key_map=lambda x: exp(x), coeff_map=lambda x: coeff(x))) /// /// yields `var(1)*coeff(5)+var(x)*coeff(y+5)+var(x^2)*coeff(1)`. - #[pyo3(signature = (x, key_map = None, coeff_map = None))] + #[pyo3(signature = (*x, key_map = None, coeff_map = None))] pub fn collect( &self, - x: ConvertibleToExpression, + x: &Bound<'_, PyTuple>, key_map: Option, coeff_map: Option, ) -> PyResult { - let id = if let AtomView::Var(x) = x.to_expression().expr.as_view() { - x.get_symbol() - } else { - return Err(exceptions::PyValueError::new_err( - "Collect must be done wrt a variable or function name", - )); - }; + let mut xs = vec![]; + for a in x { + if let Ok(r) = a.extract::() { + if matches!(r.expr, Atom::Var(_) | Atom::Fun(_)) { + xs.push(r.expr.into()); + } else { + return Err(exceptions::PyValueError::new_err( + "Collect must be done wrt a variable or function", + )); + } + } else { + return Err(exceptions::PyValueError::new_err( + "Collect must be done wrt a variable or function", + )); + } + } - let b = self.expr.as_view().collect( - id, + let b = self.expr.collect_multiple::( + &Arc::new(xs), if let Some(key_map) = key_map { Some(Box::new(move |key, out| { Python::with_gil(|py| { @@ -3461,7 +3500,7 @@ impl PythonExpression { r.into() } - /// Collect terms involving the same power of `x`, where `x` is a variable or function name. + /// Collect terms involving the same power of `x`, where `x` is an indeterminate. /// Return the list of key-coefficient pairs and the remainder that matched no key. /// /// Examples @@ -3484,31 +3523,32 @@ impl PythonExpression { /// ``` pub fn coefficient_list( &self, - x: ConvertibleToExpression, + x: Bound<'_, PyTuple>, ) -> PyResult> { - let id = if let AtomView::Var(x) = x.to_expression().expr.as_view() { - x.get_symbol() - } else { - return Err(exceptions::PyValueError::new_err( - "Coefficient list must be done wrt a variable or function name", - )); - }; + let mut xs = vec![]; + for a in x { + if let Ok(r) = a.extract::() { + if matches!(r.expr, Atom::Var(_) | Atom::Fun(_)) { + xs.push(r.expr.into()); + } else { + return Err(exceptions::PyValueError::new_err( + "Collect must be done wrt a variable or function", + )); + } + } else { + return Err(exceptions::PyValueError::new_err( + "Collect must be done wrt a variable or function", + )); + } + } - let (list, rest) = self.expr.coefficient_list(id); + let list = self.expr.coefficient_list::(&xs); - let mut py_list: Vec<_> = list + let py_list: Vec<_> = list .into_iter() .map(|e| (e.0.to_owned().into(), e.1.into())) .collect(); - if let Atom::Num(n) = &rest { - if n.to_num_view().is_zero() { - return Ok(py_list); - } - } - - py_list.push((Atom::new_num(1).into(), rest.into())); - Ok(py_list) } diff --git a/src/atom.rs b/src/atom.rs index 4b50bc01..8e892508 100644 --- a/src/atom.rs +++ b/src/atom.rs @@ -257,6 +257,12 @@ impl Hash for AtomOrView<'_> { } } +impl<'a> From for AtomOrView<'a> { + fn from(s: Symbol) -> AtomOrView<'a> { + AtomOrView::Atom(Atom::new_var(s)) + } +} + impl<'a> From for AtomOrView<'a> { fn from(a: Atom) -> AtomOrView<'a> { AtomOrView::Atom(a) diff --git a/src/collect.rs b/src/collect.rs index d609d10e..c8480b96 100644 --- a/src/collect.rs +++ b/src/collect.rs @@ -1,15 +1,14 @@ -use ahash::HashMap; - use crate::{ - atom::{Add, AsAtomView, Atom, AtomView, Symbol}, + atom::{Add, AsAtomView, Atom, AtomOrView, AtomView, Symbol}, coefficient::CoefficientView, domains::{integer::Z, rational::Q}, - poly::{factor::Factorize, polynomial::MultivariatePolynomial}, + poly::{factor::Factorize, polynomial::MultivariatePolynomial, Exponent}, state::Workspace, }; +use std::sync::Arc; impl Atom { - /// Collect terms involving the same power of `x`, where `x` is a variable or function name, e.g. + /// Collect terms involving the same power of `x`, where `x` is a variable or function, e.g. /// /// ```math /// collect(x + x * y + x^2, x) = x * (1+y) + x^2 @@ -17,16 +16,16 @@ impl Atom { /// /// Both the *key* (the quantity collected in) and its coefficient can be mapped using /// `key_map` and `coeff_map` respectively. - pub fn collect( + pub fn collect( &self, - x: Symbol, + x: &AtomOrView, key_map: Option>, coeff_map: Option>, ) -> Atom { - self.as_view().collect(x, key_map, coeff_map) + self.as_view().collect::(x, key_map, coeff_map) } - /// Collect terms involving the same power of `x`, where `x` is a variable or function name, e.g. + /// Collect terms involving the same power of `x`, where `x` is a variable or function, e.g. /// /// ```math /// collect(x + x * y + x^2, x) = x * (1+y) + x^2 @@ -34,20 +33,19 @@ impl Atom { /// /// Both the *key* (the quantity collected in) and its coefficient can be mapped using /// `key_map` and `coeff_map` respectively. - pub fn collect_into( + pub fn collect_multiple( &self, - x: Symbol, + xs: &[AtomOrView], key_map: Option>, coeff_map: Option>, - out: &mut Atom, - ) { - self.as_view().collect_into(x, key_map, coeff_map, out) + ) -> Atom { + self.as_view().collect_multiple::(xs, key_map, coeff_map) } - /// Collect terms involving the same power of `x`, where `x` is a variable or function name. - /// Return the list of key-coefficient pairs and the remainder that matched no key. - pub fn coefficient_list(&self, x: Symbol) -> (Vec<(Atom, Atom)>, Atom) { - Workspace::get_local().with(|ws| self.as_view().coefficient_list_with_ws(x, ws)) + /// Collect terms involving the same power of `x` in `xs`, where `xs` is a list of indeterminates. + /// Return the list of key-coefficient pairs + pub fn coefficient_list(&self, xs: &[AtomOrView]) -> Vec<(Atom, Atom)> { + self.as_view().coefficient_list::(xs) } /// Collect terms involving the literal occurrence of `x`. @@ -78,7 +76,7 @@ impl Atom { } impl<'a> AtomView<'a> { - /// Collect terms involving the same power of `x`, where `x` is a variable or function name, e.g. + /// Collect terms involving the same power of `x`, where `x` is an indeterminate, e.g. /// /// ```math /// collect(x + x * y + x^2, x) = x * (1+y) + x^2 @@ -86,56 +84,38 @@ impl<'a> AtomView<'a> { /// /// Both the *key* (the quantity collected in) and its coefficient can be mapped using /// `key_map` and `coeff_map` respectively. - pub fn collect( + pub fn collect( &self, - x: Symbol, + x: &AtomOrView, key_map: Option>, coeff_map: Option>, ) -> Atom { - Workspace::get_local().with(|ws| { - let mut out = ws.new_atom(); - self.collect_with_ws_into(x, ws, key_map, coeff_map, &mut out); - out.into_inner() - }) + self.collect_multiple::(std::slice::from_ref(x), key_map, coeff_map) } - /// Collect terms involving the same power of `x`, where `x` is a variable or function name, e.g. - /// - /// ```math - /// collect(x + x * y + x^2, x) = x * (1+y) + x^2 - /// ``` - /// - /// Both the *key* (the quantity collected in) and its coefficient can be mapped using - /// `key_map` and `coeff_map` respectively. - pub fn collect_into( + pub fn collect_multiple( &self, - x: Symbol, + xs: &[AtomOrView], key_map: Option>, coeff_map: Option>, - out: &mut Atom, - ) { - Workspace::get_local().with(|ws| self.collect_with_ws_into(x, ws, key_map, coeff_map, out)) + ) -> Atom { + let mut out = Atom::new(); + Workspace::get_local() + .with(|ws| self.collect_multiple_impl::(xs, ws, key_map, coeff_map, &mut out)); + out } - /// Collect terms involving the same power of `x`, where `x` is a variable or function name, e.g. - /// - /// ```math - /// collect(x + x * y + x^2, x) = x * (1+y) + x^2 - /// ``` - /// - /// Both the *key* (the quantity collected in) and its coefficient can be mapped using - /// `key_map` and `coeff_map` respectively. - pub fn collect_with_ws_into( + pub fn collect_multiple_impl( &self, - x: Symbol, - workspace: &Workspace, + xs: &[AtomOrView], + ws: &Workspace, key_map: Option>, coeff_map: Option>, out: &mut Atom, ) { - let (h, rest) = self.coefficient_list_with_ws(x, workspace); + let r = self.coefficient_list::(xs); - let mut add_h = workspace.new_atom(); + let mut add_h = Atom::new(); let add = add_h.to_add(); fn map_key_coeff( @@ -168,159 +148,37 @@ impl<'a> AtomView<'a> { add.extend(mul_h.as_view()); } - for (key, coeff) in h { - map_key_coeff(key.as_view(), coeff, workspace, &key_map, &coeff_map, add); - } - - if !rest.is_zero() { - if key_map.is_some() { - let key = workspace.new_num(1); - map_key_coeff(key.as_view(), rest, workspace, &key_map, &coeff_map, add); - } else if let Some(coeff_map) = coeff_map { - let mut handle = workspace.new_atom(); - coeff_map(rest.as_view(), &mut handle); - add.extend(handle.as_view()); - } else { - add.extend(rest.as_view()); - } + for (key, coeff) in r { + map_key_coeff(key.as_view(), coeff, ws, &key_map, &coeff_map, add); } - add_h.as_view().normalize(workspace, out); - } - - /// Collect terms involving the same power of `x`, where `x` is a variable or function name. - /// Return the list of key-coefficient pairs and the remainder that matched no key. - pub fn coefficient_list(&self, x: Symbol) -> (Vec<(Atom, Atom)>, Atom) { - Workspace::get_local().with(|ws| self.coefficient_list_with_ws(x, ws)) + add_h.as_view().normalize(ws, out); } - /// Collect terms involving the same power of `x`, where `x` is a variable or function name. - /// Return the list of key-coefficient pairs and the remainder that matched no key. - pub fn coefficient_list_with_ws( - &self, - x: Symbol, - workspace: &Workspace, - ) -> (Vec<(Atom, Atom)>, Atom) { - let mut h = HashMap::default(); - let mut rest = workspace.new_atom(); - let rest_add = rest.to_add(); + /// Collect terms involving the same powers of `x` in `xs`, where `x` is an indeterminate. + /// Return the list of key-coefficient pairs. + pub fn coefficient_list(&self, xs: &[AtomOrView]) -> Vec<(Atom, Atom)> { + let vars = xs + .iter() + .map(|x| x.as_view().to_owned().into()) + .collect::>(); - let mut expanded = workspace.new_atom(); - self.expand_with_ws_into(workspace, Some(x), &mut expanded); + let p = self.to_polynomial_in_vars::(&Arc::new(vars)); - match expanded.as_view() { - AtomView::Add(a) => { - for arg in a { - arg.collect_factor_list(x, workspace, &mut h, rest_add) - } - } - _ => expanded - .as_view() - .collect_factor_list(x, workspace, &mut h, rest_add), - } + let mut coeffs = vec![]; + for t in p.into_iter() { + let mut key = Atom::new_num(1); - let mut rest_norm = Atom::new(); - rest.as_view().normalize(workspace, &mut rest_norm); - - let mut r: Vec<_> = h - .into_iter() - .map(|(k, v)| { - ( - { - let mut a = Atom::new(); - a.set_from_view(&k); - a - }, - { - let mut a = Atom::new(); - v.as_view().normalize(workspace, &mut a); - a - }, - ) - }) - .collect(); - r.sort_unstable_by(|(a, _), (b, _)| a.as_view().cmp(&b.as_view())); - - (r, rest_norm) - } - - /// Check if a factor contains `x` at the ground level. - #[inline] - fn has_key(&self, x: Symbol) -> bool { - match self { - AtomView::Var(v) => v.get_symbol() == x, - AtomView::Fun(f) => f.get_symbol() == x, - AtomView::Pow(p) => { - let (base, _) = p.get_base_exp(); - match base { - AtomView::Var(v) => v.get_symbol() == x, - AtomView::Fun(f) => f.get_symbol() == x, - _ => false, - } + for (p, v) in t.exponents.iter().zip(xs) { + let mut pow = Atom::new(); + pow.to_pow(v.as_view(), Atom::new_num(p.to_i32() as i64).as_view()); + key = key * pow; } - AtomView::Mul(_) => unreachable!("Mul is not a factor"), - _ => false, - } - } - - fn collect_factor_list( - &self, - x: Symbol, - workspace: &Workspace, - h: &mut HashMap, Add>, - rest: &mut Add, - ) { - match self { - AtomView::Add(_) => {} - AtomView::Mul(m) => { - if m.iter().any(|a| a.has_key(x)) { - let mut collected = workspace.new_atom(); - let mul = collected.to_mul(); - // we could have a double match if x*x(..) - // we then only collect on the first hit - let mut bracket = None; - - for a in m { - if bracket.is_none() && a.has_key(x) { - bracket = Some(a); - } else { - mul.extend(a); - } - } - - h.entry(bracket.unwrap()) - .and_modify(|e| { - e.extend(collected.as_view()); - }) - .or_insert({ - let mut a = Add::new(); - a.extend(collected.as_view()); - a - }); - - return; - } - } - _ => { - if self.has_key(x) { - // add the coefficient 1 - let collected = workspace.new_num(1); - h.entry(*self) - .and_modify(|e| { - e.extend(collected.as_view()); - }) - .or_insert({ - let mut a = Add::new(); - a.extend(collected.as_view()); - a - }); - return; - } - } + coeffs.push((key, t.coefficient.clone())); } - rest.extend(*self); + coeffs } /// Collect terms involving the literal occurrence of `x`. @@ -638,9 +496,13 @@ mod test { let input = Atom::parse("v1*(1+v3)+v1*5*v2+f1(5,v1)+2+v2^2+v1^2+v1^3").unwrap(); let x = State::get_symbol("v1"); - let (r, rest) = input.coefficient_list(x); + let r = input.coefficient_list::(&[x.into()]); let res = vec![ + ( + Atom::parse("1").unwrap(), + Atom::parse("v2^2+f1(5,v1)+2").unwrap(), + ), ( Atom::parse("v1").unwrap(), Atom::parse("v3+5*v2+1").unwrap(), @@ -648,15 +510,8 @@ mod test { (Atom::parse("v1^2").unwrap(), Atom::parse("1").unwrap()), (Atom::parse("v1^3").unwrap(), Atom::parse("1").unwrap()), ]; - let res_rest = Atom::parse("v2^2+f1(5,v1)+2").unwrap(); - let res_ref = res - .iter() - .map(|(a, b)| (a.clone(), b.clone())) - .collect::>(); - - assert_eq!(r, res_ref); - assert_eq!(rest, res_rest); + assert_eq!(r, res); } #[test] @@ -664,7 +519,7 @@ mod test { let input = Atom::parse("v1*(1+v3)+v1*5*v2+f1(5,v1)+2+v2^2+v1^2+v1^3").unwrap(); let x = State::get_symbol("v1"); - let out = input.collect(x, None, None); + let out = input.collect::(&x.into(), None, None); let ref_out = Atom::parse("v1^2+v1^3+v2^2+f1(5,v1)+v1*(5*v2+v3+1)+2").unwrap(); assert_eq!(out, ref_out) @@ -675,7 +530,7 @@ mod test { let input = Atom::parse("(1+v1)^2*v1+(1+v2)^100").unwrap(); let x = State::get_symbol("v1"); - let out = input.collect(x, None, None); + let out = input.collect::(&x.into(), None, None); let ref_out = Atom::parse("v1+2*v1^2+v1^3+(v2+1)^100").unwrap(); assert_eq!(out, ref_out) @@ -688,8 +543,8 @@ mod test { let key = State::get_symbol("f3"); let coeff = State::get_symbol("f4"); println!("> Collect in x with wrapping:"); - let out = input.collect( - x, + let out = input.collect::( + &x.into(), Some(Box::new(move |a, out| { out.set_from_view(&a); *out = fun!(key, out); @@ -755,4 +610,20 @@ mod test { assert_eq!(out, ref_out); } + + #[test] + fn coefficient_list_multiple() { + let input = Atom::parse( + "(v1+v2+v3)^2+v1+v1^2+ v2 + 5*v1*v2^2 + v3 + v2*(v4+1)^10 + v1*v5(1,2,3)^2 + v5(1,2)", + ) + .unwrap(); + + let out = input.as_view().coefficient_list::(&[ + State::get_symbol("v1").into(), + State::get_symbol("v2").into(), + Atom::parse("v5(1,2,3)").unwrap().into(), + ]); + + assert_eq!(out.len(), 8); + } } diff --git a/src/expand.rs b/src/expand.rs index 2a57001a..e9ce049e 100644 --- a/src/expand.rs +++ b/src/expand.rs @@ -1,26 +1,40 @@ -use std::ops::DerefMut; +use std::{ops::DerefMut, sync::Arc}; use smallvec::SmallVec; use crate::{ - atom::{Atom, AtomView, Symbol}, + atom::{representation::InlineVar, Atom, AtomView, Symbol}, coefficient::CoefficientView, combinatorics::CombinationWithReplacementIterator, - domains::integer::Integer, + domains::{integer::Integer, rational::Q}, + poly::{Exponent, Variable}, state::{RecycledAtom, Workspace}, }; impl Atom { - /// Expand an expression. + /// Expand an expression. The function [expand_via_poly] may be faster. pub fn expand(&self) -> Atom { self.as_view().expand() } - /// Expand an expression in the variable `var`. - pub fn expand_in(&self, var: Symbol) -> Atom { + /// Expand the expression by converting it to a polynomial, optionally + /// only in the indeterminate `var`. The parameter `E` should be a numerical type + /// that fits the largest exponent in the expanded expression. Often, + /// `u8` or `u16` is sufficient. + pub fn expand_via_poly(&self, var: Option) -> Atom { + self.as_view().expand_via_poly::(var) + } + + /// Expand an expression in the variable `var`. The function [expand_via_poly] may be faster. + pub fn expand_in(&self, var: AtomView) -> Atom { self.as_view().expand_in(var) } + /// Expand an expression in the variable `var`. + pub fn expand_in_symbol(&self, var: Symbol) -> Atom { + self.as_view().expand_in(InlineVar::from(var).as_view()) + } + /// Expand an expression, returning `true` iff the expression changed. pub fn expand_into(&self, out: &mut Atom) -> bool { self.as_view().expand_into(None, out) @@ -28,7 +42,7 @@ impl Atom { } impl<'a> AtomView<'a> { - /// Expand an expression. + /// Expand an expression. The function [expand_via_poly] may be faster. pub fn expand(&self) -> Atom { Workspace::get_local().with(|ws| { let mut a = ws.new_atom(); @@ -37,8 +51,8 @@ impl<'a> AtomView<'a> { }) } - /// Expand an expression. - pub fn expand_in(&self, var: Symbol) -> Atom { + /// Expand an expression. The function [expand_via_poly] may be faster. + pub fn expand_in(&self, var: AtomView) -> Atom { Workspace::get_local().with(|ws| { let mut a = ws.new_atom(); self.expand_with_ws_into(ws, Some(var), &mut a); @@ -47,7 +61,7 @@ impl<'a> AtomView<'a> { } /// Expand an expression, returning `true` iff the expression changed. - pub fn expand_into(&self, var: Option, out: &mut Atom) -> bool { + pub fn expand_into(&self, var: Option, out: &mut Atom) -> bool { Workspace::get_local().with(|ws| self.expand_with_ws_into(ws, var, out)) } @@ -55,7 +69,7 @@ impl<'a> AtomView<'a> { pub fn expand_with_ws_into( &self, workspace: &Workspace, - var: Option, + var: Option, out: &mut Atom, ) -> bool { let changed = self.expand_no_norm(workspace, var, out); @@ -69,10 +83,122 @@ impl<'a> AtomView<'a> { changed } + /// Check if the expression is expanded, optionally in only the variable or function `var`. + pub fn is_expanded(&self, var: Option) -> bool { + match self { + AtomView::Num(_) | AtomView::Var(_) | AtomView::Fun(_) => true, + AtomView::Pow(pow_view) => { + let (base, exp) = pow_view.get_base_exp(); + if !base.is_expanded(var) || !exp.is_expanded(var) { + return false; + } + + if let AtomView::Num(n) = exp { + if let CoefficientView::Natural(n, 1) = n.get_coeff_view() { + if n.unsigned_abs() <= u32::MAX as u64 { + if matches!(base, AtomView::Add(_) | AtomView::Mul(_)) { + return var.map(|s| !base.contains(s)).unwrap_or(false); + } + } + } + } + + true + } + AtomView::Mul(mul_view) => { + for arg in mul_view { + if !arg.is_expanded(var) { + return false; + } + + if matches!(arg, AtomView::Add(_)) { + return var.map(|s| !arg.contains(s)).unwrap_or(false); + } + } + + true + } + AtomView::Add(add_view) => { + for arg in add_view { + if !arg.is_expanded(var) { + return false; + } + } + + true + } + } + } + + /// Expand the expression by converting it to a polynomial, optionally + /// only in the indeterminate `var`. The parameter `E` should be a numerical type + /// that fits the largest exponent in the expanded expression. Often, + /// `u8` or `u16` is sufficient. + pub fn expand_via_poly(&self, var: Option) -> Atom { + let var_map = var.map(|v| Arc::new(vec![v.to_owned().into()])); + + let mut out = Atom::new(); + Workspace::get_local().with(|ws| { + self.expand_via_poly_impl::(ws, var, &var_map, &mut out); + }); + out + } + + fn expand_via_poly_impl( + &self, + ws: &Workspace, + var: Option, + var_map: &Option>>, + out: &mut Atom, + ) { + if self.is_expanded(var) { + out.set_from_view(self); + return; + } + + if let Some(v) = var { + if !self.contains(v) { + out.set_from_view(self); + return; + } + } + + match self { + AtomView::Num(_) | AtomView::Var(_) | AtomView::Fun(_) => unreachable!(), + AtomView::Pow(_) => { + if let Some(v) = var_map { + *out = self.to_polynomial_in_vars::(v).flatten(true); + } else { + *out = self.to_polynomial::<_, E>(&Q, None).to_expression(); + } + } + AtomView::Mul(_) => { + if let Some(v) = var_map { + *out = self.to_polynomial_in_vars::(v).flatten(true); + } else { + *out = self.to_polynomial::<_, E>(&Q, None).to_expression(); + } + } + AtomView::Add(add_view) => { + let mut t = ws.new_atom(); + + let add = out.to_add(); + + for arg in add_view { + arg.expand_via_poly_impl::(ws, var, &var_map, &mut t); + add.extend(t.as_view()); + } + + add.as_view().normalize(ws, &mut t); + std::mem::swap(out, &mut t); + } + } + } + /// Expand an expression, but do not normalize the result. - fn expand_no_norm(&self, workspace: &Workspace, var: Option, out: &mut Atom) -> bool { + fn expand_no_norm(&self, workspace: &Workspace, var: Option, out: &mut Atom) -> bool { if let Some(s) = var { - if !self.contains_symbol(s) { + if !self.contains(s) { out.set_from_view(self); return false; } @@ -334,7 +460,16 @@ mod test { fn expand_in_var() { let exp = Atom::parse("(1+v1)^2+(1+v2)^100") .unwrap() - .expand_in(State::get_symbol("v1")); + .expand_in_symbol(State::get_symbol("v1")); + let res = Atom::parse("1+2*v1+v1^2+(v2+1)^100").unwrap(); + assert_eq!(exp, res); + } + + #[test] + fn expand_with_poly() { + let exp = Atom::parse("(1+v1)^2+(1+v2)^100") + .unwrap() + .expand_in_symbol(State::get_symbol("v1")); let res = Atom::parse("1+2*v1+v1^2+(v2+1)^100").unwrap(); assert_eq!(exp, res); } diff --git a/src/id.rs b/src/id.rs index dfa884c3..9716c8f3 100644 --- a/src/id.rs +++ b/src/id.rs @@ -1245,6 +1245,11 @@ impl Pattern { matched } + /// Replace all occurrences in `target`, where replacements are tested in the order that they are given. + pub fn replace_all_multiple(target: AtomView, replacements: &[Replacement<'_>]) -> Atom { + target.replace_all_multiple(replacements) + } + pub fn pattern_match<'a: 'b, 'b>( &'b self, target: AtomView<'a>, diff --git a/src/poly.rs b/src/poly.rs index 6ba3a0aa..1832ce56 100644 --- a/src/poly.rs +++ b/src/poly.rs @@ -1187,9 +1187,29 @@ impl<'a> AtomView<'a> { if let AtomView::Num(n) = exp { let num_n = n.get_coeff_view(); if let CoefficientView::Natural(nn, nd) = num_n { - if nd == 1 && nn > 0 && nn < u32::MAX as i64 { - let b = base.to_polynomial_in_vars_impl(var_map, poly); - return b.pow(nn as usize); + if nd == 1 && nn > 0 && nn < i32::MAX as i64 { + return base + .to_polynomial_in_vars_impl(var_map, poly) + .pow(nn as usize); + } else if nd == 1 && nn < 0 && nn > i32::MIN as i64 { + // allow x^-2 as a term if supported by the exponent + if let Ok(e) = (nn as i32).try_into() { + if let AtomView::Var(v) = base { + let s = Variable::Symbol(v.get_symbol()); + if let Some(id) = var_map.iter().position(|v| v == &s) { + let mut exp = vec![E::zero(); var_map.len()]; + exp[id] = e; + return poly.monomial(field.one(), exp); + } else { + let mut var_map = var_map.as_ref().clone(); + var_map.push(s); + let mut exp = vec![E::zero(); var_map.len()]; + exp[var_map.len() - 1] = e; + + return poly.monomial(field.one(), exp); + } + } + } } } } @@ -1529,24 +1549,24 @@ impl<'a> AtomView<'a> { } impl MultivariatePolynomial { - /// Convert the polynomial to an expression. - pub fn to_nested_expression(&self) -> Atom { + /// Convert the polynomial to an expression, optionally distributing the polynomial variables over coefficient sums. + pub fn flatten(&self, distribute: bool) -> Atom { let mut out = Atom::default(); - Workspace::get_local().with(|ws| self.to_nested_expression_into_impl(ws, &mut out)); + Workspace::get_local().with(|ws| self.flatten_impl(distribute, ws, &mut out)); out } - fn to_nested_expression_into_impl(&self, workspace: &Workspace, out: &mut Atom) { + fn flatten_impl(&self, expand: bool, ws: &Workspace, out: &mut Atom) { if self.is_zero() { - out.set_from_view(&workspace.new_num(0).as_view()); + out.set_from_view(&ws.new_num(0).as_view()); return; } let add = out.to_add(); - let mut mul_h = workspace.new_atom(); - let mut num_h = workspace.new_atom(); - let mut pow_h = workspace.new_atom(); + let mut mul_h = ws.new_atom(); + let mut num_h = ws.new_atom(); + let mut pow_h = ws.new_atom(); let vars: Vec<_> = self.variables.iter().map(|v| v.to_atom()).collect(); @@ -1570,12 +1590,25 @@ impl MultivariatePolynomial { } } - mul.extend(monomial.coefficient.as_view()); - add.extend(mul_h.as_view()); + if expand { + if let AtomView::Add(a) = monomial.coefficient.as_view() { + let mut tmp = ws.new_atom(); + for term in a { + term.mul_with_ws_into(ws, mul_h.as_view(), &mut tmp); + add.extend(tmp.as_view()); + } + } else { + mul.extend(monomial.coefficient.as_view()); + add.extend(mul_h.as_view()); + } + } else { + mul.extend(monomial.coefficient.as_view()); + add.extend(mul_h.as_view()); + } } - let mut norm = workspace.new_atom(); - out.as_view().normalize(workspace, &mut norm); + let mut norm = ws.new_atom(); + out.as_view().normalize(ws, &mut norm); std::mem::swap(norm.deref_mut(), out); } } diff --git a/src/transformer.rs b/src/transformer.rs index dc0fadc5..a065c876 100644 --- a/src/transformer.rs +++ b/src/transformer.rs @@ -1,7 +1,7 @@ use std::{sync::Arc, time::Instant}; use crate::{ - atom::{representation::FunView, Atom, AtomView, Fun, Symbol}, + atom::{representation::FunView, Atom, AtomOrView, AtomView, Fun, Symbol}, coefficient::{Coefficient, CoefficientView}, combinatorics::{partitions, unique_permutations}, domains::rational::Rational, @@ -107,13 +107,13 @@ pub enum TransformerError { #[derive(Clone)] pub enum Transformer { /// Expand the rhs. - Expand(Option), + Expand(Option, bool), /// Derive the rhs w.r.t a variable. Derivative(Symbol), /// Perform a series expansion. Series(Symbol, Atom, Rational, bool), ///Collect all terms in powers of a variable. - Collect(Symbol, Vec, Vec), + Collect(Vec>, Vec, Vec), /// Apply find-and-replace on the lhs. ReplaceAll( Pattern, @@ -164,7 +164,7 @@ pub enum Transformer { impl std::fmt::Debug for Transformer { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Transformer::Expand(s) => f.debug_tuple("Expand").field(s).finish(), + Transformer::Expand(s, _) => f.debug_tuple("Expand").field(s).finish(), Transformer::Derivative(x) => f.debug_tuple("Derivative").field(x).finish(), Transformer::Collect(x, a, b) => { f.debug_tuple("Collect").field(x).field(a).field(b).finish() @@ -471,34 +471,44 @@ impl Transformer { Self::execute_chain(cur_input, t, workspace, out)?; } - Transformer::Expand(s) => { - cur_input.expand_with_ws_into(workspace, *s, out); + Transformer::Expand(s, via_poly) => { + if *via_poly { + *out = cur_input.expand_via_poly::(s.as_ref().map(|x| x.as_view())); + } else { + cur_input.expand_with_ws_into( + workspace, + s.as_ref().map(|x| x.as_view()), + out, + ); + } } Transformer::Derivative(x) => { cur_input.derivative_with_ws_into(*x, workspace, out); } - Transformer::Collect(x, key_map, coeff_map) => cur_input.collect_into( - *x, - if key_map.is_empty() { - None - } else { - let key_map = key_map.clone(); - Some(Box::new(move |i, o| { - Workspace::get_local() - .with(|ws| Self::execute_chain(i, &key_map, ws, o).unwrap()) - })) - }, - if coeff_map.is_empty() { - None - } else { - let coeff_map = coeff_map.clone(); - Some(Box::new(move |i, o| { - Workspace::get_local() - .with(|ws| Self::execute_chain(i, &coeff_map, ws, o).unwrap()) - })) - }, - out, - ), + Transformer::Collect(x, key_map, coeff_map) => cur_input + .collect_multiple_impl::( + x, + workspace, + if key_map.is_empty() { + None + } else { + let key_map = key_map.clone(); + Some(Box::new(move |i, o| { + Workspace::get_local() + .with(|ws| Self::execute_chain(i, &key_map, ws, o).unwrap()) + })) + }, + if coeff_map.is_empty() { + None + } else { + let coeff_map = coeff_map.clone(); + Some(Box::new(move |i, o| { + Workspace::get_local() + .with(|ws| Self::execute_chain(i, &coeff_map, ws, o).unwrap()) + })) + }, + out, + ), Transformer::Series(x, expansion_point, depth, depth_is_absolute) => { if let Ok(s) = cur_input.series( *x, @@ -835,7 +845,7 @@ mod test { Transformer::execute_chain( p.as_view(), &[ - Transformer::Expand(Some(State::get_symbol("v1"))), + Transformer::Expand(Some(Atom::new_var(State::get_symbol("v1"))), false), Transformer::Derivative(State::get_symbol("v1")), ], ws, diff --git a/symbolica.pyi b/symbolica.pyi index 4472867e..730fd69d 100644 --- a/symbolica.pyi +++ b/symbolica.pyi @@ -812,19 +812,21 @@ class Expression: A list of variables """ - def expand(self, var: Optional[Expression] = None) -> Expression: + def expand(self, var: Optional[Expression] = None, via_poly: Optional[bool] = None) -> Expression: """ Expand the expression. Optionally, expand in `var` only. + + Using `via_poly=True` may give a significant speedup for large expressions. """ def collect( self, - x: Expression, + *x: Expression, key_map: Optional[Callable[[Expression], Expression]] = None, coeff_map: Optional[Callable[[Expression], Expression]] = None, ) -> Expression: """ - Collect terms involving the same power of `x`, where `x` is a variable or function name. + Collect terms involving the same power of the indeterminate(s) `x`. Return the list of key-coefficient pairs and the remainder that matched no key. Both the key (the quantity collected in) and its coefficient can be mapped using @@ -851,8 +853,8 @@ class Expression: Parameters ---------- - x: Expression - The variable to collect terms in + *x: Expression + The variable(s) or function(s) to collect terms in key_map A function to be applied to the quantity collected in coeff_map @@ -860,10 +862,10 @@ class Expression: """ def coefficient_list( - self, x: Expression + self, *x: Expression ) -> Sequence[Tuple[Expression, Expression]]: - """Collect terms involving the same power of `x`, where `x` is a variable or function name. - Return the list of key-coefficient pairs and the remainder that matched no key. + """Collect terms involving the same power of `x`, where `x` are variables or functions. + Return the list of key-coefficient pairs. Examples -------- @@ -1497,8 +1499,10 @@ class Transformer: >>> e = Transformer().expand()((1+x)**2) """ - def expand(self, var: Optional[Expression] = None) -> Transformer: - """Create a transformer that expands products and powers. + def expand(self, var: Optional[Expression] = None, via_poly: Optional[bool] = None) -> Transformer: + """Create a transformer that expands products and powers. Optionally, expand in `var` only. + + Using `via_poly=True` may give a significant speedup for large expressions. Examples -------- @@ -1786,13 +1790,12 @@ class Transformer: def collect( self, - x: Expression, + *x: Expression, key_map: Optional[Transformer] = None, coeff_map: Optional[Transformer] = None, ) -> Transformer: """ - Create a transformer that collect terms involving the same power of `x`, - where `x` is a variable or function name. + Create a transformer that collects terms involving the same power of the indeterminate(s) `x`. Return the list of key-coefficient pairs and the remainder that matched no key. Both the key (the quantity collected in) and its coefficient can be mapped using @@ -1818,8 +1821,8 @@ class Transformer: Parameters ---------- - x: Expression - The variable to collect terms in + *x: Expression + The variable(s) or function(s) to collect terms in key_map: Transformer A transformer to be applied to the quantity collected in coeff_map: Transformer From e230e4bd201eb98e1cf81c8f751690db0850b8b3 Mon Sep 17 00:00:00 2001 From: Ben Ruijl Date: Wed, 27 Nov 2024 11:19:09 +0100 Subject: [PATCH 3/6] Add collect_num and expand_num - Fix normalization of x^0 - Normalize (a+b)+(a+b) as 2*(a+b) --- src/api/python.rs | 88 ++++++++++++++++++++++- src/coefficient.rs | 107 ++++++++++++++++++++++++++-- src/collect.rs | 150 +++++++++++++++++++++++++++++++++++++++- src/domains/float.rs | 4 ++ src/domains/rational.rs | 4 ++ src/expand.rs | 129 ++++++++++++++++++++++++++++++++++ src/normalize.rs | 23 +++++- src/transformer.rs | 12 ++++ symbolica.pyi | 80 ++++++++++++++++++++- 9 files changed, 587 insertions(+), 10 deletions(-) diff --git a/src/api/python.rs b/src/api/python.rs index 405deaca..b99396b8 100644 --- a/src/api/python.rs +++ b/src/api/python.rs @@ -475,6 +475,26 @@ impl PythonTransformer { } } + /// Create a transformer that distributes numbers in the expression, for example: + /// `2*(x+y)` -> `2*x+2*y`. + /// + /// Examples + /// -------- + /// + /// >>> from symbolica import * + /// >>> x, y = Expression.symbol('x', 'y') + /// >>> e = 3*(x+y)*(4*x+5*y) + /// >>> print(Transformer().expand_num()(e)) + /// + /// yields + /// + /// ``` + /// (3*x+3*y)*(4*x+5*y) + /// ``` + pub fn expand_num(&self) -> PyResult { + return append_transformer!(self, Transformer::ExpandNum); + } + /// Create a transformer that computes the product of a list of arguments. /// /// Examples @@ -1102,6 +1122,29 @@ impl PythonTransformer { return append_transformer!(self, Transformer::Collect(xs, key_map, coeff_map)); } + /// Create a transformer that collects numerical factors by removing the numerical content from additions. + /// For example, `-2*x + 4*x^2 + 6*x^3` will be transformed into `-2*(x - 2*x^2 - 3*x^3)`. + /// + /// The first argument of the addition is normalized to a positive quantity. + /// + /// Examples + /// -------- + /// + /// >>> from symbolica import * + /// >>> + /// >>> x, y = Expression.symbol('x', 'y') + /// >>> e = (-3*x+6*y)(2*x+2*y) + /// >>> print(Transformer().collect_num()(e)) + /// + /// yields + /// + /// ``` + /// -6*(x-2*y)*(x+y) + /// ``` + pub fn collect_num(&self) -> PyResult { + return append_transformer!(self, Transformer::CollectNum); + } + /// Create a transformer that collects terms involving the literal occurrence of `x`. pub fn coefficient(&self, x: ConvertibleToExpression) -> PyResult { let a = x.to_expression().expr; @@ -3385,6 +3428,26 @@ impl PythonExpression { } } + /// Distribute numbers in the expression, for example: + /// `2*(x+y)` -> `2*x+2*y`. + /// + /// Examples + /// -------- + /// + /// >>> from symbolica import Expression + /// >>> x, y = Expression.symbol('x', 'y') + /// >>> e = 3*(x+y)*(4*x+5*y) + /// >>> print(e.expand_num()) + /// + /// yields + /// + /// ``` + /// (3*x+3*y)*(4*x+5*y) + /// ``` + pub fn expand_num(&self) -> PythonExpression { + self.expr.expand_num().into() + } + /// Collect terms involving the same power of `x`, where `x` is an indeterminate. /// Return the list of key-coefficient pairs and the remainder that matched no key. /// @@ -3479,12 +3542,35 @@ impl PythonExpression { Ok(b.into()) } + /// Collect numerical factors by removing the numerical content from additions. + /// For example, `-2*x + 4*x^2 + 6*x^3` will be transformed into `-2*(x - 2*x^2 - 3*x^3)`. + /// + /// The first argument of the addition is normalized to a positive quantity. + /// + /// Examples + /// -------- + /// + /// >>> from symbolica import Expression + /// >>> + /// >>> x, y = Expression.symbol('x', 'y') + /// >>> e = (-3*x+6*y)(2*x+2*y) + /// >>> print(e.collect_num()) + /// + /// yields + /// + /// ```log + /// -6*(x-2*y)*(x+y) + /// ``` + pub fn collect_num(&self) -> PythonExpression { + self.expr.collect_num().into() + } + /// Collect terms involving the literal occurrence of `x`. /// /// Examples /// -------- /// - /// from symbolica import Expression + /// >>> from symbolica import Expression /// >>> /// >>> x, y = Expression.symbol('x', 'y') /// >>> e = 5*x + x * y + x**2 + y*x**2 diff --git a/src/coefficient.rs b/src/coefficient.rs index 3814144a..efea5884 100644 --- a/src/coefficient.rs +++ b/src/coefficient.rs @@ -1,7 +1,7 @@ use std::{ cmp::Ordering, f64::consts::LOG2_10, - ops::{Add, Div, Mul}, + ops::{Add, Div, Mul, Neg}, sync::Arc, }; @@ -17,10 +17,10 @@ use crate::{ finite_field::{ FiniteField, FiniteFieldCore, FiniteFieldElement, FiniteFieldWorkspace, ToFiniteField, }, - float::{Float, Real, SingleFloat}, + float::{Float, NumericalFloatLike, Real, SingleFloat}, integer::{Integer, IntegerRing, Z}, rational::{Rational, Q}, - rational_polynomial::RationalPolynomial, + rational_polynomial::{FromNumeratorAndDenominator, RationalPolynomial}, EuclideanDomain, Field, InternalOrdering, Ring, }, poly::{polynomial::MultivariatePolynomial, Variable, INLINED_EXPONENTS}, @@ -41,7 +41,7 @@ pub trait ConvertToRing: Ring { /// A coefficient that can appear in a Symbolica expression. /// In most cases, this is a rational number but it can also be a finite field element or /// a rational polynomial. -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum Coefficient { Rational(Rational), Float(Float), @@ -49,8 +49,6 @@ pub enum Coefficient { RationalPolynomial(RationalPolynomial), } -impl Eq for Coefficient {} - impl From for Coefficient { fn from(value: i64) -> Self { Coefficient::Rational(value.into()) @@ -165,6 +163,15 @@ impl Coefficient { Coefficient::Rational(Rational::one()) } + pub fn is_negative(&self) -> bool { + match self { + Coefficient::Rational(r) => r.is_negative(), + Coefficient::Float(f) => f.is_negative(), + Coefficient::FiniteField(_, _) => false, + Coefficient::RationalPolynomial(r) => r.numerator.lcoeff().is_negative(), + } + } + pub fn is_zero(&self) -> bool { match self { Coefficient::Rational(r) => r.is_zero(), @@ -173,6 +180,94 @@ impl Coefficient { Coefficient::RationalPolynomial(r) => r.numerator.is_zero(), } } + + pub fn is_one(&self) -> bool { + match self { + Coefficient::Rational(r) => r.is_one(), + Coefficient::Float(f) => f.is_one(), + Coefficient::FiniteField(num, field) => { + let f = State::get_finite_field(*field); + f.is_one(num) + } + Coefficient::RationalPolynomial(r) => r.numerator.is_one(), + } + } + + pub fn gcd(&self, rhs: &Self) -> Self { + match (self, rhs) { + (Coefficient::Rational(r1), Coefficient::Rational(r2)) => { + Coefficient::Rational(r1.gcd(r2)) + } + (Coefficient::FiniteField(_n1, i1), Coefficient::FiniteField(_n2, i2)) => { + if i1 != i2 { + panic!( + "Cannot multiply numbers from different finite fields: p1={}, p2={}", + State::get_finite_field(*i1).get_prime(), + State::get_finite_field(*i2).get_prime() + ); + } + let f = State::get_finite_field(*i1); + Coefficient::FiniteField(f.one(), *i1) + } + (Coefficient::FiniteField(_, _), _) | (_, Coefficient::FiniteField(_, _)) => { + panic!("Cannot multiply finite field to non-finite number. Convert other number first?"); + } + (Coefficient::Rational(r), Coefficient::RationalPolynomial(rp)) + | (Coefficient::RationalPolynomial(rp), Coefficient::Rational(r)) => { + let p = RationalPolynomial::from_num_den( + rp.numerator.constant(r.numerator()), + rp.numerator.constant(r.denominator()), + &Z, + false, + ); + + let g = p.gcd(rp); + if g.is_constant() { + (g.numerator.lcoeff(), g.denominator.lcoeff()).into() + } else { + unreachable!() + } + } + (Coefficient::RationalPolynomial(p1), Coefficient::RationalPolynomial(p2)) => { + let r = if p1.get_variables() != p2.get_variables() { + let mut p1 = p1.clone(); + let mut p2 = p2.clone(); + p1.unify_variables(&mut p2); + p1.gcd(&p2) + } else { + p1.gcd(&p2) + }; + + if r.is_constant() { + (r.numerator.lcoeff(), r.denominator.lcoeff()).into() + } else { + Coefficient::RationalPolynomial(r) + } + } + (Coefficient::Rational(_), Coefficient::Float(f)) + | (Coefficient::Float(f), Coefficient::Rational(_)) => Coefficient::Float(f.one()), + (Coefficient::Float(f1), Coefficient::Float(_f2)) => Coefficient::Float(f1.one()), + (Coefficient::Float(_), _) | (_, Coefficient::Float(_)) => { + panic!("Cannot add float to finite-field number or rational polynomial"); + } + } + } +} + +impl Neg for Coefficient { + type Output = Coefficient; + + fn neg(self) -> Coefficient { + match self { + Coefficient::Rational(r) => Coefficient::Rational(-r), + Coefficient::Float(f) => Coefficient::Float(-f), + Coefficient::FiniteField(n, i) => { + let f = State::get_finite_field(i); + Coefficient::FiniteField(f.neg(&n), i) + } + Coefficient::RationalPolynomial(p) => Coefficient::RationalPolynomial(-p), + } + } } impl Add for Coefficient { diff --git a/src/collect.rs b/src/collect.rs index c8480b96..0e2b913d 100644 --- a/src/collect.rs +++ b/src/collect.rs @@ -1,6 +1,6 @@ use crate::{ atom::{Add, AsAtomView, Atom, AtomOrView, AtomView, Symbol}, - coefficient::CoefficientView, + coefficient::{Coefficient, CoefficientView}, domains::{integer::Z, rational::Q}, poly::{factor::Factorize, polynomial::MultivariatePolynomial, Exponent}, state::Workspace, @@ -73,6 +73,14 @@ impl Atom { pub fn factor(&self) -> Atom { self.as_view().factor() } + + /// Collect numerical factors by removing the numerical content from additions. + /// For example, `-2*x + 4*x^2 + 6*x^3` will be transformed into `-2*(x - 2*x^2 - 3*x^3)`. + /// + /// The first argument of the addition is normalized to a positive quantity. + pub fn collect_num(&self) -> Atom { + self.as_view().collect_num() + } } impl<'a> AtomView<'a> { @@ -485,12 +493,152 @@ impl<'a> AtomView<'a> { pow } + + /// Collect numerical factors by removing the numerical content from additions. + /// For example, `-2*x + 4*x^2 + 6*x^3` will be transformed into `-2*(x - 2*x^2 - 3*x^3)`. + /// + /// The first argument of the addition is normalized to a positive quantity. + pub fn collect_num(&self) -> Atom { + Workspace::get_local().with(|ws| { + let mut coeff = Atom::new(); + self.collect_num_impl(ws, &mut coeff); + coeff + }) + } + + fn collect_num_impl(&self, ws: &Workspace, out: &mut Atom) -> bool { + fn get_num(a: AtomView) -> Option { + match a { + AtomView::Num(n) => Some(n.get_coeff_view().to_owned()), + AtomView::Add(add) => { + // perform GCD of all arguments + // make sure the first argument is positive + let mut is_negative = false; + let mut gcd: Option = None; + for arg in add.iter() { + if let Some(num) = get_num(arg) { + if let Some(g) = gcd { + gcd = Some(g.gcd(&num)); + } else { + is_negative = num.is_negative(); + gcd = Some(num); + } + } + } + + if let Some(g) = gcd { + if is_negative && !g.is_negative() { + Some(-g) + } else { + Some(g) + } + } else { + None + } + } + AtomView::Mul(mul) => { + if mul.has_coefficient() { + for aa in mul.iter() { + if let AtomView::Num(n) = aa { + return Some(n.get_coeff_view().to_owned()); + } + } + + unreachable!() + } else { + None + } + } + AtomView::Pow(_) | AtomView::Var(_) | AtomView::Fun(_) => None, + } + } + + match self { + AtomView::Add(a) => { + let mut r = ws.new_atom(); + let ra = r.to_add(); + let mut na = ws.new_atom(); + let mut changed = false; + for arg in a { + changed |= arg.collect_num_impl(ws, &mut na); + ra.extend(na.as_view()); + } + + if !changed { + out.set_from_view(self); + } else { + r.as_view().normalize(ws, out); + } + + if let AtomView::Add(aa) = out.as_view() { + if let Some(n) = get_num(out.as_view()) { + let v = ws.new_num(n); + // divide every term by n + let ra = r.to_add(); + let mut div = ws.new_atom(); + for arg in aa.iter() { + arg.div_with_ws_into(ws, v.as_view(), &mut div); + ra.extend(div.as_view()); + } + + let m = div.to_mul(); + m.extend(r.as_view()); + m.extend(v.as_view()); + m.as_view().normalize(ws, out); + changed = true; + } + } + + changed + } + AtomView::Mul(m) => { + let mut r = ws.new_atom(); + let ra = r.to_mul(); + let mut na = ws.new_atom(); + let mut changed = false; + for arg in m { + changed |= arg.collect_num_impl(ws, &mut na); + ra.extend(na.as_view()); + } + + if !changed { + out.set_from_view(self); + } else { + r.as_view().normalize(ws, out); + } + + changed + } + _ => { + out.set_from_view(self); + false + } + } + } } #[cfg(test)] mod test { use crate::{atom::Atom, fun, state::State}; + #[test] + fn collect_num() { + let input = Atom::parse("2*v1+4*v1^2+6*v1^3").unwrap(); + let out = input.collect_num(); + let ref_out = Atom::parse("2*(v1+2v1^2+3v1^3)").unwrap(); + assert_eq!(out, ref_out); + + let input = Atom::parse("(-3*v1+3*v2)(2*v3+2*v4)").unwrap(); + let out = input.collect_num(); + let ref_out = Atom::parse("-6*(v4+v3)*(v1-v2)").unwrap(); + assert_eq!(out, ref_out); + + let input = Atom::parse("v1+v2+2*(v1+v2)").unwrap(); + let out = input.expand_num().collect_num(); + let ref_out = Atom::parse("3*(v1+v2)").unwrap(); + assert_eq!(out, ref_out); + } + #[test] fn coefficient_list() { let input = Atom::parse("v1*(1+v3)+v1*5*v2+f1(5,v1)+2+v2^2+v1^2+v1^3").unwrap(); diff --git a/src/domains/float.rs b/src/domains/float.rs index 6e683c5f..5efb8c55 100644 --- a/src/domains/float.rs +++ b/src/domains/float.rs @@ -1620,6 +1620,10 @@ impl Float { self.0.is_finite() } + pub fn is_negative(&self) -> bool { + self.0.is_sign_negative() + } + /// Parse a float from a string. /// Precision can be specified by a trailing backtick followed by the precision. /// For example: ```1.234`20``` for a precision of 20 decimal digits. diff --git a/src/domains/rational.rs b/src/domains/rational.rs index 0f74cca1..bb5b5ef3 100644 --- a/src/domains/rational.rs +++ b/src/domains/rational.rs @@ -672,6 +672,10 @@ impl Rational { Q.neg(self) } + pub fn gcd(&self, other: &Rational) -> Rational { + Q.gcd(self, other) + } + pub fn to_f64(&self) -> f64 { rug::Rational::from(( self.numerator.clone().to_multi_prec(), diff --git a/src/expand.rs b/src/expand.rs index e9ce049e..32ca58bc 100644 --- a/src/expand.rs +++ b/src/expand.rs @@ -39,6 +39,12 @@ impl Atom { pub fn expand_into(&self, out: &mut Atom) -> bool { self.as_view().expand_into(None, out) } + + /// Distribute numbers in the expression, for example: + /// `2*(x+y)` -> `2*x+2*y`. + pub fn expand_num(&self) -> Atom { + self.as_view().expand_num() + } } impl<'a> AtomView<'a> { @@ -422,12 +428,135 @@ impl<'a> AtomView<'a> { } } } + + /// Distribute numbers in the expression, for example: + /// `2*(x+y)` -> `2*x+2*y`. + pub fn expand_num(&self) -> Atom { + let mut a = Atom::new(); + Workspace::get_local().with(|ws| { + self.expand_num_impl(ws, &mut a); + }); + a + } + + pub fn expand_num_into(&self, out: &mut Atom) { + Workspace::get_local().with(|ws| { + self.expand_with_ws_into(ws, None, out); + }) + } + + pub fn expand_num_impl(&self, ws: &Workspace, out: &mut Atom) -> bool { + match self { + AtomView::Num(_) | AtomView::Var(_) | AtomView::Fun(_) => { + out.set_from_view(self); + false + } + AtomView::Pow(pow_view) => { + let (base, exp) = pow_view.get_base_exp(); + let mut new_base = ws.new_atom(); + let mut changed = base.expand_num_impl(ws, &mut new_base); + + let mut new_exp = ws.new_atom(); + changed |= exp.expand_num_impl(ws, &mut new_exp); + + let mut pow_h = ws.new_atom(); + pow_h.to_pow(new_base.as_view(), new_exp.as_view()); + pow_h.as_view().normalize(ws, out); + + changed + } + AtomView::Mul(mul_view) => { + if !mul_view.has_coefficient() + || !mul_view.iter().any(|a| matches!(a, AtomView::Add(_))) + { + out.set_from_view(self); + return false; + } + + let mut args: Vec<_> = mul_view.iter().collect(); + let mut sum = None; + let mut num = None; + + args.retain(|a| { + if let AtomView::Add(_) = a { + if sum.is_none() { + sum = Some(a.clone()); + false + } else { + true + } + } else if let AtomView::Num(_) = a { + if num.is_none() { + num = Some(a.clone()); + false + } else { + true + } + } else { + true + } + }); + + let mut add = ws.new_atom(); + let add_view = add.to_add(); + let n = num.unwrap(); + + let mut m = ws.new_atom(); + if let AtomView::Add(sum) = sum.unwrap() { + for a in sum.iter() { + let mm = m.to_mul(); + mm.extend(a); + mm.extend(n); + add_view.extend(m.as_view()); + } + } + + add_view.as_view().normalize(ws, &mut m); + let m2 = add.to_mul(); + for a in args { + m2.extend(a); + } + m2.extend(m.as_view()); + + m2.as_view().normalize(ws, out); + + true + } + AtomView::Add(add_view) => { + let mut changed = false; + + let mut new = ws.new_atom(); + let add = new.to_add(); + + let mut new_arg = ws.new_atom(); + for arg in add_view { + changed |= arg.expand_num_impl(ws, &mut new_arg); + add.extend(new_arg.as_view()); + } + + if !changed { + out.set_from_view(self); + return false; + } + + new.as_view().normalize(ws, out); + true + } + } + } } #[cfg(test)] mod test { use crate::{atom::Atom, state::State}; + #[test] + fn expand_num() { + let exp = Atom::parse("5+2*v3*(v1-v2)*(v4+v5)").unwrap().expand_num(); + let res = Atom::parse("5+v3*(v4+v5)*(2*v1-2*v2)").unwrap(); + assert_eq!(exp, res); + } + #[test] fn exponent() { let exp = Atom::parse("(1+v1+v2)^4").unwrap().expand(); diff --git a/src/normalize.rs b/src/normalize.rs index 0738b369..4961981a 100644 --- a/src/normalize.rs +++ b/src/normalize.rs @@ -3,7 +3,7 @@ use std::{cmp::Ordering, ops::DerefMut}; use smallvec::SmallVec; use crate::{ - atom::{Atom, AtomView, Fun, Symbol}, + atom::{representation::InlineNum, Atom, AtomView, Fun, Symbol}, coefficient::{Coefficient, CoefficientView}, domains::{float::Real, integer::Z, rational::Q}, poly::Variable, @@ -390,6 +390,17 @@ impl Atom { new_exp.extend(exp2); let mut helper2 = workspace.new_atom(); helper.as_view().normalize(workspace, &mut helper2); + + if let AtomView::Num(n) = helper2.as_view() { + if n.is_zero() { + self.to_num(1.into()); + return true; + } else if n.is_one() { + self.set_from_view(&base2); + return true; + } + } + p1.set_from_base_and_exp(base2, helper2.as_view()); p1.set_normalized(true); return true; @@ -1357,6 +1368,16 @@ impl<'a> AtomView<'a> { /// Add two atoms and normalize the result. pub(crate) fn add_normalized(&self, rhs: AtomView, ws: &Workspace, out: &mut Atom) { + // write (a+b)+(a+b) as 2*(a+b) + if *self == rhs { + let mut a = ws.new_atom(); + let m = a.to_mul(); + m.extend(*self); + m.extend(InlineNum::new(2, 1).as_view()); + a.as_view().normalize(ws, out); + return; + } + let a = out.to_add(); a.grow_capacity(self.get_byte_size() + rhs.get_byte_size()); diff --git a/src/transformer.rs b/src/transformer.rs index a065c876..e0a152cf 100644 --- a/src/transformer.rs +++ b/src/transformer.rs @@ -108,12 +108,16 @@ pub enum TransformerError { pub enum Transformer { /// Expand the rhs. Expand(Option, bool), + /// Distribute numbers. + ExpandNum, /// Derive the rhs w.r.t a variable. Derivative(Symbol), /// Perform a series expansion. Series(Symbol, Atom, Rational, bool), ///Collect all terms in powers of a variable. Collect(Vec>, Vec, Vec), + /// Collect numbers. + CollectNum, /// Apply find-and-replace on the lhs. ReplaceAll( Pattern, @@ -165,10 +169,12 @@ impl std::fmt::Debug for Transformer { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Transformer::Expand(s, _) => f.debug_tuple("Expand").field(s).finish(), + Transformer::ExpandNum => f.debug_tuple("ExpandNum").finish(), Transformer::Derivative(x) => f.debug_tuple("Derivative").field(x).finish(), Transformer::Collect(x, a, b) => { f.debug_tuple("Collect").field(x).field(a).field(b).finish() } + Transformer::CollectNum => f.debug_tuple("CollectNum").finish(), Transformer::ReplaceAll(pat, rhs, ..) => { f.debug_tuple("ReplaceAll").field(pat).field(rhs).finish() } @@ -482,6 +488,9 @@ impl Transformer { ); } } + Transformer::ExpandNum => { + cur_input.expand_num_into(out); + } Transformer::Derivative(x) => { cur_input.derivative_with_ws_into(*x, workspace, out); } @@ -509,6 +518,9 @@ impl Transformer { }, out, ), + Transformer::CollectNum => { + *out = cur_input.collect_num(); + } Transformer::Series(x, expansion_point, depth, depth_is_absolute) => { if let Ok(s) = cur_input.series( *x, diff --git a/symbolica.pyi b/symbolica.pyi index 730fd69d..adb76118 100644 --- a/symbolica.pyi +++ b/symbolica.pyi @@ -819,6 +819,24 @@ class Expression: Using `via_poly=True` may give a significant speedup for large expressions. """ + def expand_num(self) -> Expression: + """ Distribute numbers in the expression, for example: `2*(x+y)` -> `2*x+2*y`. + + Examples + -------- + + >>> from symbolica import * + >>> x, y = Expression.symbol('x', 'y') + >>> e = 3*(x+y)*(4*x+5*y) + >>> print(e.expand_num()) + + yields + + ``` + (3*x+3*y)*(4*x+5*y) + ``` + """ + def collect( self, *x: Expression, @@ -861,6 +879,27 @@ class Expression: A function to be applied to the coefficient """ + def collect_num(self) -> Expression: + """Collect numerical factors by removing the content from additions. + For example, `-2*x + 4*x^2 + 6*x^3` will be transformed into `-2*(x - 2*x^2 - 3*x^3)`. + + The first argument of the addition is normalized to a positive quantity. + + Examples + -------- + + >>> from symbolica import * + >>> x, y = Expression.symbol('x', 'y') + >>> e = (-3*x+6*y)*(2*x+2*y) + >>> print(e.collect_num()) + + yields + + ``` + -6*(x+y)*(x-2*y) + ``` + """ + def coefficient_list( self, *x: Expression ) -> Sequence[Tuple[Expression, Expression]]: @@ -1513,6 +1552,24 @@ class Transformer: >>> print(e) """ + def expand_num(self) -> Expression: + """Create a transformer that distributes numbers in the expression, for example: `2*(x+y)` -> `2*x+2*y`. + + Examples + -------- + + >>> from symbolica import * + >>> x, y = Expression.symbol('x', 'y') + >>> e = 3*(x+y)*(4*x+5*y) + >>> print(Transformer().expand_num()(e)) + + yields + + ``` + (3*x+3*y)*(4*x+5*y) + ``` + """ + def prod(self) -> Transformer: """Create a transformer that computes the product of a list of arguments. @@ -1829,6 +1886,27 @@ class Transformer: A transformer to be applied to the coefficient """ + def collect_num(self) -> Expression: + """Create a transformer that collects numerical factors by removing the content from additions. + For example, `-2*x + 4*x^2 + 6*x^3` will be transformed into `-2*(x - 2*x^2 - 3*x^3)`. + + The first argument of the addition is normalized to a positive quantity. + + Examples + -------- + + >>> from symbolica import * + >>> x, y = Expression.symbol('x', 'y') + >>> e = (-3*x+6*y)*(2*x+2*y) + >>> print(Transformer().collect_num()(e)) + + yields + + ``` + -6*(x+y)*(x-2*y) + ``` + """ + def coefficient(self, x: Expression) -> Transformer: """Create a transformer that collects terms involving the literal occurrence of `x`. """ @@ -1973,7 +2051,7 @@ class Transformer: >>> e = e.transform().stats('replace', Transformer().replace_all(f(x_), 1)).execute() yields - ```log + ``` Stats for replace: In │ 1 │ 10.00 B │ Out │ 1 │ 3.00 B │ ⧗ 40.15µs From e73ea498dcca4a59010c335ccf4a3a22f1ef30e8 Mon Sep 17 00:00:00 2001 From: Ben Ruijl Date: Mon, 2 Dec 2024 12:45:38 +0100 Subject: [PATCH 4/6] Boolean queries on Python expressions now yield Conditions - Add conversion from conditions to pattern restrictions - Add contains condition --- src/api/python.rs | 395 ++++++++++++++++++++++++++++++++++++++-------- src/atom.rs | 2 +- src/id.rs | 93 +++++++++++ symbolica.pyi | 76 ++++++--- 4 files changed, 481 insertions(+), 85 deletions(-) diff --git a/src/api/python.rs b/src/api/python.rs index b99396b8..93107d5c 100644 --- a/src/api/python.rs +++ b/src/api/python.rs @@ -52,7 +52,7 @@ use crate::{ }, graph::Graph, id::{ - Condition, ConditionResult, Match, MatchSettings, MatchStack, Pattern, + Condition, ConditionResult, Evaluate, Match, MatchSettings, MatchStack, Pattern, PatternAtomTreeIterator, PatternOrMap, PatternRestriction, ReplaceIterator, Replacement, WildcardRestriction, }, @@ -1306,7 +1306,7 @@ impl PythonTransformer { &self, lhs: ConvertibleToPattern, rhs: ConvertibleToPatternOrMap, - cond: Option, + cond: Option, non_greedy_wildcards: Option>, level_range: Option<(usize, Option)>, level_is_tree_depth: Option, @@ -1352,7 +1352,7 @@ impl PythonTransformer { Transformer::ReplaceAll( lhs.to_pattern()?.expr, rhs.to_pattern_or_map()?, - cond.map(|r| r.condition.clone()).unwrap_or_default(), + cond.map(|r| r.0).unwrap_or_default(), settings, ) ); @@ -1711,6 +1711,242 @@ impl PythonPatternRestriction { } } +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +pub enum Relation { + Eq(Atom, Atom), + Ne(Atom, Atom), + Gt(Atom, Atom), + Ge(Atom, Atom), + Lt(Atom, Atom), + Le(Atom, Atom), + Contains(Atom, Atom), + IsType(Atom, AtomType), +} + +impl std::fmt::Display for Relation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Relation::Eq(a, b) => write!(f, "{} == {}", a, b), + Relation::Ne(a, b) => write!(f, "{} != {}", a, b), + Relation::Gt(a, b) => write!(f, "{} > {}", a, b), + Relation::Ge(a, b) => write!(f, "{} >= {}", a, b), + Relation::Lt(a, b) => write!(f, "{} < {}", a, b), + Relation::Le(a, b) => write!(f, "{} <= {}", a, b), + Relation::Contains(a, b) => write!(f, "{} contains {}", a, b), + Relation::IsType(a, b) => write!(f, "{} is type {:?}", a, b), + } + } +} + +impl Evaluate for Relation { + type State<'a> = (); + + fn evaluate(&self, _state: &()) -> ConditionResult { + match self { + Relation::Eq(a, b) => (a == b).into(), + Relation::Ne(a, b) => (a != b).into(), + Relation::Gt(a, b) => (a > b).into(), + Relation::Ge(a, b) => (a >= b).into(), + Relation::Lt(a, b) => (a < b).into(), + Relation::Le(a, b) => (a <= b).into(), + Relation::Contains(a, b) => (a.contains(b)).into(), + Relation::IsType(a, b) => match a { + Atom::Var(_) => (*b == AtomType::Var).into(), + Atom::Fun(_) => (*b == AtomType::Fun).into(), + Atom::Num(_) => (*b == AtomType::Num).into(), + Atom::Add(_) => (*b == AtomType::Add).into(), + Atom::Mul(_) => (*b == AtomType::Mul).into(), + Atom::Pow(_) => (*b == AtomType::Pow).into(), + Atom::Zero => (*b == AtomType::Num).into(), + }, + } + } +} + +/// A restriction on wildcards. +#[pyclass(name = "Condition", module = "symbolica")] +#[derive(Clone)] +pub struct PythonCondition { + pub condition: Condition, +} + +impl From> for PythonCondition { + fn from(condition: Condition) -> Self { + PythonCondition { condition } + } +} + +#[pymethods] +impl PythonCondition { + pub fn __repr__(&self) -> String { + format!("{:?}", self.condition) + } + + pub fn __str__(&self) -> String { + format!("{}", self.condition) + } + + pub fn eval(&self) -> bool { + self.condition.evaluate(&()) == ConditionResult::True + } + + pub fn __bool__(&self) -> bool { + self.eval() + } + + /// Create a new pattern restriction that is the logical 'and' operation between two restrictions (i.e., both should hold). + pub fn __and__(&self, other: Self) -> PythonCondition { + (self.condition.clone() & other.condition.clone()).into() + } + + /// Create a new pattern restriction that is the logical 'or' operation between two restrictions (i.e., one of the two should hold). + pub fn __or__(&self, other: Self) -> PythonCondition { + (self.condition.clone() | other.condition.clone()).into() + } + + /// Create a new pattern restriction that takes the logical 'not' of the current restriction. + pub fn __invert__(&self) -> PythonCondition { + (!self.condition.clone()).into() + } + + /// Convert the condition to a pattern restriction. + pub fn to_req(&self) -> PyResult { + self.condition + .clone() + .try_into() + .map(|e| PythonPatternRestriction { condition: e }) + .map_err(|e| exceptions::PyValueError::new_err(e)) + } +} + +macro_rules! req_cmp_rel { + ($self:ident,$num:ident,$cmp_any_atom:ident,$c:ident) => {{ + if !$cmp_any_atom && !matches!($num.as_view(), AtomView::Num(_)) { + return Err("Can only compare to number"); + }; + + match $self.as_view() { + AtomView::Var(v) => { + let name = v.get_symbol(); + if v.get_wildcard_level() == 0 { + return Err("Only wildcards can be restricted."); + } + + Ok(PatternRestriction::Wildcard(( + name, + WildcardRestriction::Filter(Box::new(move |v: &Match| { + let k = $num.as_view(); + + if let Match::Single(m) = v { + if !$cmp_any_atom { + if let AtomView::Num(_) = m { + return m.cmp(&k).$c(); + } + } else { + return m.cmp(&k).$c(); + } + } + + false + })), + ))) + } + _ => Err("Only wildcards can be restricted."), + } + }}; +} + +impl TryFrom for PatternRestriction { + type Error = &'static str; + + fn try_from(value: Relation) -> Result { + match value { + Relation::Eq(atom, atom1) => { + return req_cmp_rel!(atom, atom1, true, is_eq); + } + Relation::Ne(atom, atom1) => { + return req_cmp_rel!(atom, atom1, true, is_ne); + } + Relation::Gt(atom, atom1) => { + return req_cmp_rel!(atom, atom1, true, is_gt); + } + Relation::Ge(atom, atom1) => { + return req_cmp_rel!(atom, atom1, true, is_ge); + } + Relation::Lt(atom, atom1) => { + return req_cmp_rel!(atom, atom1, true, is_lt); + } + Relation::Le(atom, atom1) => { + return req_cmp_rel!(atom, atom1, true, is_le); + } + Relation::Contains(atom, atom1) => { + if let Atom::Var(v) = atom { + let name = v.get_symbol(); + if name.get_wildcard_level() == 0 { + return Err("Only wildcards can be restricted."); + } + + Ok(PatternRestriction::Wildcard(( + name, + WildcardRestriction::Filter(Box::new(move |m| match m { + Match::Single(v) => v.contains(atom1.as_view()), + Match::Multiple(_, v) => v.iter().any(|x| x.contains(atom1.as_view())), + Match::FunctionName(_) => false, + })), + ))) + } else { + Err("LHS must be wildcard") + } + } + Relation::IsType(atom, atom_type) => { + if let Atom::Var(v) = atom { + Ok(PatternRestriction::Wildcard(( + v.get_symbol(), + WildcardRestriction::IsAtomType(atom_type), + ))) + } else { + Err("LHS must be wildcard") + } + } + } + } +} + +impl TryFrom> for Condition { + type Error = &'static str; + + fn try_from(value: Condition) -> Result { + Ok(match value { + Condition::True => Condition::True, + Condition::False => Condition::False, + Condition::Yield(r) => Condition::Yield(r.try_into()?), + Condition::And(a) => Condition::And(Box::new((a.0.try_into()?, a.1.try_into()?))), + Condition::Or(a) => Condition::Or(Box::new((a.0.try_into()?, a.1.try_into()?))), + Condition::Not(a) => Condition::Not(Box::new((*a).try_into()?)), + }) + } +} + +pub struct ConvertibleToPatternRestriction(Condition); + +impl<'a> FromPyObject<'a> for ConvertibleToPatternRestriction { + fn extract_bound(ob: &Bound<'a, pyo3::PyAny>) -> PyResult { + if let Ok(a) = ob.extract::() { + Ok(ConvertibleToPatternRestriction(a.condition)) + } else if let Ok(a) = ob.extract::() { + Ok(ConvertibleToPatternRestriction( + a.condition + .try_into() + .map_err(|e| exceptions::PyValueError::new_err(e))?, + )) + } else { + Err(exceptions::PyTypeError::new_err( + "Cannot convert to pattern restriction", + )) + } + } +} + impl<'a> FromPyObject<'a> for ConvertibleToExpression { fn extract_bound(ob: &Bound<'a, pyo3::PyAny>) -> PyResult { if let Ok(a) = ob.extract::() { @@ -1723,13 +1959,13 @@ impl<'a> FromPyObject<'a> for ConvertibleToExpression { Ok(ConvertibleToExpression(Atom::new_num(i).into())) } else if let Ok(_) = ob.extract::() { // disallow direct string conversion - Err(exceptions::PyValueError::new_err( + Err(exceptions::PyTypeError::new_err( "Cannot convert to expression", )) } else if let Ok(f) = ob.extract::() { Ok(ConvertibleToExpression(Atom::new_num(f.0).into())) } else { - Err(exceptions::PyValueError::new_err( + Err(exceptions::PyTypeError::new_err( "Cannot convert to expression", )) } @@ -1741,13 +1977,13 @@ impl<'a> FromPyObject<'a> for Symbol { if let Ok(a) = ob.extract::() { match a.expr.as_view() { AtomView::Var(v) => Ok(v.get_symbol()), - e => Err(exceptions::PyValueError::new_err(format!( + e => Err(exceptions::PyTypeError::new_err(format!( "Expected variable instead of {}", e ))), } } else { - Err(exceptions::PyValueError::new_err("Not a valid variable")) + Err(exceptions::PyTypeError::new_err("Not a valid variable")) } } } @@ -2088,6 +2324,10 @@ impl PythonExpression { Err(exceptions::PyValueError::new_err( "Illegal character in name", )) + } else if name.chars().next().unwrap().is_numeric() { + Err(exceptions::PyValueError::new_err( + "Name cannot start with a number", + )) } else { Ok(name) } @@ -2813,8 +3053,13 @@ impl PythonExpression { /// >>> e.contains(x) # True /// >>> e.contains(x*y*z) # True /// >>> e.contains(x*y) # False - pub fn contains(&self, s: ConvertibleToExpression) -> bool { - self.expr.contains(s.to_expression().expr.as_view()) + pub fn contains(&self, s: ConvertibleToExpression) -> PythonCondition { + PythonCondition { + condition: Condition::Yield(Relation::Contains( + self.expr.clone(), + s.to_expression().expr, + )), + } } /// Get all symbols in the current expression, optionally including function symbols. @@ -2935,6 +3180,35 @@ impl PythonExpression { } } + /// Create a pattern restriction that filters for expressions that contain `a`. + pub fn req_contains(&self, a: PythonExpression) -> PyResult { + match self.expr.as_view() { + AtomView::Var(v) => { + let name = v.get_symbol(); + if v.get_wildcard_level() == 0 { + return Err(exceptions::PyTypeError::new_err( + "Only wildcards can be restricted.", + )); + } + + Ok(PythonPatternRestriction { + condition: ( + name, + WildcardRestriction::Filter(Box::new(move |m| match m { + Match::Single(v) => v.contains(a.expr.as_view()), + Match::Multiple(_, v) => v.iter().any(|x| x.contains(a.expr.as_view())), + Match::FunctionName(_) => false, + })), + ) + .into(), + }) + } + _ => Err(exceptions::PyTypeError::new_err( + "Only wildcards can be restricted.", + )), + } + } + /// Create a pattern restriction that treats the wildcard as a literal variable, /// so that it only matches to itself. pub fn req_lit(&self) -> PyResult { @@ -2957,41 +3231,45 @@ impl PythonExpression { } } - /// Compare two expressions. - fn __richcmp__(&self, other: ConvertibleToExpression, op: CompareOp) -> PyResult { - match op { - CompareOp::Eq => Ok(self.expr == other.to_expression().expr), - CompareOp::Ne => Ok(self.expr != other.to_expression().expr), - _ => { - let other = other.to_expression(); - if let n1 @ AtomView::Num(_) = self.expr.as_view() { - if let n2 @ AtomView::Num(_) = other.expr.as_view() { - return Ok(match op { - CompareOp::Eq => n1 == n2, - CompareOp::Ge => n1 >= n2, - CompareOp::Gt => n1 > n2, - CompareOp::Le => n1 <= n2, - CompareOp::Lt => n1 < n2, - CompareOp::Ne => n1 != n2, - }); + /// Test if the expression is of a certain type. + pub fn is_type(&self, atom_type: PythonAtomType) -> PythonCondition { + PythonCondition { + condition: Condition::Yield(Relation::IsType( + self.expr.clone(), + match atom_type { + PythonAtomType::Num => AtomType::Num, + PythonAtomType::Var => AtomType::Var, + PythonAtomType::Add => AtomType::Add, + PythonAtomType::Mul => AtomType::Mul, + PythonAtomType::Pow => AtomType::Pow, + PythonAtomType::Fn => AtomType::Fun, + }, + )), } } - Err(exceptions::PyTypeError::new_err(format!( - "Inequalities between expression that are not numbers are not allowed in {} {} {}", - self.__str__()?, + /// Compare two expressions. If one of the expressions is not a number, an + /// internal ordering will be used. + fn __richcmp__(&self, other: ConvertibleToExpression, op: CompareOp) -> PythonCondition { match op { - CompareOp::Eq => "==", - CompareOp::Ge => ">=", - CompareOp::Gt => ">", - CompareOp::Le => "<=", - CompareOp::Lt => "<", - CompareOp::Ne => "!=", - }, - other.__str__()?, - ) - )) - } + CompareOp::Eq => PythonCondition { + condition: Relation::Eq(self.expr.clone(), other.to_expression().expr).into(), + }, + CompareOp::Ne => PythonCondition { + condition: Relation::Ne(self.expr.clone(), other.to_expression().expr).into(), + }, + CompareOp::Ge => PythonCondition { + condition: Relation::Ge(self.expr.clone(), other.to_expression().expr).into(), + }, + CompareOp::Gt => PythonCondition { + condition: Relation::Gt(self.expr.clone(), other.to_expression().expr).into(), + }, + CompareOp::Le => PythonCondition { + condition: Relation::Le(self.expr.clone(), other.to_expression().expr).into(), + }, + CompareOp::Lt => PythonCondition { + condition: Relation::Lt(self.expr.clone(), other.to_expression().expr).into(), + }, } } @@ -3976,14 +4254,12 @@ impl PythonExpression { pub fn pattern_match( &self, lhs: ConvertibleToPattern, - cond: Option, + cond: Option, level_range: Option<(usize, Option)>, level_is_tree_depth: Option, allow_new_wildcards_on_rhs: Option, ) -> PyResult { - let conditions = cond - .map(|r| r.condition.clone()) - .unwrap_or(Condition::default()); + let conditions = cond.map(|r| r.0).unwrap_or(Condition::default()); let settings = MatchSettings { level_range: level_range.unwrap_or((0, None)), level_is_tree_depth: level_is_tree_depth.unwrap_or(false), @@ -4016,15 +4292,13 @@ impl PythonExpression { pub fn matches( &self, lhs: ConvertibleToPattern, - cond: Option, + cond: Option, level_range: Option<(usize, Option)>, level_is_tree_depth: Option, allow_new_wildcards_on_rhs: Option, ) -> PyResult { let pat = lhs.to_pattern()?.expr; - let conditions = cond - .map(|r| r.condition.clone()) - .unwrap_or(Condition::default()); + let conditions = cond.map(|r| r.0).unwrap_or(Condition::default()); let settings = MatchSettings { level_range: level_range.unwrap_or((0, None)), level_is_tree_depth: level_is_tree_depth.unwrap_or(false), @@ -4067,14 +4341,12 @@ impl PythonExpression { &self, lhs: ConvertibleToPattern, rhs: ConvertibleToPatternOrMap, - cond: Option, + cond: Option, level_range: Option<(usize, Option)>, level_is_tree_depth: Option, allow_new_wildcards_on_rhs: Option, ) -> PyResult { - let conditions = cond - .map(|r| r.condition.clone()) - .unwrap_or(Condition::default()); + let conditions = cond.map(|r| r.0.clone()).unwrap_or(Condition::default()); let settings = MatchSettings { level_range: level_range.unwrap_or((0, None)), level_is_tree_depth: level_is_tree_depth.unwrap_or(false), @@ -4138,7 +4410,7 @@ impl PythonExpression { &self, pattern: ConvertibleToPattern, rhs: ConvertibleToPatternOrMap, - cond: Option, + cond: Option, non_greedy_wildcards: Option>, level_range: Option<(usize, Option)>, level_is_tree_depth: Option, @@ -4185,15 +4457,11 @@ impl PythonExpression { let mut expr_ref = self.expr.as_view(); + let cond = cond.map(|r| r.0); + let mut out = RecycledAtom::new(); let mut out2 = RecycledAtom::new(); - while pattern.replace_all_into( - expr_ref, - rhs, - cond.as_ref().map(|r| &r.condition), - Some(&settings), - &mut out, - ) { + while pattern.replace_all_into(expr_ref, rhs, cond.as_ref(), Some(&settings), &mut out) { if !repeat.unwrap_or(false) { break; } @@ -4881,7 +5149,7 @@ impl PythonReplacement { pub fn new( pattern: ConvertibleToPattern, rhs: ConvertibleToPatternOrMap, - cond: Option, + cond: Option, non_greedy_wildcards: Option>, level_range: Option<(usize, Option)>, level_is_tree_depth: Option, @@ -4925,10 +5193,7 @@ impl PythonReplacement { settings.rhs_cache_size = rhs_cache_size; } - let cond = cond - .as_ref() - .map(|r| r.condition.clone()) - .unwrap_or(Condition::default()); + let cond = cond.map(|r| r.0).unwrap_or(Condition::default()); Ok(Self { pattern, diff --git a/src/atom.rs b/src/atom.rs index 8e892508..ddd45158 100644 --- a/src/atom.rs +++ b/src/atom.rs @@ -100,7 +100,7 @@ impl Symbol { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum AtomType { Num, Var, diff --git a/src/id.rs b/src/id.rs index 9716c8f3..41fddf1f 100644 --- a/src/id.rs +++ b/src/id.rs @@ -1345,6 +1345,41 @@ pub enum Condition { False, } +impl std::fmt::Display for Condition { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Condition::And(a) => write!(f, "({}) & ({})", a.0, a.1), + Condition::Or(o) => write!(f, "{} | {}", o.0, o.1), + Condition::Not(n) => write!(f, "!({})", n), + Condition::True => write!(f, "True"), + Condition::False => write!(f, "False"), + Condition::Yield(t) => write!(f, "{}", t), + } + } +} + +pub trait Evaluate { + type State<'a>; + + /// Evaluate a condition. + fn evaluate<'a>(&self, state: &Self::State<'a>) -> ConditionResult; +} + +impl Evaluate for Condition { + type State<'a> = T::State<'a>; + + fn evaluate(&self, state: &T::State<'_>) -> ConditionResult { + match self { + Condition::And(a) => a.0.evaluate(state) & a.1.evaluate(state), + Condition::Or(o) => o.0.evaluate(state) | o.1.evaluate(state), + Condition::Not(n) => !n.evaluate(state), + Condition::True => ConditionResult::True, + Condition::False => ConditionResult::False, + Condition::Yield(t) => t.evaluate(state), + } + } +} + impl From for Condition { fn from(value: T) -> Self { Condition::Yield(value) @@ -1430,6 +1465,64 @@ impl From for ConditionResult { } } +impl Evaluate for Condition { + type State<'a> = MatchStack<'a, 'a>; + + fn evaluate(&self, state: &MatchStack) -> ConditionResult { + match self { + Condition::And(a) => a.0.evaluate(state) & a.1.evaluate(state), + Condition::Or(o) => o.0.evaluate(state) | o.1.evaluate(state), + Condition::Not(n) => !n.evaluate(state), + Condition::True => ConditionResult::True, + Condition::False => ConditionResult::False, + Condition::Yield(t) => match t { + PatternRestriction::Wildcard((v, r)) => { + if let Some((_, value)) = state.stack.iter().find(|(k, _)| k == v) { + match r { + WildcardRestriction::IsAtomType(t) => match value { + Match::Single(AtomView::Num(_)) => *t == AtomType::Num, + Match::Single(AtomView::Var(_)) => *t == AtomType::Var, + Match::Single(AtomView::Add(_)) => *t == AtomType::Add, + Match::Single(AtomView::Mul(_)) => *t == AtomType::Mul, + Match::Single(AtomView::Pow(_)) => *t == AtomType::Pow, + Match::Single(AtomView::Fun(_)) => *t == AtomType::Fun, + _ => false, + }, + WildcardRestriction::IsLiteralWildcard(wc) => match value { + Match::Single(AtomView::Var(v)) => wc == &v.get_symbol(), + _ => false, + }, + WildcardRestriction::Length(min, max) => match value { + Match::Single(_) | Match::FunctionName(_) => { + *min <= 1 && max.map(|m| m >= 1).unwrap_or(true) + } + Match::Multiple(_, slice) => { + *min <= slice.len() + && max.map(|m| m >= slice.len()).unwrap_or(true) + } + }, + WildcardRestriction::Filter(f) => f(value), + WildcardRestriction::Cmp(v2, f) => { + if let Some((_, value2)) = state.stack.iter().find(|(k, _)| k == v2) + { + f(value, value2) + } else { + return ConditionResult::Inconclusive; + } + } + WildcardRestriction::NotGreedy => true, + } + .into() + } else { + ConditionResult::Inconclusive + } + } + PatternRestriction::MatchStack(mf) => mf(state), + }, + } + } +} + impl Condition { /// Check if the conditions on `var` are met fn check_possible(&self, var: Symbol, value: &Match, stack: &MatchStack) -> ConditionResult { diff --git a/symbolica.pyi b/symbolica.pyi index adb76118..43ce8da2 100644 --- a/symbolica.pyi +++ b/symbolica.pyi @@ -517,7 +517,7 @@ class Expression: transformations can be applied. """ - def contains(self, a: Expression | int | float | Decimal) -> bool: + def contains(self, a: Expression | int | float | Decimal) -> Condition: """Returns true iff `self` contains `a` literally. Examples @@ -569,6 +569,16 @@ class Expression: Yields `f(x)*f(1)`. """ + def is_type(self, atom_type: AtomType) -> Condition: + """ + Test if the expression is of a certain type. + """ + + def req_contains(self, a: Expression) -> PatternRestriction: + """ + Create a pattern restriction that filters for expressions that contain `a`. + """ + def req_lit(self) -> PatternRestriction: """ Create a pattern restriction that treats the wildcard as a literal variable, @@ -746,34 +756,34 @@ class Expression: >>> e = e.replace_all(f(x_,y_), 1, x_.req_cmp_ge(y_)) """ - def __eq__(self, other: Expression | int | float | Decimal) -> bool: + def __eq__(self, other: Expression | int | float | Decimal) -> Condition: """ Compare two expressions. """ - def __neq__(self, other: Expression | int | float | Decimal) -> bool: + def __neq__(self, other: Expression | int | float | Decimal) -> Condition: """ Compare two expressions. """ - def __lt__(self, other: Expression | int | float | Decimal) -> bool: + def __lt__(self, other: Expression | int | float | Decimal) -> Condition: """ - Compare two expressions. Both expressions must be a number. + Compare two expressions. If any of the two expressions is not a rational number, an interal ordering is used. """ - def __le__(self, other: Expression | int | float | Decimal) -> bool: + def __le__(self, other: Expression | int | float | Decimal) -> Condition: """ - Compare two expressions. Both expressions must be a number. + Compare two expressions. If any of the two expressions is not a rational number, an interal ordering is used. """ - def __gt__(self, other: Expression | int | float | Decimal) -> bool: + def __gt__(self, other: Expression | int | float | Decimal) -> Condition: """ - Compare two expressions. Both expressions must be a number. + Compare two expressions. If any of the two expressions is not a rational number, an interal ordering is used. """ - def __ge__(self, other: Expression | int | float | Decimal) -> bool: + def __ge__(self, other: Expression | int | float | Decimal) -> Condition: """ - Compare two expressions. Both expressions must be a number. + Compare two expressions. If any of the two expressions is not a rational number, an interal ordering is used. """ def __iter__(self) -> Iterator[Expression]: @@ -1056,7 +1066,7 @@ class Expression: def match( self, lhs: Transformer | Expression | int | float | Decimal, - cond: Optional[PatternRestriction] = None, + cond: Optional[PatternRestriction | Condition] = None, level_range: Optional[Tuple[int, Optional[int]]] = None, level_is_tree_depth: Optional[bool] = False, allow_new_wildcards_on_rhs: Optional[bool] = False, @@ -1083,7 +1093,7 @@ class Expression: def matches( self, lhs: Transformer | Expression | int | float | Decimal, - cond: Optional[PatternRestriction] = None, + cond: Optional[PatternRestriction | Condition] = None, level_range: Optional[Tuple[int, Optional[int]]] = None, level_is_tree_depth: Optional[bool] = False, allow_new_wildcards_on_rhs: Optional[bool] = False, @@ -1104,7 +1114,7 @@ class Expression: self, lhs: Transformer | Expression | int | float | Decimal, rhs: Transformer | Expression | Callable[[dict[Expression, Expression]], Expression] | int | float | Decimal, - cond: Optional[PatternRestriction] = None, + cond: Optional[PatternRestriction | Condition] = None, level_range: Optional[Tuple[int, Optional[int]]] = None, level_is_tree_depth: Optional[bool] = False, allow_new_wildcards_on_rhs: Optional[bool] = False, @@ -1150,7 +1160,7 @@ class Expression: self, pattern: Transformer | Expression | int | float | Decimal, rhs: Transformer | Expression | Callable[[dict[Expression, Expression]], Expression] | int | float | Decimal, - cond: Optional[PatternRestriction] = None, + cond: Optional[PatternRestriction | Condition] = None, non_greedy_wildcards: Optional[Sequence[Expression]] = None, level_range: Optional[Tuple[int, Optional[int]]] = None, level_is_tree_depth: Optional[bool] = False, @@ -1167,7 +1177,7 @@ class Expression: >>> x, w1_, w2_ = Expression.symbol('x','w1_','w2_') >>> f = Expression.symbol('f') >>> e = f(3,x) - >>> r = e.replace_all(f(w1_,w2_), f(w1_ - 1, w2_**2), (w1_ >= 1) & w2_.is_var()) + >>> r = e.replace_all(f(w1_,w2_), f(w1_ - 1, w2_**2), w1_ >= 1) >>> print(r) Parameters @@ -1464,7 +1474,7 @@ class Replacement: cls, pattern: Transformer | Expression | int | float | Decimal, rhs: Transformer | Expression | Callable[[dict[Expression, Expression]], Expression] | int | float | Decimal, - cond: Optional[PatternRestriction] = None, + cond: Optional[PatternRestriction | Condition] = None, non_greedy_wildcards: Optional[Sequence[Expression]] = None, level_range: Optional[Tuple[int, Optional[int]]] = None, level_is_tree_depth: Optional[bool] = False, @@ -1518,6 +1528,34 @@ class PatternRestriction: """ +class Condition: + """Relations that evaluate to booleans""" + + def eval(self) -> bool: + """Evaluate the condition.""" + + def __repr__(self) -> str: + """Return a string representation of the condition.""" + + def __str__(self) -> str: + """Return a string representation of the condition.""" + + def __bool__(self) -> bool: + """Return the boolean value of the condition.""" + + def __and__(self, other: Condition) -> Condition: + """Create a condition that is the logical and operation between two conditions (i.e., both should hold).""" + + def __or__(self, other: Condition) -> Condition: + """Create a condition that is the logical 'or' operation between two conditions (i.e., at least one of the two should hold).""" + + def __invert__(self) -> Condition: + """Create a condition that takes the logical 'not' of the current condition.""" + + def to_req(self) -> PatternRestriction: + """Convert the condition to a pattern restriction.""" + + class CompareOp: """One of the following comparison operators: `<`,`>`,`<=`,`>=`,`==`,`!=`.""" @@ -1955,7 +1993,7 @@ class Transformer: self, pat: Transformer | Expression | int | float | Decimal, rhs: Transformer | Expression | Callable[[dict[Expression, Expression]], Expression] | int | float | Decimal, - cond: Optional[PatternRestriction] = None, + cond: Optional[PatternRestriction | Condition] = None, non_greedy_wildcards: Optional[Sequence[Expression]] = None, level_range: Optional[Tuple[int, Optional[int]]] = None, level_is_tree_depth: Optional[bool] = False, @@ -1970,7 +2008,7 @@ class Transformer: >>> x, w1_, w2_ = Expression.symbol('x','w1_','w2_') >>> f = Expression.symbol('f') >>> e = f(3,x) - >>> r = e.transform().replace_all(f(w1_,w2_), f(w1_ - 1, w2_**2), (w1_ >= 1) & w2_.is_var()) + >>> r = e.transform().replace_all(f(w1_,w2_), f(w1_ - 1, w2_**2), w1_ >= 1) >>> print(r) Parameters From 88a9cbd132311a2b0f266d31feb12cdfb58525df Mon Sep 17 00:00:00 2001 From: Ben Ruijl Date: Tue, 3 Dec 2024 12:12:00 +0100 Subject: [PATCH 5/6] Add control flow transformers - Add if_then and if_changed - Add break_chain - Add condition functions for transformers --- src/api/python.rs | 340 +++++++++++++++++++++++++++++++-------------- src/id.rs | 201 +++++++++++++++++++++++---- src/transformer.rs | 58 ++++++-- symbolica.pyi | 79 +++++++++++ 4 files changed, 537 insertions(+), 141 deletions(-) diff --git a/src/api/python.rs b/src/api/python.rs index 93107d5c..4050a827 100644 --- a/src/api/python.rs +++ b/src/api/python.rs @@ -53,8 +53,8 @@ use crate::{ graph::Graph, id::{ Condition, ConditionResult, Evaluate, Match, MatchSettings, MatchStack, Pattern, - PatternAtomTreeIterator, PatternOrMap, PatternRestriction, ReplaceIterator, Replacement, - WildcardRestriction, + PatternAtomTreeIterator, PatternOrMap, PatternRestriction, Relation, ReplaceIterator, + Replacement, WildcardRestriction, }, numerical_integration::{ContinuousGrid, DiscreteGrid, Grid, MonteCarloRng, Sample}, parser::Token, @@ -443,6 +443,50 @@ impl PythonTransformer { } } + /// Compare two expressions. If one of the expressions is not a number, an + /// internal ordering will be used. + fn __richcmp__(&self, other: ConvertibleToPattern, op: CompareOp) -> PyResult { + Ok(match op { + CompareOp::Eq => PythonCondition { + condition: Relation::Eq(self.expr.clone(), other.to_pattern()?.expr).into(), + }, + CompareOp::Ne => PythonCondition { + condition: Relation::Ne(self.expr.clone(), other.to_pattern()?.expr).into(), + }, + CompareOp::Ge => PythonCondition { + condition: Relation::Ge(self.expr.clone(), other.to_pattern()?.expr).into(), + }, + CompareOp::Gt => PythonCondition { + condition: Relation::Gt(self.expr.clone(), other.to_pattern()?.expr).into(), + }, + CompareOp::Le => PythonCondition { + condition: Relation::Le(self.expr.clone(), other.to_pattern()?.expr).into(), + }, + CompareOp::Lt => PythonCondition { + condition: Relation::Lt(self.expr.clone(), other.to_pattern()?.expr).into(), + }, + }) + } + + /// Returns true iff `self` contains `a` literally. + /// + /// Examples + /// -------- + /// >>> from symbolica import * + /// >>> x, y, z = Expression.symbol('x', 'y', 'z') + /// >>> e = x * y * z + /// >>> e.contains(x) # True + /// >>> e.contains(x*y*z) # True + /// >>> e.contains(x*y) # False + pub fn contains(&self, s: ConvertibleToPattern) -> PyResult { + Ok(PythonCondition { + condition: Condition::Yield(Relation::Contains( + self.expr.clone(), + s.to_pattern()?.expr, + )), + }) + } + /// Create a transformer that expands products and powers. /// /// Examples @@ -917,6 +961,119 @@ impl PythonTransformer { return append_transformer!(self, Transformer::Repeat(rep_chain)); } + /// Evaluate the condition and apply the `if_block` if the condition is true, otherwise apply the `else_block`. + /// The expression that is the input of the transformer is the input for the condition, the `if_block` and the `else_block`. + /// + /// Examples + /// -------- + /// >>> t = T.map_terms(T.if_then(T.contains(x), T.print())) + /// >>> t(x + y + 4) + /// + /// prints `x`. + #[pyo3(signature = (condition, if_block, else_block = None))] + pub fn if_then( + &self, + condition: PythonCondition, + if_block: PythonTransformer, + else_block: Option, + ) -> PyResult { + let Pattern::Transformer(t1) = if_block.expr else { + return Err(exceptions::PyValueError::new_err( + "Argument must be a transformer", + )); + }; + + let t2 = if let Some(e) = else_block { + if let Pattern::Transformer(t2) = e.expr { + t2 + } else { + return Err(exceptions::PyValueError::new_err( + "Argument must be a transformer", + )); + } + } else { + Box::new((None, vec![])) + }; + + if t1.0.is_some() || t2.0.is_some() { + return Err(exceptions::PyValueError::new_err( + "Transformers in a repeat must be unbound. Use Transformer() to create it.", + )); + } + + return append_transformer!(self, Transformer::IfElse(condition.condition, t1.1, t2.1)); + } + + /// Execute the `condition` transformer. If the result of the `condition` transformer is different from the input expression, + /// apply the `if_block`, otherwise apply the `else_block`. The input expression of the `if_block` is the output + /// of the `condition` transformer. + /// + /// Examples + /// -------- + /// >>> t = T.map_terms(T.if_changed(T.replace_all(x, y), T.print())) + /// >>> print(t(x + y + 4)) + /// + /// prints + /// ```log + /// y + /// 2*y+4 + /// ``` + #[pyo3(signature = (condition, if_block, else_block = None))] + pub fn if_changed( + &self, + condition: PythonTransformer, + if_block: PythonTransformer, + else_block: Option, + ) -> PyResult { + let Pattern::Transformer(t0) = condition.expr else { + return Err(exceptions::PyValueError::new_err( + "Argument must be a transformer", + )); + }; + + let Pattern::Transformer(t1) = if_block.expr else { + return Err(exceptions::PyValueError::new_err( + "Argument must be a transformer", + )); + }; + + let t2 = if let Some(e) = else_block { + if let Pattern::Transformer(t2) = e.expr { + t2 + } else { + return Err(exceptions::PyValueError::new_err( + "Argument must be a transformer", + )); + } + } else { + Box::new((None, vec![])) + }; + + if t0.0.is_some() || t1.0.is_some() || t2.0.is_some() { + return Err(exceptions::PyValueError::new_err( + "Transformers in a repeat must be unbound. Use Transformer() to create it.", + )); + } + + return append_transformer!(self, Transformer::IfChanged(t0.1, t1.1, t2.1)); + } + + /// Break the current chain and all higher-level chains containing `if` transformers. + /// + /// Examples + /// -------- + /// >>> from symbolica import * + /// >>> t = T.map_terms(T.repeat( + /// >>> T.replace_all(y, 4), + /// >>> T.if_changed(T.replace_all(x, y), + /// >>> T.break_chain()), + /// >>> T.print() # print of y is never reached + /// >>> )) + /// >>> print(t(x)) + pub fn break_chain(&self) -> PyResult { + return append_transformer!(self, Transformer::BreakChain); + } + /// Chain several transformers. `chain(A,B,C)` is the same as `A.B.C`, /// where `A`, `B`, `C` are transformers. /// @@ -981,6 +1138,7 @@ impl PythonTransformer { workspace, &mut out, &MatchStack::new(&Condition::default(), &MatchSettings::default()), + None, ) }) .map_err(|e| match e { @@ -1138,7 +1296,7 @@ impl PythonTransformer { /// /// yields /// - /// ``` + /// ```log /// -6*(x-2*y)*(x+y) /// ``` pub fn collect_num(&self) -> PyResult { @@ -1711,58 +1869,6 @@ impl PythonPatternRestriction { } } -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub enum Relation { - Eq(Atom, Atom), - Ne(Atom, Atom), - Gt(Atom, Atom), - Ge(Atom, Atom), - Lt(Atom, Atom), - Le(Atom, Atom), - Contains(Atom, Atom), - IsType(Atom, AtomType), -} - -impl std::fmt::Display for Relation { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Relation::Eq(a, b) => write!(f, "{} == {}", a, b), - Relation::Ne(a, b) => write!(f, "{} != {}", a, b), - Relation::Gt(a, b) => write!(f, "{} > {}", a, b), - Relation::Ge(a, b) => write!(f, "{} >= {}", a, b), - Relation::Lt(a, b) => write!(f, "{} < {}", a, b), - Relation::Le(a, b) => write!(f, "{} <= {}", a, b), - Relation::Contains(a, b) => write!(f, "{} contains {}", a, b), - Relation::IsType(a, b) => write!(f, "{} is type {:?}", a, b), - } - } -} - -impl Evaluate for Relation { - type State<'a> = (); - - fn evaluate(&self, _state: &()) -> ConditionResult { - match self { - Relation::Eq(a, b) => (a == b).into(), - Relation::Ne(a, b) => (a != b).into(), - Relation::Gt(a, b) => (a > b).into(), - Relation::Ge(a, b) => (a >= b).into(), - Relation::Lt(a, b) => (a < b).into(), - Relation::Le(a, b) => (a <= b).into(), - Relation::Contains(a, b) => (a.contains(b)).into(), - Relation::IsType(a, b) => match a { - Atom::Var(_) => (*b == AtomType::Var).into(), - Atom::Fun(_) => (*b == AtomType::Fun).into(), - Atom::Num(_) => (*b == AtomType::Num).into(), - Atom::Add(_) => (*b == AtomType::Add).into(), - Atom::Mul(_) => (*b == AtomType::Mul).into(), - Atom::Pow(_) => (*b == AtomType::Pow).into(), - Atom::Zero => (*b == AtomType::Num).into(), - }, - } - } -} - /// A restriction on wildcards. #[pyclass(name = "Condition", module = "symbolica")] #[derive(Clone)] @@ -1786,11 +1892,15 @@ impl PythonCondition { format!("{}", self.condition) } - pub fn eval(&self) -> bool { - self.condition.evaluate(&()) == ConditionResult::True + pub fn eval(&self) -> PyResult { + Ok(self + .condition + .evaluate(&None) + .map_err(|e| exceptions::PyValueError::new_err(e))? + == ConditionResult::True) } - pub fn __bool__(&self) -> bool { + pub fn __bool__(&self) -> PyResult { self.eval() } @@ -1821,37 +1931,45 @@ impl PythonCondition { macro_rules! req_cmp_rel { ($self:ident,$num:ident,$cmp_any_atom:ident,$c:ident) => {{ - if !$cmp_any_atom && !matches!($num.as_view(), AtomView::Num(_)) { - return Err("Can only compare to number"); - }; - - match $self.as_view() { - AtomView::Var(v) => { - let name = v.get_symbol(); - if v.get_wildcard_level() == 0 { - return Err("Only wildcards can be restricted."); + let num = if !$cmp_any_atom { + if let Pattern::Literal(a) = $num { + if let AtomView::Num(_) = a.as_view() { + a + } else { + return Err("Can only compare to number"); } + } else { + return Err("Can only compare to number"); + } + } else if let Pattern::Literal(a) = $num { + a + } else { + return Err("Pattern must be literal"); + }; - Ok(PatternRestriction::Wildcard(( - name, - WildcardRestriction::Filter(Box::new(move |v: &Match| { - let k = $num.as_view(); + if let Pattern::Wildcard(name) = $self { + if name.get_wildcard_level() == 0 { + return Err("Only wildcards can be restricted."); + } - if let Match::Single(m) = v { - if !$cmp_any_atom { - if let AtomView::Num(_) = m { - return m.cmp(&k).$c(); - } - } else { - return m.cmp(&k).$c(); + Ok(PatternRestriction::Wildcard(( + name, + WildcardRestriction::Filter(Box::new(move |v: &Match| { + if let Match::Single(m) = v { + if !$cmp_any_atom { + if let AtomView::Num(_) = m { + return m.cmp(&num.as_view()).$c(); } + } else { + return m.cmp(&num.as_view()).$c(); } + } - false - })), - ))) - } - _ => Err("Only wildcards can be restricted."), + false + })), + ))) + } else { + Err("Only wildcards can be restricted.") } }}; } @@ -1880,18 +1998,28 @@ impl TryFrom for PatternRestriction { return req_cmp_rel!(atom, atom1, true, is_le); } Relation::Contains(atom, atom1) => { - if let Atom::Var(v) = atom { - let name = v.get_symbol(); + if let Pattern::Wildcard(name) = atom { if name.get_wildcard_level() == 0 { return Err("Only wildcards can be restricted."); } + if !matches!(&atom1, &Pattern::Literal(_)) { + return Err("Pattern must be literal"); + } + Ok(PatternRestriction::Wildcard(( name, - WildcardRestriction::Filter(Box::new(move |m| match m { - Match::Single(v) => v.contains(atom1.as_view()), - Match::Multiple(_, v) => v.iter().any(|x| x.contains(atom1.as_view())), - Match::FunctionName(_) => false, + WildcardRestriction::Filter(Box::new(move |m| { + let val = if let Pattern::Literal(a) = &atom1 { + a.as_view() + } else { + unreachable!() + }; + match m { + Match::Single(v) => v.contains(val), + Match::Multiple(_, v) => v.iter().any(|x| x.contains(val)), + Match::FunctionName(_) => false, + } })), ))) } else { @@ -1899,9 +2027,9 @@ impl TryFrom for PatternRestriction { } } Relation::IsType(atom, atom_type) => { - if let Atom::Var(v) = atom { + if let Pattern::Wildcard(name) = atom { Ok(PatternRestriction::Wildcard(( - v.get_symbol(), + name, WildcardRestriction::IsAtomType(atom_type), ))) } else { @@ -3056,8 +3184,8 @@ impl PythonExpression { pub fn contains(&self, s: ConvertibleToExpression) -> PythonCondition { PythonCondition { condition: Condition::Yield(Relation::Contains( - self.expr.clone(), - s.to_expression().expr, + self.expr.into_pattern(), + s.to_expression().expr.into_pattern(), )), } } @@ -3235,7 +3363,7 @@ impl PythonExpression { pub fn is_type(&self, atom_type: PythonAtomType) -> PythonCondition { PythonCondition { condition: Condition::Yield(Relation::IsType( - self.expr.clone(), + self.expr.into_pattern(), match atom_type { PythonAtomType::Num => AtomType::Num, PythonAtomType::Var => AtomType::Var, @@ -3245,32 +3373,32 @@ impl PythonExpression { PythonAtomType::Fn => AtomType::Fun, }, )), - } - } + } + } /// Compare two expressions. If one of the expressions is not a number, an /// internal ordering will be used. - fn __richcmp__(&self, other: ConvertibleToExpression, op: CompareOp) -> PythonCondition { - match op { + fn __richcmp__(&self, other: ConvertibleToPattern, op: CompareOp) -> PyResult { + Ok(match op { CompareOp::Eq => PythonCondition { - condition: Relation::Eq(self.expr.clone(), other.to_expression().expr).into(), + condition: Relation::Eq(self.expr.into_pattern(), other.to_pattern()?.expr).into(), }, CompareOp::Ne => PythonCondition { - condition: Relation::Ne(self.expr.clone(), other.to_expression().expr).into(), + condition: Relation::Ne(self.expr.into_pattern(), other.to_pattern()?.expr).into(), }, CompareOp::Ge => PythonCondition { - condition: Relation::Ge(self.expr.clone(), other.to_expression().expr).into(), + condition: Relation::Ge(self.expr.into_pattern(), other.to_pattern()?.expr).into(), }, CompareOp::Gt => PythonCondition { - condition: Relation::Gt(self.expr.clone(), other.to_expression().expr).into(), + condition: Relation::Gt(self.expr.into_pattern(), other.to_pattern()?.expr).into(), }, CompareOp::Le => PythonCondition { - condition: Relation::Le(self.expr.clone(), other.to_expression().expr).into(), + condition: Relation::Le(self.expr.into_pattern(), other.to_pattern()?.expr).into(), }, CompareOp::Lt => PythonCondition { - condition: Relation::Lt(self.expr.clone(), other.to_expression().expr).into(), + condition: Relation::Lt(self.expr.into_pattern(), other.to_pattern()?.expr).into(), }, - } + }) } /// Create a pattern restriction that passes when the wildcard is smaller than a number `num`. @@ -3719,7 +3847,7 @@ impl PythonExpression { /// /// yields /// - /// ``` + /// ```log /// (3*x+3*y)*(4*x+5*y) /// ``` pub fn expand_num(&self) -> PythonExpression { diff --git a/src/id.rs b/src/id.rs index 41fddf1f..180eb2b9 100644 --- a/src/id.rs +++ b/src/id.rs @@ -23,6 +23,16 @@ pub enum Pattern { Transformer(Box<(Option, Vec)>), } +impl std::fmt::Display for Pattern { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Ok(a) = self.to_atom() { + a.fmt(f) + } else { + std::fmt::Debug::fmt(self, f) + } + } +} + pub trait MatchMap: Fn(&MatchStack) -> Atom + DynClone + Send + Sync {} dyn_clone::clone_trait_object!(MatchMap); impl Atom> MatchMap for T {} @@ -417,8 +427,13 @@ impl<'a> AtomView<'a> { match r.rhs { PatternOrMap::Pattern(rhs) => { - rhs.substitute_wildcards(workspace, &mut rhs_subs, &match_stack) - .unwrap(); // TODO: escalate? + rhs.substitute_wildcards( + workspace, + &mut rhs_subs, + &match_stack, + None, + ) + .unwrap(); // TODO: escalate? } PatternOrMap::Map(f) => { let mut rhs = f(&match_stack); @@ -938,6 +953,7 @@ impl Pattern { workspace: &Workspace, out: &mut Atom, match_stack: &MatchStack, + transformer_input: Option<&Pattern>, ) -> Result<(), TransformerError> { match self { Pattern::Wildcard(name) => { @@ -1017,7 +1033,12 @@ impl Pattern { } let mut handle = workspace.new_atom(); - arg.substitute_wildcards(workspace, &mut handle, match_stack)?; + arg.substitute_wildcards( + workspace, + &mut handle, + match_stack, + transformer_input, + )?; func.add_arg(handle.as_view()); } @@ -1055,7 +1076,12 @@ impl Pattern { } let mut handle = workspace.new_atom(); - arg.substitute_wildcards(workspace, &mut handle, match_stack)?; + arg.substitute_wildcards( + workspace, + &mut handle, + match_stack, + transformer_input, + )?; out.set_from_view(&handle.as_view()); } @@ -1099,7 +1125,12 @@ impl Pattern { } let mut handle = workspace.new_atom(); - arg.substitute_wildcards(workspace, &mut handle, match_stack)?; + arg.substitute_wildcards( + workspace, + &mut handle, + match_stack, + transformer_input, + )?; mul.extend(handle.as_view()); } mul_h.as_view().normalize(workspace, out); @@ -1140,7 +1171,12 @@ impl Pattern { } let mut handle = workspace.new_atom(); - arg.substitute_wildcards(workspace, &mut handle, match_stack)?; + arg.substitute_wildcards( + workspace, + &mut handle, + match_stack, + transformer_input, + )?; add.extend(handle.as_view()); } add_h.as_view().normalize(workspace, out); @@ -1150,14 +1186,19 @@ impl Pattern { } Pattern::Transformer(p) => { let (pat, ts) = &**p; - let pat = pat.as_ref().ok_or_else(|| { - TransformerError::ValueError( + + let pat = if let Some(p) = pat.as_ref() { + p + } else if let Some(input_p) = transformer_input { + input_p + } else { + Err(TransformerError::ValueError( "Transformer is missing an expression to act on.".to_owned(), - ) - })?; + ))? + }; let mut handle = workspace.new_atom(); - pat.substitute_wildcards(workspace, &mut handle, match_stack)?; + pat.substitute_wildcards(workspace, &mut handle, match_stack, transformer_input)?; Transformer::execute_chain(handle.as_view(), ts, workspace, out)?; } @@ -1362,21 +1403,21 @@ pub trait Evaluate { type State<'a>; /// Evaluate a condition. - fn evaluate<'a>(&self, state: &Self::State<'a>) -> ConditionResult; + fn evaluate<'a>(&self, state: &Self::State<'a>) -> Result; } impl Evaluate for Condition { type State<'a> = T::State<'a>; - fn evaluate(&self, state: &T::State<'_>) -> ConditionResult { - match self { - Condition::And(a) => a.0.evaluate(state) & a.1.evaluate(state), - Condition::Or(o) => o.0.evaluate(state) | o.1.evaluate(state), - Condition::Not(n) => !n.evaluate(state), + fn evaluate(&self, state: &T::State<'_>) -> Result { + Ok(match self { + Condition::And(a) => a.0.evaluate(state)? & a.1.evaluate(state)?, + Condition::Or(o) => o.0.evaluate(state)? | o.1.evaluate(state)?, + Condition::Not(n) => !n.evaluate(state)?, Condition::True => ConditionResult::True, Condition::False => ConditionResult::False, - Condition::Yield(t) => t.evaluate(state), - } + Condition::Yield(t) => t.evaluate(state)?, + }) } } @@ -1465,14 +1506,120 @@ impl From for ConditionResult { } } +impl ConditionResult { + pub fn is_true(&self) -> bool { + matches!(self, ConditionResult::True) + } + + pub fn is_false(&self) -> bool { + matches!(self, ConditionResult::False) + } + + pub fn is_inconclusive(&self) -> bool { + matches!(self, ConditionResult::Inconclusive) + } +} + +#[derive(Clone, Debug)] +pub enum Relation { + Eq(Pattern, Pattern), + Ne(Pattern, Pattern), + Gt(Pattern, Pattern), + Ge(Pattern, Pattern), + Lt(Pattern, Pattern), + Le(Pattern, Pattern), + Contains(Pattern, Pattern), + IsType(Pattern, AtomType), +} + +impl std::fmt::Display for Relation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Relation::Eq(a, b) => write!(f, "{} == {}", a, b), + Relation::Ne(a, b) => write!(f, "{} != {}", a, b), + Relation::Gt(a, b) => write!(f, "{} > {}", a, b), + Relation::Ge(a, b) => write!(f, "{} >= {}", a, b), + Relation::Lt(a, b) => write!(f, "{} < {}", a, b), + Relation::Le(a, b) => write!(f, "{} <= {}", a, b), + Relation::Contains(a, b) => write!(f, "{} contains {}", a, b), + Relation::IsType(a, b) => write!(f, "{} is type {:?}", a, b), + } + } +} + +impl Evaluate for Relation { + type State<'a> = Option>; + + fn evaluate(&self, state: &Option) -> Result { + Workspace::get_local().with(|ws| { + let mut out1 = ws.new_atom(); + let mut out2 = ws.new_atom(); + let c = Condition::default(); + let s = MatchSettings::default(); + let m = MatchStack::new(&c, &s); + let pat = state.map(|x| x.into_pattern()); + + Ok(match self { + Relation::Eq(a, b) + | Relation::Ne(a, b) + | Relation::Gt(a, b) + | Relation::Ge(a, b) + | Relation::Lt(a, b) + | Relation::Le(a, b) + | Relation::Contains(a, b) => { + a.substitute_wildcards(ws, &mut out1, &m, pat.as_ref()) + .map_err(|e| match e { + TransformerError::Interrupt => "Interrupted by user".into(), + TransformerError::ValueError(v) => v, + })?; + b.substitute_wildcards(ws, &mut out2, &m, pat.as_ref()) + .map_err(|e| match e { + TransformerError::Interrupt => "Interrupted by user".into(), + TransformerError::ValueError(v) => v, + })?; + + match self { + Relation::Eq(_, _) => out1 == out2, + Relation::Ne(_, _) => out1 != out2, + Relation::Gt(_, _) => out1.as_view() > out2.as_view(), + Relation::Ge(_, _) => out1.as_view() >= out2.as_view(), + Relation::Lt(_, _) => out1.as_view() < out2.as_view(), + Relation::Le(_, _) => out1.as_view() <= out2.as_view(), + Relation::Contains(_, _) => out1.contains(out2.as_view()), + _ => unreachable!(), + } + } + Relation::IsType(a, b) => { + a.substitute_wildcards(ws, &mut out1, &m, pat.as_ref()) + .map_err(|e| match e { + TransformerError::Interrupt => "Interrupted by user".into(), + TransformerError::ValueError(v) => v, + })?; + + match out1.as_ref() { + Atom::Var(_) => (*b == AtomType::Var).into(), + Atom::Fun(_) => (*b == AtomType::Fun).into(), + Atom::Num(_) => (*b == AtomType::Num).into(), + Atom::Add(_) => (*b == AtomType::Add).into(), + Atom::Mul(_) => (*b == AtomType::Mul).into(), + Atom::Pow(_) => (*b == AtomType::Pow).into(), + Atom::Zero => (*b == AtomType::Num).into(), + } + } + } + .into()) + }) + } +} + impl Evaluate for Condition { type State<'a> = MatchStack<'a, 'a>; - fn evaluate(&self, state: &MatchStack) -> ConditionResult { - match self { - Condition::And(a) => a.0.evaluate(state) & a.1.evaluate(state), - Condition::Or(o) => o.0.evaluate(state) | o.1.evaluate(state), - Condition::Not(n) => !n.evaluate(state), + fn evaluate(&self, state: &MatchStack) -> Result { + Ok(match self { + Condition::And(a) => a.0.evaluate(state)? & a.1.evaluate(state)?, + Condition::Or(o) => o.0.evaluate(state)? | o.1.evaluate(state)?, + Condition::Not(n) => !n.evaluate(state)?, Condition::True => ConditionResult::True, Condition::False => ConditionResult::False, Condition::Yield(t) => match t { @@ -1507,7 +1654,7 @@ impl Evaluate for Condition { { f(value, value2) } else { - return ConditionResult::Inconclusive; + return Ok(ConditionResult::Inconclusive); } } WildcardRestriction::NotGreedy => true, @@ -1519,7 +1666,7 @@ impl Evaluate for Condition { } PatternRestriction::MatchStack(mf) => mf(state), }, - } + }) } } @@ -3125,7 +3272,7 @@ impl<'a: 'b, 'b> ReplaceIterator<'a, 'b> { match self.rhs { PatternOrMap::Pattern(p) => { - p.substitute_wildcards(ws, &mut new_rhs, pattern_match.match_stack) + p.substitute_wildcards(ws, &mut new_rhs, pattern_match.match_stack, None) .unwrap(); // TODO: escalate? } PatternOrMap::Map(f) => { diff --git a/src/transformer.rs b/src/transformer.rs index e0a152cf..e6c29797 100644 --- a/src/transformer.rs +++ b/src/transformer.rs @@ -1,11 +1,14 @@ -use std::{sync::Arc, time::Instant}; +use std::{ops::ControlFlow, sync::Arc, time::Instant}; use crate::{ atom::{representation::FunView, Atom, AtomOrView, AtomView, Fun, Symbol}, coefficient::{Coefficient, CoefficientView}, combinatorics::{partitions, unique_permutations}, domains::rational::Rational, - id::{Condition, MatchSettings, Pattern, PatternOrMap, PatternRestriction, Replacement}, + id::{ + Condition, Evaluate, MatchSettings, Pattern, PatternOrMap, PatternRestriction, Relation, + Replacement, + }, printer::{AtomPrinter, PrintOptions}, state::{RecycledAtom, State, Workspace}, }; @@ -106,6 +109,9 @@ pub enum TransformerError { /// Operations that take a pattern as the input and produce an expression #[derive(Clone)] pub enum Transformer { + IfElse(Condition, Vec, Vec), + IfChanged(Vec, Vec, Vec), + BreakChain, /// Expand the rhs. Expand(Option, bool), /// Distribute numbers. @@ -168,6 +174,9 @@ pub enum Transformer { impl std::fmt::Debug for Transformer { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { + Transformer::IfElse(_, _, _) => f.debug_tuple("IfElse").finish(), + Transformer::IfChanged(_, _, _) => f.debug_tuple("IfChanged").finish(), + Transformer::BreakChain => f.debug_tuple("BreakChain").finish(), Transformer::Expand(s, _) => f.debug_tuple("Expand").field(s).finish(), Transformer::ExpandNum => f.debug_tuple("ExpandNum").finish(), Transformer::Derivative(x) => f.debug_tuple("Derivative").field(x).finish(), @@ -415,17 +424,17 @@ impl Transformer { input: AtomView<'_>, workspace: &Workspace, out: &mut Atom, - ) -> Result<(), TransformerError> { + ) -> Result, TransformerError> { Transformer::execute_chain(input, std::slice::from_ref(self), workspace, out) } - /// Apply a chain of transformers to `orig_input`. + /// Apply a chain of transformers to `input`. pub fn execute_chain( input: AtomView<'_>, chain: &[Transformer], workspace: &Workspace, out: &mut Atom, - ) -> Result<(), TransformerError> { + ) -> Result, TransformerError> { out.set_from_view(&input); let mut tmp = workspace.new_atom(); for t in chain { @@ -433,6 +442,39 @@ impl Transformer { let cur_input = tmp.as_view(); match t { + Transformer::IfElse(cond, t1, t2) => { + if cond + .evaluate(&Some(cur_input)) + .map_err(|e| TransformerError::ValueError(e))? + .is_true() + { + if Transformer::execute_chain(cur_input, t1, workspace, out)?.is_break() { + return Ok(ControlFlow::Break(())); + } + } else if Transformer::execute_chain(cur_input, t2, workspace, out)?.is_break() + { + return Ok(ControlFlow::Break(())); + } + } + Transformer::IfChanged(cond, t1, t2) => { + Transformer::execute_chain(cur_input, cond, workspace, out)?; + std::mem::swap(out, &mut tmp); + + if tmp.as_view() != out.as_view() { + if Transformer::execute_chain(tmp.as_view(), t1, workspace, out)?.is_break() + { + return Ok(ControlFlow::Break(())); + } + } else if Transformer::execute_chain(tmp.as_view(), t2, workspace, out)? + .is_break() + { + return Ok(ControlFlow::Break(())); + } + } + Transformer::BreakChain => { + std::mem::swap(out, &mut tmp); + return Ok(ControlFlow::Break(())); + } Transformer::Map(f) => { f(cur_input, out)?; } @@ -504,7 +546,7 @@ impl Transformer { let key_map = key_map.clone(); Some(Box::new(move |i, o| { Workspace::get_local() - .with(|ws| Self::execute_chain(i, &key_map, ws, o).unwrap()) + .with(|ws| Self::execute_chain(i, &key_map, ws, o).unwrap()); })) }, if coeff_map.is_empty() { @@ -513,7 +555,7 @@ impl Transformer { let coeff_map = coeff_map.clone(); Some(Box::new(move |i, o| { Workspace::get_local() - .with(|ws| Self::execute_chain(i, &coeff_map, ws, o).unwrap()) + .with(|ws| Self::execute_chain(i, &coeff_map, ws, o).unwrap()); })) }, out, @@ -832,7 +874,7 @@ impl Transformer { } } - Ok(()) + Ok(ControlFlow::Continue(())) } } diff --git a/symbolica.pyi b/symbolica.pyi index 43ce8da2..80368055 100644 --- a/symbolica.pyi +++ b/symbolica.pyi @@ -1576,6 +1576,85 @@ class Transformer: >>> e = Transformer().expand()((1+x)**2) """ + def __eq__(self, other: Transformer | Expression | int | float | Decimal) -> Condition: + """ + Compare two transformers. + """ + + def __neq__(self, other: Transformer | Expression | int | float | Decimal) -> Condition: + """ + Compare two transformers. + """ + + def __lt__(self, other: Transformer | Expression | int | float | Decimal) -> Condition: + """ + Compare two transformers. If any of the two expressions is not a rational number, an interal ordering is used. + """ + + def __le__(self, other: Transformer | Expression | int | float | Decimal) -> Condition: + """ + Compare two transformers. If any of the two expressions is not a rational number, an interal ordering is used. + """ + + def __gt__(self, other: Transformer | Expression | int | float | Decimal) -> Condition: + """ + Compare two transformers. If any of the two expressions is not a rational number, an interal ordering is used. + """ + + def __ge__(self, other: Transformer | Expression | int | float | Decimal) -> Condition: + """ + Compare two transformers. If any of the two expressions is not a rational number, an interal ordering is used. + """ + + def contains(self, element: Transformer | Expression | int | float | Decimal) -> Condition: + """ + Create a transformer that checks if the expression contains the given `element`. + """ + + def if_then(self, condition: Condition, if_block: Transformer, else_block: Optional[Transformer] = None) -> Transformer: + """Evaluate the condition and apply the `if_block` if the condition is true, otherwise apply the `else_block`. + The expression that is the input of the transformer is the input for the condition, the `if_block` and the `else_block`. + + Examples + -------- + >>> t = T.map_terms(T.if_then(T.contains(x), T.print())) + >>> t(x + y + 4) + + prints `x`. + """ + + def if_changed(self, condition: Transformer, if_block: Transformer, else_block: Optional[Transformer] = None) -> Transformer: + """Execute the `condition` transformer. If the result of the `condition` transformer is different from the input expression, + apply the `if_block`, otherwise apply the `else_block`. The input expression of the `if_block` is the output + of the `condition` transformer. + + Examples + -------- + >>> t = T.map_terms(T.if_changed(T.replace_all(x, y), T.print())) + >>> print(t(x + y + 4)) + + prints + ``` + y + 2*y+4 + ``` + """ + + def break_chain(self) -> Transformer: + """Break the current chain and all higher-level chains containing `if` transformers. + + Examples + -------- + >>> from symbolica import * + >>> t = T.map_terms(T.repeat( + >>> T.replace_all(y, 4), + >>> T.if_changed(T.replace_all(x, y), + >>> T.break_chain()), + >>> T.print() # print of y is never reached + >>> )) + >>> print(t(x)) + """ + def expand(self, var: Optional[Expression] = None, via_poly: Optional[bool] = None) -> Transformer: """Create a transformer that expands products and powers. Optionally, expand in `var` only. From fe31c83cd065982473e4c65225cd635538826320 Mon Sep 17 00:00:00 2001 From: Ben Ruijl Date: Wed, 4 Dec 2024 17:48:48 +0100 Subject: [PATCH 6/6] Add replace_map for replacements using a function - Do not set a static trace level - Allow transformers on rhs of contains - Add is_type condition on transformers --- Cargo.toml | 5 +- src/api/python.rs | 25 +++++-- src/atom.rs | 1 + src/id.rs | 162 ++++++++++++++++++++++++++++++++++++++++++++++ symbolica.pyi | 7 +- 5 files changed, 193 insertions(+), 7 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index bea8e8fe..54768068 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,7 @@ crate-type = ["lib"] name = "symbolica" [features] -default = [] +default = ["tracing_only_warnings"] # if using this, make sure jemalloc is compiled with --disable-initial-exec-tls # if symbolica is used as a dynamic library (as is the case for the Python API) faster_alloc = ["tikv-jemallocator"] @@ -37,6 +37,7 @@ python_api = ["pyo3", "bincode"] python_no_module = ["python_api"] # build a module that is independent of the specific Python version python_abi3 = ["pyo3/abi3", "pyo3/abi3-py37"] +tracing_only_warnings = ["tracing/release_max_level_warn"] [dependencies.pyo3] features = ["extension-module", "abi3", "py-clone"] @@ -67,6 +68,6 @@ smallvec = "1.13" smartstring = "1.0" tikv-jemallocator = {version = "0.5.4", optional = true} tinyjson = "2.5" -tracing = {version = "0.1", features = ["max_level_trace", "release_max_level_warn"]} +tracing = "0.1" wide = "0.7" wolfram-library-link = {version = "0.2.9", optional = true} diff --git a/src/api/python.rs b/src/api/python.rs index 4050a827..920ff00b 100644 --- a/src/api/python.rs +++ b/src/api/python.rs @@ -468,6 +468,23 @@ impl PythonTransformer { }) } + /// Test if the expression is of a certain type. + pub fn is_type(&self, atom_type: PythonAtomType) -> PythonCondition { + PythonCondition { + condition: Condition::Yield(Relation::IsType( + self.expr.clone(), + match atom_type { + PythonAtomType::Num => AtomType::Num, + PythonAtomType::Var => AtomType::Var, + PythonAtomType::Add => AtomType::Add, + PythonAtomType::Mul => AtomType::Mul, + PythonAtomType::Pow => AtomType::Pow, + PythonAtomType::Fn => AtomType::Fun, + }, + )), + } + } + /// Returns true iff `self` contains `a` literally. /// /// Examples @@ -3181,13 +3198,13 @@ impl PythonExpression { /// >>> e.contains(x) # True /// >>> e.contains(x*y*z) # True /// >>> e.contains(x*y) # False - pub fn contains(&self, s: ConvertibleToExpression) -> PythonCondition { - PythonCondition { + pub fn contains(&self, s: ConvertibleToPattern) -> PyResult { + Ok(PythonCondition { condition: Condition::Yield(Relation::Contains( self.expr.into_pattern(), - s.to_expression().expr.into_pattern(), + s.to_pattern()?.expr, )), - } + }) } /// Get all symbols in the current expression, optionally including function symbols. diff --git a/src/atom.rs b/src/atom.rs index ddd45158..40cc9542 100644 --- a/src/atom.rs +++ b/src/atom.rs @@ -499,6 +499,7 @@ impl<'a> AtomView<'a> { } } +/// A mathematical expression. #[derive(Clone)] pub enum Atom { Num(Num), diff --git a/src/id.rs b/src/id.rs index 180eb2b9..c22edbd7 100644 --- a/src/id.rs +++ b/src/id.rs @@ -161,6 +161,24 @@ impl Atom { ) -> bool { self.as_view().replace_all_multiple_into(replacements, out) } + + /// Replace part of an expression by calling the map `m` on each subexpression. + /// The function `m` must return `true` if the expression was replaced and must write the new expression to `out`. + /// A [Context] object is passed to the function, which contains information about the current position in the expression. + pub fn replace_map bool>(&self, m: &F) -> Atom { + self.as_view().replace_map(m) + } +} + +/// The context of an atom. +#[derive(Clone, Copy, Debug)] +pub struct Context { + /// The level of the function in the expression tree. + pub function_level: usize, + /// The type of the parent atom. + pub parent_type: Option, + /// The index of the atom in the parent. + pub index: usize, } impl<'a> AtomView<'a> { @@ -326,6 +344,134 @@ impl<'a> AtomView<'a> { false } + /// Replace part of an expression by calling the map `m` on each subexpression. + /// The function `m` must return `true` if the expression was replaced and must write the new expression to `out`. + /// A [Context] object is passed to the function, which contains information about the current position in the expression. + pub fn replace_map bool>(&self, m: &F) -> Atom { + let mut out = Atom::new(); + self.replace_map_into(m, &mut out); + out + } + + /// Replace part of an expression by calling the map `m` on each subexpression. + /// The function `m` must return `true` if the expression was replaced and must write the new expression to `out`. + /// A [Context] object is passed to the function, which contains information about the current position in the expression. + pub fn replace_map_into bool>( + &self, + m: &F, + out: &mut Atom, + ) { + let context = Context { + function_level: 0, + parent_type: None, + index: 0, + }; + Workspace::get_local().with(|ws| { + self.replace_map_impl(ws, m, context, out); + }); + } + + fn replace_map_impl bool>( + &self, + ws: &Workspace, + m: &F, + mut context: Context, + out: &mut Atom, + ) -> bool { + if m(*self, &context, out) { + return true; + } + + let mut changed = false; + match self { + AtomView::Num(_) | AtomView::Var(_) => { + out.set_from_view(self); + } + AtomView::Fun(f) => { + let mut fun = ws.new_atom(); + let fun = fun.to_fun(f.get_symbol()); + + context.parent_type = Some(AtomType::Fun); + context.function_level += 1; + + for (i, arg) in f.iter().enumerate() { + context.index = i; + + let mut arg_h = ws.new_atom(); + changed |= arg.replace_map_impl(ws, m, context, &mut arg_h); + fun.add_arg(arg_h.as_view()); + } + + if changed { + fun.as_view().normalize(ws, out); + } else { + out.set_from_view(self); + } + } + AtomView::Pow(p) => { + let (base, exp) = p.get_base_exp(); + + context.parent_type = Some(AtomType::Pow); + context.index = 0; + + let mut base_h = ws.new_atom(); + changed |= base.replace_map_impl(ws, m, context, &mut base_h); + + context.index = 1; + let mut exp_h = ws.new_atom(); + changed |= exp.replace_map_impl(ws, m, context, &mut exp_h); + + if changed { + let mut pow_h = ws.new_atom(); + pow_h.to_pow(base_h.as_view(), exp_h.as_view()); + pow_h.as_view().normalize(ws, out); + } else { + out.set_from_view(self); + } + } + AtomView::Mul(mm) => { + let mut mul_h = ws.new_atom(); + let mul = mul_h.to_mul(); + + context.parent_type = Some(AtomType::Mul); + + for (i, child) in mm.iter().enumerate() { + context.index = i; + let mut child_h = ws.new_atom(); + changed |= child.replace_map_impl(ws, m, context, &mut child_h); + mul.extend(child_h.as_view()); + } + + if changed { + mul_h.as_view().normalize(ws, out); + } else { + out.set_from_view(self); + } + } + AtomView::Add(a) => { + let mut add_h = ws.new_atom(); + let add = add_h.to_add(); + + context.parent_type = Some(AtomType::Add); + + for (i, child) in a.iter().enumerate() { + context.index = i; + let mut child_h = ws.new_atom(); + changed |= child.replace_map_impl(ws, m, context, &mut child_h); + add.extend(child_h.as_view()); + } + + if changed { + add_h.as_view().normalize(ws, out); + } else { + out.set_from_view(self); + } + } + } + + changed + } + /// Replace all occurrences of the patterns, where replacements are tested in the order that they are given. pub fn replace_all( &self, @@ -3313,6 +3459,22 @@ mod test { use super::Pattern; + #[test] + fn replace_map() { + let a = Atom::parse("v1 + f1(1,2, f1((1+v1)^2), (v1+v2)^2)").unwrap(); + + let r = a.replace_map(&|arg, context, out| { + if context.function_level > 0 { + arg.expand_into(None, out) + } else { + false + } + }); + + let res = Atom::parse("v1+f1(1,2,f1(2*v1+v1^2+1),v1^2+v2^2+2*v1*v2)").unwrap(); + assert_eq!(r, res); + } + #[test] fn overlap() { let a = Atom::parse("(v1*(v2+v2^2+1)+v2^2 + v2)").unwrap(); diff --git a/symbolica.pyi b/symbolica.pyi index 80368055..1edd1739 100644 --- a/symbolica.pyi +++ b/symbolica.pyi @@ -517,7 +517,7 @@ class Expression: transformations can be applied. """ - def contains(self, a: Expression | int | float | Decimal) -> Condition: + def contains(self, a: Transformer | Expression | int | float | Decimal) -> Condition: """Returns true iff `self` contains `a` literally. Examples @@ -1606,6 +1606,11 @@ class Transformer: Compare two transformers. If any of the two expressions is not a rational number, an interal ordering is used. """ + def is_type(self, atom_type: AtomType) -> Condition: + """ + Test if the transformed expression is of a certain type. + """ + def contains(self, element: Transformer | Expression | int | float | Decimal) -> Condition: """ Create a transformer that checks if the expression contains the given `element`.