From fea7a041fdf82c883de2d78674b95be494cd8adb Mon Sep 17 00:00:00 2001 From: Ben Ruijl Date: Tue, 10 Dec 2024 17:00:42 +0100 Subject: [PATCH 1/7] Redesign error propagating floats - Use absolute error so that zeros are supported - Add i() function for floats - Add automatic evaluation of complex i - Move EXP, E, etc from State to Atom --- src/api/python.rs | 16 +- src/atom.rs | 52 ++++++- src/coefficient.rs | 4 +- src/derivative.rs | 34 ++--- src/domains/dual.rs | 7 + src/domains/float.rs | 354 ++++++++++++++++++++++++++++++------------- src/evaluate.rs | 81 +++++----- src/id.rs | 6 +- src/normalize.rs | 38 ++--- src/poly/evaluate.rs | 5 +- src/poly/series.rs | 17 +-- src/printer.rs | 6 +- src/state.rs | 24 +-- src/transformer.rs | 30 ++-- 14 files changed, 438 insertions(+), 236 deletions(-) diff --git a/src/api/python.rs b/src/api/python.rs index 920ff00b..816043a5 100644 --- a/src/api/python.rs +++ b/src/api/python.rs @@ -2656,14 +2656,14 @@ impl PythonExpression { #[classattr] #[pyo3(name = "E")] pub fn e() -> PythonExpression { - Atom::new_var(State::E).into() + Atom::new_var(Atom::E).into() } /// The mathematical constant `π`. #[classattr] #[pyo3(name = "PI")] pub fn pi() -> PythonExpression { - Atom::new_var(State::PI).into() + Atom::new_var(Atom::PI).into() } /// The mathematical constant `i`, where @@ -2671,42 +2671,42 @@ impl PythonExpression { #[classattr] #[pyo3(name = "I")] pub fn i() -> PythonExpression { - Atom::new_var(State::I).into() + Atom::new_var(Atom::I).into() } /// The built-in function that converts a rational polynomial to a coefficient. #[classattr] #[pyo3(name = "COEFF")] pub fn coeff() -> PythonExpression { - Atom::new_var(State::COEFF).into() + Atom::new_var(Atom::COEFF).into() } /// The built-in cosine function. #[classattr] #[pyo3(name = "COS")] pub fn cos() -> PythonExpression { - Atom::new_var(State::COS).into() + Atom::new_var(Atom::COS).into() } /// The built-in sine function. #[classattr] #[pyo3(name = "SIN")] pub fn sin() -> PythonExpression { - Atom::new_var(State::SIN).into() + Atom::new_var(Atom::SIN).into() } /// The built-in exponential function. #[classattr] #[pyo3(name = "EXP")] pub fn exp() -> PythonExpression { - Atom::new_var(State::EXP).into() + Atom::new_var(Atom::EXP).into() } /// The built-in logarithm function. #[classattr] #[pyo3(name = "LOG")] pub fn log() -> PythonExpression { - Atom::new_var(State::LOG).into() + Atom::new_var(Atom::LOG).into() } /// Return all defined symbol names (function names and variables). diff --git a/src/atom.rs b/src/atom.rs index 40cc9542..10bb9ebe 100644 --- a/src/atom.rs +++ b/src/atom.rs @@ -7,7 +7,7 @@ use crate::{ coefficient::Coefficient, parser::Token, printer::{AtomPrinter, PrintOptions}, - state::{RecycledAtom, Workspace}, + state::{RecycledAtom, State, Workspace}, transformer::StatsOptions, }; use std::{cmp::Ordering, hash::Hash, ops::DerefMut, str::FromStr}; @@ -511,6 +511,56 @@ pub enum Atom { Zero, } +impl Atom { + /// The built-in function represents a list of function arguments. + pub const ARG: Symbol = State::ARG; + /// The built-in function that converts a rational polynomial to a coefficient. + pub const COEFF: Symbol = State::COEFF; + /// The exponent function. + pub const EXP: Symbol = State::EXP; + /// The logarithm function. + pub const LOG: Symbol = State::LOG; + /// The sine function. + pub const SIN: Symbol = State::SIN; + /// The cosine function. + pub const COS: Symbol = State::COS; + /// The square root function. + pub const SQRT: Symbol = State::SQRT; + /// The built-in function that represents an abstract derivative. + pub const DERIVATIVE: Symbol = State::DERIVATIVE; + /// The constant e, the base of the natural logarithm. + pub const E: Symbol = State::E; + /// The constant i, the imaginary unit. + pub const I: Symbol = State::I; + /// The mathematical constant `π`. + pub const PI: Symbol = State::PI; + + /// Exponentiate the atom. + pub fn exp(&self) -> Atom { + FunctionBuilder::new(Atom::EXP).add_arg(self).finish() + } + + /// Take the logarithm of the atom. + pub fn log(&self) -> Atom { + FunctionBuilder::new(Atom::LOG).add_arg(self).finish() + } + + /// Take the sine the atom. + pub fn sin(&self) -> Atom { + FunctionBuilder::new(Atom::SIN).add_arg(self).finish() + } + + /// Take the cosine the atom. + pub fn cos(&self) -> Atom { + FunctionBuilder::new(Atom::COS).add_arg(self).finish() + } + + /// Take the square root of the atom. + pub fn sqrt(&self) -> Atom { + FunctionBuilder::new(Atom::SQRT).add_arg(self).finish() + } +} + impl Default for Atom { /// Create an atom that represents the number 0. #[inline] diff --git a/src/coefficient.rs b/src/coefficient.rs index efea5884..cfb2a331 100644 --- a/src/coefficient.rs +++ b/src/coefficient.rs @@ -1365,13 +1365,13 @@ impl<'a> AtomView<'a> { let s = v.get_symbol(); match s { - State::PI => { + Atom::PI => { out.to_num(Coefficient::Float(Float::with_val( binary_prec, rug::float::Constant::Pi, ))); } - State::E => { + Atom::E => { out.to_num(Coefficient::Float(Float::with_val(binary_prec, 1).exp())); } _ => { diff --git a/src/derivative.rs b/src/derivative.rs index 9080066e..6d8df1b1 100644 --- a/src/derivative.rs +++ b/src/derivative.rs @@ -9,7 +9,7 @@ use crate::{ combinatorics::CombinationWithReplacementIterator, domains::{atom::AtomField, integer::Integer, rational::Rational}, poly::{series::Series, Variable}, - state::{State, Workspace}, + state::Workspace, }; impl Atom { @@ -83,7 +83,7 @@ impl<'a> AtomView<'a> { // detect if the function to derive is the derivative function itself // if so, derive the last argument of the derivative function and set // a flag to later accumulate previous derivatives - let (to_derive, f, is_der) = if f_orig.get_symbol() == State::DERIVATIVE { + let (to_derive, f, is_der) = if f_orig.get_symbol() == Atom::DERIVATIVE { let to_derive = f_orig.iter().last().unwrap(); ( to_derive, @@ -113,29 +113,29 @@ impl<'a> AtomView<'a> { // derive special functions if f.get_nargs() == 1 - && [State::EXP, State::LOG, State::SIN, State::COS].contains(&f.get_symbol()) + && [Atom::EXP, Atom::LOG, Atom::SIN, Atom::COS].contains(&f.get_symbol()) { let mut fn_der = workspace.new_atom(); match f.get_symbol() { - State::EXP => { + Atom::EXP => { fn_der.set_from_view(self); } - State::LOG => { + Atom::LOG => { let mut n = workspace.new_atom(); n.to_num((-1).into()); fn_der.to_pow(f.iter().next().unwrap(), n.as_view()); } - State::SIN => { - let p = fn_der.to_fun(State::COS); + Atom::SIN => { + let p = fn_der.to_fun(Atom::COS); p.add_arg(f.iter().next().unwrap()); } - State::COS => { + Atom::COS => { let mut n = workspace.new_atom(); n.to_num((-1).into()); let mut sin = workspace.new_atom(); - let sin_fun = sin.to_fun(State::SIN); + let sin_fun = sin.to_fun(Atom::SIN); sin_fun.add_arg(f.iter().next().unwrap()); let m = fn_der.to_mul(); @@ -167,7 +167,7 @@ impl<'a> AtomView<'a> { let mut n = workspace.new_atom(); let mut mul = workspace.new_atom(); for (index, arg_der) in args_der { - let p = fn_der.to_fun(State::DERIVATIVE); + let p = fn_der.to_fun(Atom::DERIVATIVE); if is_der { for (i, x_orig) in f_orig.iter().take(f.get_nargs()).enumerate() { @@ -218,7 +218,7 @@ impl<'a> AtomView<'a> { if exp_der_non_zero { // create log(base) let mut log_base = workspace.new_atom(); - let lb = log_base.to_fun(State::LOG); + let lb = log_base.to_fun(Atom::LOG); lb.add_arg(base); if let Atom::Mul(m) = exp_der.deref_mut() { @@ -418,11 +418,11 @@ impl<'a> AtomView<'a> { } match f.get_symbol() { - State::COS => args_series[0].cos(), - State::SIN => args_series[0].sin(), - State::EXP => args_series[0].exp(), - State::LOG => args_series[0].log(), - State::SQRT => args_series[0].rpow((1, 2).into()), + Atom::COS => args_series[0].cos(), + Atom::SIN => args_series[0].sin(), + Atom::EXP => args_series[0].exp(), + Atom::LOG => args_series[0].log(), + Atom::SQRT => args_series[0].rpow((1, 2).into()), _ => { // TODO: also check for log(x)? if args_series @@ -461,7 +461,7 @@ impl<'a> AtomView<'a> { CombinationWithReplacementIterator::new(args_series.len(), i); while let Some(x) = it.next() { - let mut f_der = FunctionBuilder::new(State::DERIVATIVE); + let mut f_der = FunctionBuilder::new(Atom::DERIVATIVE); let mut term = info.one(); for (arg, pow) in x.iter().enumerate() { if *pow > 0 { diff --git a/src/domains/dual.rs b/src/domains/dual.rs index 15491dfe..f023f157 100644 --- a/src/domains/dual.rs +++ b/src/domains/dual.rs @@ -759,6 +759,13 @@ macro_rules! create_hyperdual_from_components { res } + #[inline(always)] + fn i(&self) -> Option { + let mut res = self.zero(); + res.values[0] = self.values[0].i()?; + Some(res) + } + #[inline(always)] fn norm(&self) -> Self { let n = self.values[0].norm(); diff --git a/src/domains/float.rs b/src/domains/float.rs index 5efb8c55..d7fb1e67 100644 --- a/src/domains/float.rs +++ b/src/domains/float.rs @@ -363,6 +363,8 @@ pub trait Real: NumericalFloatLike { fn euler(&self) -> Self; /// The golden ratio, 1.6180339887... fn phi(&self) -> Self; + /// The imaginary unit, if it exists. + fn i(&self) -> Option; fn norm(&self) -> Self; fn sqrt(&self) -> Self; @@ -535,6 +537,11 @@ impl Real for f64 { 1.6180339887498948 } + #[inline(always)] + fn i(&self) -> Option { + None + } + #[inline(always)] fn norm(&self) -> Self { f64::abs(*self) @@ -928,6 +935,11 @@ impl Real for F64 { 1.6180339887498948.into() } + #[inline(always)] + fn i(&self) -> Option { + None + } + #[inline(always)] fn norm(&self) -> Self { self.0.norm().into() @@ -1821,6 +1833,11 @@ impl Real for Float { (self.one() + self.from_i64(5).sqrt()) / 2 } + #[inline(always)] + fn i(&self) -> Option { + None + } + #[inline(always)] fn norm(&self) -> Self { self.0.clone().abs().into() @@ -1973,7 +1990,7 @@ impl Rational { #[derive(Copy, Clone)] pub struct ErrorPropagatingFloat { value: T, - prec: f64, + abs_err: f64, } impl Neg for ErrorPropagatingFloat { @@ -1983,7 +2000,7 @@ impl Neg for ErrorPropagatingFloat { fn neg(self) -> Self::Output { ErrorPropagatingFloat { value: -self.value, - prec: self.prec, + abs_err: self.abs_err, } } } @@ -1993,13 +2010,9 @@ impl Add<&ErrorPropagatingFloat> for ErrorPropagatingFloat #[inline] fn add(self, rhs: &Self) -> Self::Output { - // TODO: handle r = 0 - let r = self.value.clone() + &rhs.value; ErrorPropagatingFloat { - prec: (self.get_num().to_f64().abs() * self.prec - + rhs.get_num().to_f64().abs() * rhs.prec) - / r.clone().to_f64().abs(), - value: r, + abs_err: self.abs_err + rhs.abs_err, + value: self.value + &rhs.value, } } } @@ -2036,9 +2049,20 @@ impl Mul<&ErrorPropagatingFloat> for ErrorPropagatingFloat #[inline] fn mul(self, rhs: &Self) -> Self::Output { - ErrorPropagatingFloat { - value: self.value.clone() * &rhs.value, - prec: self.prec + rhs.prec, + let value = self.value.clone() * &rhs.value; + let r = rhs.value.to_f64().abs(); + let s = self.value.to_f64().abs(); + + if s == 0. && r == 0. { + return ErrorPropagatingFloat { + value, + abs_err: self.abs_err * rhs.abs_err, + }; + } else { + ErrorPropagatingFloat { + value, + abs_err: self.abs_err * r + rhs.abs_err * s, + } } } } @@ -2048,10 +2072,10 @@ impl> Add for ErrorPropa #[inline] fn add(self, rhs: Rational) -> Self::Output { - let v = self.value.to_f64(); - let prec = self.prec * v.abs() / (v + rhs.to_f64()).abs(); - let r = self.value + rhs; - ErrorPropagatingFloat { prec, value: r }.truncate() + ErrorPropagatingFloat { + abs_err: self.abs_err, + value: self.value + rhs, + } } } @@ -2070,9 +2094,10 @@ impl> Mul for ErrorPropa #[inline] fn mul(self, rhs: Rational) -> Self::Output { ErrorPropagatingFloat { + abs_err: self.abs_err * rhs.to_f64().abs(), value: self.value * rhs, - prec: self.prec, } + .truncate() } } @@ -2082,9 +2107,10 @@ impl> Div for ErrorPropa #[inline] fn div(self, rhs: Rational) -> Self::Output { ErrorPropagatingFloat { + abs_err: self.abs_err * rhs.inv().to_f64().abs(), value: self.value.clone() / rhs, - prec: self.prec, } + .truncate() } } @@ -2102,10 +2128,7 @@ impl Div<&ErrorPropagatingFloat> for ErrorPropagatingFloat #[inline] fn div(self, rhs: &Self) -> Self::Output { - ErrorPropagatingFloat { - value: self.value.clone() / &rhs.value, - prec: self.prec + rhs.prec, - } + self * rhs.inv() } } @@ -2178,26 +2201,51 @@ impl DivAssign> for ErrorPropagating } } -impl ErrorPropagatingFloat { +impl ErrorPropagatingFloat { /// Create a new precision tracking float with a number of precise decimal digits `prec`. /// The `prec` must be smaller than the precision of the underlying float. + /// + /// If the value provided is 0, the precision argument is interpreted as an accuracy ( + /// the number of digits of the absolute error). pub fn new(value: T, prec: f64) -> Self { - ErrorPropagatingFloat { - value, - prec: 10f64.pow(-prec), + let r = value.to_f64().abs(); + + if r == 0. { + ErrorPropagatingFloat { + abs_err: 10f64.pow(-prec), + value, + } + } else { + ErrorPropagatingFloat { + abs_err: 10f64.pow(-prec) * r, + value, + } } } - /// Get the number. - #[inline(always)] - pub fn get_num(&self) -> &T { - &self.value + pub fn get_absolute_error(&self) -> f64 { + self.abs_err + } + + pub fn get_relative_error(&self) -> f64 { + self.abs_err / self.value.to_f64().abs() } /// Get the precision in number of decimal digits. #[inline(always)] - pub fn get_precision(&self) -> f64 { - -self.prec.log10() + pub fn get_precision(&self) -> Option { + let r = self.value.to_f64().abs(); + if r == 0. { + return None; + } else { + Some(-(self.abs_err / r).log10()) + } + } + + /// Get the accuracy in number of decimal digits. + #[inline(always)] + pub fn get_accuracy(&self) -> f64 { + -self.abs_err.log10() } /// Truncate the precision to the maximal number of stable decimal digits @@ -2205,36 +2253,56 @@ impl ErrorPropagatingFloat { #[inline(always)] pub fn truncate(mut self) -> Self { if self.value.fixed_precision() { - self.prec = self.prec.max(self.value.get_epsilon()); + self.abs_err = self + .abs_err + .max(self.value.get_epsilon() * self.value.to_f64()); } self } } -impl fmt::Display for ErrorPropagatingFloat { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - let p = self.get_precision() as usize; +impl ErrorPropagatingFloat { + pub fn new_with_accuracy(value: T, acc: f64) -> Self { + ErrorPropagatingFloat { + value, + abs_err: 10f64.pow(-acc), + } + } - if p == 0 { - f.write_char('0') + /// Get the number. + #[inline(always)] + pub fn get_num(&self) -> &T { + &self.value + } +} + +impl fmt::Display for ErrorPropagatingFloat { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + if let Some(p) = self.get_precision() { + if p < 0. { + return f.write_char('0'); + } else { + f.write_fmt(format_args!("{0:.1$}", self.value, p as usize)) + } } else { - f.write_fmt(format_args!( - "{0:.1$e}", - self.value, - self.get_precision() as usize - )) + f.write_char('0') } } } -impl Debug for ErrorPropagatingFloat { +impl Debug for ErrorPropagatingFloat { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { Debug::fmt(&self.value, f)?; - f.write_fmt(format_args!("`{}", self.get_precision())) + + if let Some(p) = self.get_precision() { + f.write_fmt(format_args!("`{:.2}", p)) + } else { + f.write_fmt(format_args!("``{:.2}", -self.abs_err.log10())) + } } } -impl LowerExp for ErrorPropagatingFloat { +impl LowerExp for ErrorPropagatingFloat { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { Display::fmt(self, f) } @@ -2265,49 +2333,81 @@ impl NumericalFloatLike for ErrorPropagatingFloat { fn zero(&self) -> Self { ErrorPropagatingFloat { value: self.value.zero(), - prec: 2f64.pow(-(self.value.get_precision() as f64)), + abs_err: 2f64.pow(-(self.value.get_precision() as f64)), } } fn new_zero() -> Self { ErrorPropagatingFloat { value: T::new_zero(), - prec: 2f64.powi(-53), + abs_err: 2f64.powi(-53), } } fn one(&self) -> Self { ErrorPropagatingFloat { value: self.value.one(), - prec: 2f64.pow(-(self.value.get_precision() as f64)), + abs_err: 2f64.pow(-(self.value.get_precision() as f64)), } } fn pow(&self, e: u64) -> Self { + let i = self.to_f64().abs(); + + if i == 0. { + return ErrorPropagatingFloat { + value: self.value.pow(e), + abs_err: self.abs_err.pow(e as f64), + }; + } + + let r = self.value.pow(e); ErrorPropagatingFloat { - value: self.value.pow(e), - prec: self.prec * e as f64, + abs_err: self.abs_err * e as f64 * r.to_f64().abs() / i, + value: r, } } fn inv(&self) -> Self { + let r = self.value.inv(); + let rr = r.to_f64().abs(); ErrorPropagatingFloat { - value: self.value.inv(), - prec: self.prec, + abs_err: self.abs_err * rr * rr, + value: r, } } + /// Convert from a `usize`. fn from_usize(&self, a: usize) -> Self { - ErrorPropagatingFloat { - value: self.value.from_usize(a), - prec: self.prec, + let v = self.value.from_usize(a); + let r = v.to_f64().abs(); + if r == 0. { + ErrorPropagatingFloat { + value: v, + abs_err: 2f64.pow(-(self.value.get_precision() as f64)), + } + } else { + ErrorPropagatingFloat { + value: v, + abs_err: 2f64.pow(-(self.value.get_precision() as f64)) * r, + } } } + /// Convert from a `i64`. fn from_i64(&self, a: i64) -> Self { - ErrorPropagatingFloat { - value: self.value.from_i64(a), - prec: self.prec, + let v = self.value.from_i64(a); + let r = v.to_f64().abs(); + if r == 0. { + ErrorPropagatingFloat { + value: v, + abs_err: 2f64.pow(-(self.value.get_precision() as f64)), + } + } else { + ErrorPropagatingFloat { + value: v, + abs_err: 2f64.pow(-(self.value.get_precision() as f64)) * r, + } } } @@ -2318,7 +2418,7 @@ impl NumericalFloatLike for ErrorPropagatingFloat { } fn get_epsilon(&self) -> f64 { - 2.0f64.powi(-(self.get_precision() as i32)) + 2.0f64.powi(-(self.value.get_precision() as i32)) } #[inline(always)] @@ -2327,9 +2427,10 @@ impl NumericalFloatLike for ErrorPropagatingFloat { } fn sample_unit(&self, rng: &mut R) -> Self { + let v = self.value.sample_unit(rng); ErrorPropagatingFloat { - value: self.value.sample_unit(rng), - prec: self.prec, + abs_err: self.abs_err * v.to_f64().abs(), + value: v, } } } @@ -2348,9 +2449,16 @@ impl SingleFloat for ErrorPropagatingFloat { } fn from_rational(&self, rat: &Rational) -> Self { - ErrorPropagatingFloat { - value: self.value.from_rational(rat), - prec: self.prec, + if rat.is_zero() { + ErrorPropagatingFloat { + value: self.value.from_rational(rat), + abs_err: self.abs_err, + } + } else { + ErrorPropagatingFloat { + value: self.value.from_rational(rat), + abs_err: self.abs_err * rat.to_f64(), + } } } } @@ -2372,45 +2480,59 @@ impl RealNumberLike for ErrorPropagatingFloat { impl Real for ErrorPropagatingFloat { fn pi(&self) -> Self { + let v = self.value.pi(); ErrorPropagatingFloat { - value: self.value.pi(), - prec: self.prec, + abs_err: 2f64.pow(-(self.value.get_precision() as f64)) * v.to_f64(), + value: v, } } fn e(&self) -> Self { + let v = self.value.e(); ErrorPropagatingFloat { - value: self.value.e(), - prec: self.prec, + abs_err: 2f64.pow(-(self.value.get_precision() as f64)) * v.to_f64(), + value: v, } } fn euler(&self) -> Self { + let v = self.value.euler(); ErrorPropagatingFloat { - value: self.value.euler(), - prec: self.prec, + abs_err: 2f64.pow(-(self.value.get_precision() as f64)) * v.to_f64(), + value: v, } } fn phi(&self) -> Self { + let v = self.value.phi(); ErrorPropagatingFloat { - value: self.value.phi(), - prec: self.prec, + abs_err: 2f64.pow(-(self.value.get_precision() as f64)) * v.to_f64(), + value: v, } } + #[inline(always)] + fn i(&self) -> Option { + Some(ErrorPropagatingFloat { + value: self.value.i()?, + abs_err: 2f64.pow(-(self.value.get_precision() as f64)), + }) + } + fn norm(&self) -> Self { ErrorPropagatingFloat { + abs_err: self.abs_err, value: self.value.norm(), - prec: self.prec, } - .truncate() } fn sqrt(&self) -> Self { + let v = self.value.sqrt(); + let r = v.to_f64().abs(); + ErrorPropagatingFloat { - value: self.value.sqrt(), - prec: self.prec / 2., + abs_err: self.abs_err / (2. * r), + value: v, } .truncate() } @@ -2418,23 +2540,24 @@ impl Real for ErrorPropagatingFloat { fn log(&self) -> Self { let r = self.value.log(); ErrorPropagatingFloat { - prec: self.prec / r.clone().to_f64().abs(), + abs_err: self.abs_err / self.value.to_f64().abs(), value: r, } .truncate() } fn exp(&self) -> Self { + let v = self.value.exp(); ErrorPropagatingFloat { - prec: self.value.to_f64().abs() * self.prec, - value: self.value.exp(), + abs_err: v.to_f64().abs() * self.abs_err, + value: v, } .truncate() } fn sin(&self) -> Self { ErrorPropagatingFloat { - prec: self.prec * self.value.to_f64().abs() / self.value.tan().to_f64().abs(), + abs_err: self.abs_err * self.value.to_f64().cos().abs(), value: self.value.sin(), } .truncate() @@ -2442,7 +2565,7 @@ impl Real for ErrorPropagatingFloat { fn cos(&self) -> Self { ErrorPropagatingFloat { - prec: self.prec * self.value.to_f64().abs() * self.value.tan().to_f64().abs(), + abs_err: self.abs_err * self.value.to_f64().sin().abs(), value: self.value.cos(), } .truncate() @@ -2451,8 +2574,9 @@ impl Real for ErrorPropagatingFloat { fn tan(&self) -> Self { let t = self.value.tan(); let tt = t.to_f64().abs(); + ErrorPropagatingFloat { - prec: self.prec * self.value.to_f64().abs() * (tt.inv() + tt), + abs_err: self.abs_err * (1. + tt * tt), value: t, } .truncate() @@ -2461,9 +2585,9 @@ impl Real for ErrorPropagatingFloat { fn asin(&self) -> Self { let v = self.value.to_f64(); let t = self.value.asin(); - let tt = (1. - v * v).sqrt() * t.to_f64().abs(); + let tt = (1. - v * v).sqrt(); ErrorPropagatingFloat { - prec: self.prec * v.abs() / tt, + abs_err: self.abs_err / tt, value: t, } .truncate() @@ -2472,9 +2596,9 @@ impl Real for ErrorPropagatingFloat { fn acos(&self) -> Self { let v = self.value.to_f64(); let t = self.value.acos(); - let tt = (1. - v * v).sqrt() * t.to_f64().abs(); + let tt = (1. - v * v).sqrt(); ErrorPropagatingFloat { - prec: self.prec * v.abs() / tt, + abs_err: self.abs_err / tt, value: t, } .truncate() @@ -2485,9 +2609,9 @@ impl Real for ErrorPropagatingFloat { let r = self.clone() / x; let r2 = r.value.to_f64().abs(); - let tt = (1. + r2 * r2) * t.clone().to_f64().abs(); + let tt = 1. + r2 * r2; ErrorPropagatingFloat { - prec: r.prec * r2 / tt, + abs_err: r.abs_err / tt, value: t, } .truncate() @@ -2495,7 +2619,7 @@ impl Real for ErrorPropagatingFloat { fn sinh(&self) -> Self { ErrorPropagatingFloat { - prec: self.prec * self.value.to_f64().abs() / self.value.tanh().to_f64().abs(), + abs_err: self.abs_err * self.value.cosh().to_f64().abs(), value: self.value.sinh(), } .truncate() @@ -2503,7 +2627,7 @@ impl Real for ErrorPropagatingFloat { fn cosh(&self) -> Self { ErrorPropagatingFloat { - prec: self.prec * self.value.to_f64().abs() * self.value.tanh().to_f64().abs(), + abs_err: self.abs_err * self.value.sinh().to_f64().abs(), value: self.value.cosh(), } .truncate() @@ -2513,7 +2637,7 @@ impl Real for ErrorPropagatingFloat { let t = self.value.tanh(); let tt = t.clone().to_f64().abs(); ErrorPropagatingFloat { - prec: self.prec * self.value.to_f64().abs() * (tt.inv() - tt), + abs_err: self.abs_err * (1. - tt * tt), value: t, } .truncate() @@ -2522,9 +2646,9 @@ impl Real for ErrorPropagatingFloat { fn asinh(&self) -> Self { let v = self.value.to_f64(); let t = self.value.asinh(); - let tt = (1. + v * v).sqrt() * t.to_f64().abs(); + let tt = (1. + v * v).sqrt(); ErrorPropagatingFloat { - prec: self.prec * v.abs() / tt, + abs_err: self.abs_err / tt, value: t, } .truncate() @@ -2533,9 +2657,9 @@ impl Real for ErrorPropagatingFloat { fn acosh(&self) -> Self { let v = self.value.to_f64(); let t = self.value.acosh(); - let tt = (v * v - 1.).sqrt() * t.to_f64().abs(); + let tt = (v * v - 1.).sqrt(); ErrorPropagatingFloat { - prec: self.prec * v.abs() / tt, + abs_err: self.abs_err / tt, value: t, } .truncate() @@ -2544,19 +2668,30 @@ impl Real for ErrorPropagatingFloat { fn atanh(&self) -> Self { let v = self.value.to_f64(); let t = self.value.atanh(); - let tt = (1. - v * v) * t.to_f64().abs(); + let tt = 1. - v * v; ErrorPropagatingFloat { - prec: self.prec * v.abs() / tt, + abs_err: self.abs_err / tt, value: t, } .truncate() } fn powf(&self, e: &Self) -> Self { - let v = self.value.to_f64().abs(); + let i = self.to_f64().abs(); + + if i == 0. { + return ErrorPropagatingFloat { + value: self.value.powf(&e.value), + abs_err: 0., + }; + } + + let r = self.value.powf(&e.value); ErrorPropagatingFloat { - value: self.value.powf(&e.value), - prec: (self.prec + e.prec * v.ln().abs()) * e.value.clone().to_f64().abs(), + abs_err: (self.abs_err * e.value.to_f64() + i * e.abs_err * i.ln().abs()) + * r.to_f64().abs() + / i, + value: r, } .truncate() } @@ -2653,6 +2788,11 @@ macro_rules! simd_impl { 1.6180339887498948.into() } + #[inline(always)] + fn i(&self) -> Option { + None + } + #[inline(always)] fn norm(&self) -> Self { (*self).abs() @@ -3603,6 +3743,11 @@ impl Real for Complex { Complex::new(self.re.phi(), self.im.zero()) } + #[inline(always)] + fn i(&self) -> Option { + Some(self.i()) + } + #[inline] fn norm(&self) -> Self { Complex::new(self.norm_squared().sqrt(), self.im.zero()) @@ -3784,7 +3929,7 @@ mod test { + b.powf(&a); assert_eq!(r.value, 17293.219725825093); // error is 14.836811363436391 when the f64 could have theoretically grown in between - assert_eq!(r.get_precision(), 14.836795991431746); + assert_eq!(r.get_precision(), Some(14.836795991431746)); } #[test] @@ -3792,16 +3937,15 @@ mod test { let a = ErrorPropagatingFloat::new(0.0000000123456789, 9.) .exp() .log(); - assert_eq!(a.get_precision(), 8.046104745509947); + assert_eq!(a.get_precision(), Some(8.046104745509947)); } #[test] fn large_cancellation() { let a = ErrorPropagatingFloat::new(Float::with_val(200, 1e-50), 60.); let r = (a.exp() - a.one()) / a; - println!("{}", r.value.prec()); - assert_eq!(format!("{}", r), "1.000000000e0"); - assert_eq!(r.get_precision(), 10.205999132807323); + assert_eq!(format!("{}", r), "1.000000000"); + assert_eq!(r.get_precision(), Some(10.205999132796238)); } #[test] diff --git a/src/evaluate.rs b/src/evaluate.rs index a2687f71..ba176cc2 100644 --- a/src/evaluate.rs +++ b/src/evaluate.rs @@ -718,11 +718,11 @@ impl ExpressionEvaluator { self.stack[*r] = self.stack[*b].powf(&self.stack[*e]); } Instr::BuiltinFun(r, s, arg) => match s.0 { - State::EXP => self.stack[*r] = self.stack[*arg].exp(), - State::LOG => self.stack[*r] = self.stack[*arg].log(), - State::SIN => self.stack[*r] = self.stack[*arg].sin(), - State::COS => self.stack[*r] = self.stack[*arg].cos(), - State::SQRT => self.stack[*r] = self.stack[*arg].sqrt(), + Atom::EXP => self.stack[*r] = self.stack[*arg].exp(), + Atom::LOG => self.stack[*r] = self.stack[*arg].log(), + Atom::SIN => self.stack[*r] = self.stack[*arg].sin(), + Atom::COS => self.stack[*r] = self.stack[*arg].cos(), + Atom::SQRT => self.stack[*r] = self.stack[*arg].sqrt(), _ => unreachable!(), }, } @@ -1359,23 +1359,23 @@ impl ExpressionEvaluator { *out += format!("\tZ{} = pow({}, {});\n", o, base, exp).as_str(); } Instr::BuiltinFun(o, s, a) => match s.0 { - State::EXP => { + Atom::EXP => { let arg = format!("Z{}", a); *out += format!("\tZ{} = exp({});\n", o, arg).as_str(); } - State::LOG => { + Atom::LOG => { let arg = format!("Z{}", a); *out += format!("\tZ{} = log({});\n", o, arg).as_str(); } - State::SIN => { + Atom::SIN => { let arg = format!("Z{}", a); *out += format!("\tZ{} = sin({});\n", o, arg).as_str(); } - State::COS => { + Atom::COS => { let arg = format!("Z{}", a); *out += format!("\tZ{} = cos({});\n", o, arg).as_str(); } - State::SQRT => { + Atom::SQRT => { let arg = format!("Z{}", a); *out += format!("\tZ{} = sqrt({});\n", o, arg).as_str(); } @@ -1931,19 +1931,19 @@ impl ExpressionEvaluator { let arg = get_input!(*a); match s.0 { - State::EXP => { + Atom::EXP => { *out += format!("\tZ[{}] = exp({});\n", o, arg).as_str(); } - State::LOG => { + Atom::LOG => { *out += format!("\tZ[{}] = log({});\n", o, arg).as_str(); } - State::SIN => { + Atom::SIN => { *out += format!("\tZ[{}] = sin({});\n", o, arg).as_str(); } - State::COS => { + Atom::COS => { *out += format!("\tZ[{}] = cos({});\n", o, arg).as_str(); } - State::SQRT => { + Atom::SQRT => { *out += format!("\tZ[{}] = sqrt({});\n", o, arg).as_str(); } _ => unreachable!(), @@ -2122,19 +2122,19 @@ impl ExpressionEvaluator { let arg = get_input!(*a); match s.0 { - State::EXP => { + Atom::EXP => { *out += format!("\tZ[{}] = exp({});\n", o, arg).as_str(); } - State::LOG => { + Atom::LOG => { *out += format!("\tZ[{}] = log({});\n", o, arg).as_str(); } - State::SIN => { + Atom::SIN => { *out += format!("\tZ[{}] = sin({});\n", o, arg).as_str(); } - State::COS => { + Atom::COS => { *out += format!("\tZ[{}] = cos({});\n", o, arg).as_str(); } - State::SQRT => { + Atom::SQRT => { *out += format!("\tZ[{}] = sqrt({});\n", o, arg).as_str(); } _ => unreachable!(), @@ -3383,11 +3383,11 @@ impl EvalTree { Expression::BuiltinFun(s, a) => { let arg = self.evaluate_impl(a, subexpressions, params, args); match s.0 { - State::EXP => arg.exp(), - State::LOG => arg.log(), - State::SIN => arg.sin(), - State::COS => arg.cos(), - State::SQRT => arg.sqrt(), + Atom::EXP => arg.exp(), + Atom::LOG => arg.log(), + Atom::SIN => arg.sin(), + Atom::COS => arg.cos(), + Atom::SQRT => arg.sqrt(), _ => unreachable!(), } } @@ -3815,31 +3815,31 @@ impl EvalTree { } Expression::ReadArg(s) => args[*s].to_string(), Expression::BuiltinFun(s, a) => match s.0 { - State::EXP => { + Atom::EXP => { let mut r = "exp(".to_string(); r += &self.export_cpp_impl(a, args); r.push(')'); r } - State::LOG => { + Atom::LOG => { let mut r = "log(".to_string(); r += &self.export_cpp_impl(a, args); r.push(')'); r } - State::SIN => { + Atom::SIN => { let mut r = "sin(".to_string(); r += &self.export_cpp_impl(a, args); r.push(')'); r } - State::COS => { + Atom::COS => { let mut r = "cos(".to_string(); r += &self.export_cpp_impl(a, args); r.push(')'); r } - State::SQRT => { + Atom::SQRT => { let mut r = "sqrt(".to_string(); r += &self.export_cpp_impl(a, args); r.push(')'); @@ -3930,7 +3930,7 @@ impl<'a> AtomView<'a> { } AtomView::Fun(f) => { let name = f.get_symbol(); - if [State::EXP, State::LOG, State::SIN, State::COS, State::SQRT].contains(&name) { + if [Atom::EXP, Atom::LOG, Atom::SIN, Atom::COS, Atom::SQRT].contains(&name) { assert!(f.get_nargs() == 1); let arg = f.iter().next().unwrap(); let arg_eval = arg.to_eval_tree_impl(fn_map, params, args, funcs)?; @@ -4065,8 +4065,11 @@ impl<'a> AtomView<'a> { ), }, AtomView::Var(v) => match v.get_symbol() { - State::E => Ok(coeff_map(&1.into()).e()), - State::PI => Ok(coeff_map(&1.into()).pi()), + Atom::E => Ok(coeff_map(&1.into()).e()), + Atom::PI => Ok(coeff_map(&1.into()).pi()), + Atom::I => coeff_map(&1.into()) + .i() + .ok_or_else(|| "Numerical type does not support imaginary unit".to_string()), _ => Err(format!( "Variable {} not in constant map", State::get_name(v.get_symbol()) @@ -4074,17 +4077,17 @@ impl<'a> AtomView<'a> { }, AtomView::Fun(f) => { let name = f.get_symbol(); - if [State::EXP, State::LOG, State::SIN, State::COS, State::SQRT].contains(&name) { + if [Atom::EXP, Atom::LOG, Atom::SIN, Atom::COS, Atom::SQRT].contains(&name) { assert!(f.get_nargs() == 1); let arg = f.iter().next().unwrap(); let arg_eval = arg.evaluate(coeff_map, const_map, function_map, cache)?; return Ok(match f.get_symbol() { - State::EXP => arg_eval.exp(), - State::LOG => arg_eval.log(), - State::SIN => arg_eval.sin(), - State::COS => arg_eval.cos(), - State::SQRT => arg_eval.sqrt(), + Atom::EXP => arg_eval.exp(), + Atom::LOG => arg_eval.log(), + Atom::SIN => arg_eval.sin(), + Atom::COS => arg_eval.cos(), + Atom::SQRT => arg_eval.sqrt(), _ => unreachable!(), }); } diff --git a/src/id.rs b/src/id.rs index c22edbd7..b00795f0 100644 --- a/src/id.rs +++ b/src/id.rs @@ -8,7 +8,7 @@ use crate::{ representation::{InlineVar, ListSlice}, AsAtomView, Atom, AtomType, AtomView, Num, SliceType, Symbol, }, - state::{State, Workspace}, + state::Workspace, transformer::{Transformer, TransformerError}, }; @@ -2109,7 +2109,7 @@ impl<'a> Match<'a> { // to update the coefficient flag } SliceType::Arg => { - let fun = out.to_fun(State::ARG); + let fun = out.to_fun(Atom::ARG); for arg in wargs { fun.add_arg(*arg); } @@ -2124,7 +2124,7 @@ impl<'a> Match<'a> { out.set_from_view(&wargs[0]); } SliceType::Empty => { - let f = out.to_fun(State::ARG); + let f = out.to_fun(Atom::ARG); f.set_normalized(true); } }, diff --git a/src/normalize.rs b/src/normalize.rs index 4961981a..cfb2653c 100644 --- a/src/normalize.rs +++ b/src/normalize.rs @@ -275,7 +275,7 @@ impl<'a> AtomView<'a> { /// Simplify logs in the argument of the exponential function. fn simplify_exp_log(&self, ws: &Workspace, out: &mut Atom) -> bool { if let AtomView::Fun(f) = self { - if f.get_symbol() == State::LOG && f.get_nargs() == 1 { + if f.get_symbol() == Atom::LOG && f.get_nargs() == 1 { out.set_from_view(&f.iter().next().unwrap()); return true; } @@ -333,7 +333,7 @@ impl<'a> AtomView<'a> { if changed { let mut new_exp = ws.new_atom(); // TODO: change to e^() - new_exp.to_fun(State::EXP).add_arg(aa.as_view()); + new_exp.to_fun(Atom::EXP).add_arg(aa.as_view()); m.extend(new_exp.as_view()); @@ -489,7 +489,7 @@ impl Atom { // x * x => x^2 if self.as_view() == other.as_view() { if let AtomView::Var(v) = self.as_view() { - if v.get_symbol() == State::I { + if v.get_symbol() == Atom::I { self.to_num((-1).into()); return true; } @@ -821,7 +821,7 @@ impl<'a> AtomView<'a> { #[inline(always)] fn add_arg(f: &mut Fun, a: AtomView) { if let AtomView::Fun(fa) = a { - if fa.get_symbol() == State::ARG { + if fa.get_symbol() == Atom::ARG { // flatten f(arg(...)) = f(...) for aa in fa.iter() { f.add_arg(aa); @@ -872,16 +872,16 @@ impl<'a> AtomView<'a> { out_f.set_normalized(true); - if [State::COS, State::SIN, State::EXP, State::LOG].contains(&id) + if [Atom::COS, Atom::SIN, Atom::EXP, Atom::LOG].contains(&id) && out_f.to_fun_view().get_nargs() == 1 { let arg = out_f.to_fun_view().iter().next().unwrap(); if let AtomView::Num(n) = arg { - if n.is_zero() && id != State::LOG || n.is_one() && id == State::LOG { - if id == State::COS || id == State::EXP { + if n.is_zero() && id != Atom::LOG || n.is_one() && id == Atom::LOG { + if id == Atom::COS || id == Atom::EXP { out.to_num(Coefficient::one()); return; - } else if id == State::SIN || id == State::LOG { + } else if id == Atom::SIN || id == Atom::LOG { out.to_num(Coefficient::zero()); return; } @@ -889,22 +889,22 @@ impl<'a> AtomView<'a> { if let CoefficientView::Float(f) = n.get_coeff_view() { match id { - State::COS => { + Atom::COS => { let r = f.to_float().cos(); out.to_num(Coefficient::Float(r)); return; } - State::SIN => { + Atom::SIN => { let r = f.to_float().sin(); out.to_num(Coefficient::Float(r)); return; } - State::EXP => { + Atom::EXP => { let r = f.to_float().exp(); out.to_num(Coefficient::Float(r)); return; } - State::LOG => { + Atom::LOG => { let r = f.to_float().log(); out.to_num(Coefficient::Float(r)); return; @@ -915,10 +915,10 @@ impl<'a> AtomView<'a> { } } - if id == State::EXP && out_f.to_fun_view().get_nargs() == 1 { + if id == Atom::EXP && out_f.to_fun_view().get_nargs() == 1 { let arg = out_f.to_fun_view().iter().next().unwrap(); // simplify logs inside exp - if arg.contains_symbol(State::LOG) { + if arg.contains_symbol(Atom::LOG) { let mut buffer = workspace.new_atom(); if arg.simplify_exp_log(workspace, &mut buffer) { out.set_from_view(&buffer.as_view()); @@ -928,7 +928,7 @@ impl<'a> AtomView<'a> { } // try to turn the argument into a number - if id == State::COEFF && out_f.to_fun_view().get_nargs() == 1 { + if id == Atom::COEFF && out_f.to_fun_view().get_nargs() == 1 { let arg = out_f.to_fun_view().iter().next().unwrap(); if let AtomView::Num(_) = arg { let mut buffer = workspace.new_atom(); @@ -1197,7 +1197,7 @@ impl<'a> AtomView<'a> { base_handle.to_num(new_base_num); exp_handle.to_num(new_exp_num); } else if let AtomView::Var(v) = base_handle.as_view() { - if v.get_symbol() == State::I { + if v.get_symbol() == Atom::I { if let CoefficientView::Natural(n, d) = exp_num { let mut new_base = workspace.new_atom(); @@ -1533,7 +1533,7 @@ impl<'a> AtomView<'a> { #[cfg(test)] mod test { - use crate::{atom::Atom, state::State}; + use crate::atom::Atom; #[test] fn pow_apart() { @@ -1564,8 +1564,8 @@ mod test { #[test] fn mul_complex_i() { - let res = Atom::new_var(State::I) * &Atom::new_var(State::E) * &Atom::new_var(State::I); - let refr = -Atom::new_var(State::E); + let res = Atom::new_var(Atom::I) * &Atom::new_var(Atom::E) * &Atom::new_var(Atom::I); + let refr = -Atom::new_var(Atom::E); assert_eq!(res, refr); } diff --git a/src/poly/evaluate.rs b/src/poly/evaluate.rs index dd7f15d4..e9388656 100644 --- a/src/poly/evaluate.rs +++ b/src/poly/evaluate.rs @@ -20,7 +20,6 @@ use crate::{ atom::{Atom, AtomView}, domains::{float::Real, Ring}, evaluate::EvaluationFn, - state::State, }; use super::{polynomial::MultivariatePolynomial, PositiveExponent}; @@ -1678,7 +1677,7 @@ impl<'a> std::fmt::Display for InstructionSetPrinter<'a> { None } } else if let super::Variable::Symbol(i) = x { - if [State::E, State::I, State::PI].contains(i) { + if [Atom::E, Atom::I, Atom::PI].contains(i) { None } else { Some(format!("T {}", x.to_string())) @@ -1855,7 +1854,7 @@ impl ExpressionEvaluator { None } } else if let super::Variable::Symbol(i) = x { - if [State::E, State::I, State::PI].contains(i) { + if [Atom::E, Atom::I, Atom::PI].contains(i) { None } else { Some(x.clone()) diff --git a/src/poly/series.rs b/src/poly/series.rs index f23f17de..b161ca4a 100644 --- a/src/poly/series.rs +++ b/src/poly/series.rs @@ -15,7 +15,6 @@ use crate::{ EuclideanDomain, InternalOrdering, Ring, SelfRing, }, printer::{PrintOptions, PrintState}, - state::State, }; use super::Variable; @@ -923,7 +922,7 @@ impl Series { }; // construct the constant term, log(x) in the argument will be turned into x - let e = FunctionBuilder::new(State::EXP).add_arg(&c).finish(); + let e = FunctionBuilder::new(Atom::EXP).add_arg(&c).finish(); // split the true constant part and the x-dependent part let var = self.variable.to_atom() - &self.expansion_point; @@ -961,7 +960,7 @@ impl Series { .mul_exp_units(-self.shift) - self.one(); - let mut e = self.constant(FunctionBuilder::new(State::LOG).add_arg(&c).finish()); + let mut e = self.constant(FunctionBuilder::new(Atom::LOG).add_arg(&c).finish()); let mut sp = p.clone(); for i in 1..=self.order { let s = sp.clone().div_coeff(&Atom::new_num(i as i64)); @@ -1007,13 +1006,13 @@ impl Series { let p = self.clone().remove_constant(); - let mut e = self.constant(FunctionBuilder::new(State::SIN).add_arg(&c).finish()); + let mut e = self.constant(FunctionBuilder::new(Atom::SIN).add_arg(&c).finish()); let mut sp = p.clone(); for i in 1..=self.order { let mut b = if i % 2 == 1 { - FunctionBuilder::new(State::COS).add_arg(&c).finish() + FunctionBuilder::new(Atom::COS).add_arg(&c).finish() } else { - FunctionBuilder::new(State::SIN).add_arg(&c).finish() + FunctionBuilder::new(Atom::SIN).add_arg(&c).finish() }; if i % 4 >= 2 { @@ -1063,13 +1062,13 @@ impl Series { let p = self.clone().remove_constant(); - let mut e = self.constant(FunctionBuilder::new(State::COS).add_arg(&c).finish()); + let mut e = self.constant(FunctionBuilder::new(Atom::COS).add_arg(&c).finish()); let mut sp = p.clone(); for i in 1..=self.order { let mut b = if i % 2 == 1 { - FunctionBuilder::new(State::SIN).add_arg(&c).finish() + FunctionBuilder::new(Atom::SIN).add_arg(&c).finish() } else { - -FunctionBuilder::new(State::COS).add_arg(&c).finish() + -FunctionBuilder::new(Atom::COS).add_arg(&c).finish() }; if i % 4 < 2 { diff --git a/src/printer.rs b/src/printer.rs index ff4e788e..42c34a00 100644 --- a/src/printer.rs +++ b/src/printer.rs @@ -459,9 +459,9 @@ impl<'a> FormattedPrintVar for VarView<'a> { if opts.latex { match id { - State::E => f.write_char('e'), - State::PI => f.write_str("\\pi"), - State::I => f.write_char('i'), + Atom::E => f.write_char('e'), + Atom::PI => f.write_str("\\pi"), + Atom::I => f.write_char('i'), _ => f.write_str(name), } } else if opts.color_builtin_symbols && name.ends_with('_') { diff --git a/src/state.rs b/src/state.rs index 547483b0..0ac508f1 100644 --- a/src/state.rs +++ b/src/state.rs @@ -88,17 +88,17 @@ impl Default for State { } impl State { - pub const ARG: Symbol = Symbol::init_fn(0, 0, false, false, false, false); - pub const COEFF: Symbol = Symbol::init_fn(1, 0, false, false, false, false); - pub const EXP: Symbol = Symbol::init_fn(2, 0, false, false, false, false); - pub const LOG: Symbol = Symbol::init_fn(3, 0, false, false, false, false); - pub const SIN: Symbol = Symbol::init_fn(4, 0, false, false, false, false); - pub const COS: Symbol = Symbol::init_fn(5, 0, false, false, false, false); - pub const SQRT: Symbol = Symbol::init_fn(6, 0, false, false, false, false); - pub const DERIVATIVE: Symbol = Symbol::init_fn(7, 0, false, false, false, false); - pub const E: Symbol = Symbol::init_var(8, 0); - pub const I: Symbol = Symbol::init_var(9, 0); - pub const PI: Symbol = Symbol::init_var(10, 0); + pub(crate) const ARG: Symbol = Symbol::init_fn(0, 0, false, false, false, false); + pub(crate) const COEFF: Symbol = Symbol::init_fn(1, 0, false, false, false, false); + pub(crate) const EXP: Symbol = Symbol::init_fn(2, 0, false, false, false, false); + pub(crate) const LOG: Symbol = Symbol::init_fn(3, 0, false, false, false, false); + pub(crate) const SIN: Symbol = Symbol::init_fn(4, 0, false, false, false, false); + pub(crate) const COS: Symbol = Symbol::init_fn(5, 0, false, false, false, false); + pub(crate) const SQRT: Symbol = Symbol::init_fn(6, 0, false, false, false, false); + pub(crate) const DERIVATIVE: Symbol = Symbol::init_fn(7, 0, false, false, false, false); + pub(crate) const E: Symbol = Symbol::init_var(8, 0); + pub(crate) const I: Symbol = Symbol::init_var(9, 0); + pub(crate) const PI: Symbol = Symbol::init_var(10, 0); pub const BUILTIN_VAR_LIST: [&'static str; 11] = [ "arg", "coeff", "exp", "log", "sin", "cos", "sqrt", "der", "𝑒", "𝑖", "𝜋", @@ -888,7 +888,7 @@ mod tests { if f.get_nargs() == 1 { let arg = f.iter().next().unwrap(); if let AtomView::Fun(f2) = arg { - if f2.get_symbol() == State::EXP { + if f2.get_symbol() == Atom::EXP { if f2.get_nargs() == 1 { out.set_from_view(&f2.iter().next().unwrap()); return true; diff --git a/src/transformer.rs b/src/transformer.rs index e6c29797..f9ec3204 100644 --- a/src/transformer.rs +++ b/src/transformer.rs @@ -10,7 +10,7 @@ use crate::{ Replacement, }, printer::{AtomPrinter, PrintOptions}, - state::{RecycledAtom, State, Workspace}, + state::{RecycledAtom, Workspace}, }; use ahash::HashMap; use colored::Colorize; @@ -238,7 +238,7 @@ impl FunView<'_> { #[inline(always)] fn add_arg(f: &mut Fun, a: AtomView) { if let AtomView::Fun(fa) = a { - if fa.get_symbol() == State::ARG { + if fa.get_symbol() == Atom::ARG { // flatten f(arg(...)) = f(...) for aa in fa.iter() { f.add_arg(aa); @@ -502,9 +502,9 @@ impl Transformer { } Transformer::ForEach(t) => { if let AtomView::Fun(f) = cur_input { - if f.get_symbol() == State::ARG { + if f.get_symbol() == Atom::ARG { let mut ff = workspace.new_atom(); - let ff = ff.to_fun(State::ARG); + let ff = ff.to_fun(Atom::ARG); let mut a = workspace.new_atom(); for arg in f { @@ -598,7 +598,7 @@ impl Transformer { } Transformer::Product => { if let AtomView::Fun(f) = cur_input { - if f.get_symbol() == State::ARG { + if f.get_symbol() == Atom::ARG { let mut mul_h = workspace.new_atom(); let mul = mul_h.to_mul(); @@ -615,7 +615,7 @@ impl Transformer { } Transformer::Sum => { if let AtomView::Fun(f) = cur_input { - if f.get_symbol() == State::ARG { + if f.get_symbol() == Atom::ARG { let mut add_h = workspace.new_atom(); let add = add_h.to_add(); @@ -632,7 +632,7 @@ impl Transformer { } Transformer::ArgCount(only_for_arg_fun) => { if let AtomView::Fun(f) = cur_input { - if !*only_for_arg_fun || f.get_symbol() == State::ARG { + if !*only_for_arg_fun || f.get_symbol() == Atom::ARG { let n_args = f.get_nargs(); out.to_num((n_args as i64).into()); } else { @@ -654,7 +654,7 @@ impl Transformer { Transformer::Split => match cur_input { AtomView::Mul(m) => { let mut arg_h = workspace.new_atom(); - let arg = arg_h.to_fun(State::ARG); + let arg = arg_h.to_fun(Atom::ARG); for factor in m { arg.add_arg(factor); @@ -664,7 +664,7 @@ impl Transformer { } AtomView::Add(a) => { let mut arg_h = workspace.new_atom(); - let arg = arg_h.to_fun(State::ARG); + let arg = arg_h.to_fun(Atom::ARG); for summand in a { arg.add_arg(summand); @@ -678,7 +678,7 @@ impl Transformer { }, Transformer::Partition(bins, fill_last, repeat) => { if let AtomView::Fun(f) = cur_input { - if f.get_symbol() == State::ARG { + if f.get_symbol() == Atom::ARG { let args: Vec<_> = f.iter().collect(); let mut sum_h = workspace.new_atom(); @@ -721,12 +721,12 @@ impl Transformer { } Transformer::Sort => { if let AtomView::Fun(f) = cur_input { - if f.get_symbol() == State::ARG { + if f.get_symbol() == Atom::ARG { let mut args: Vec<_> = f.iter().collect(); args.sort(); let mut fun_h = workspace.new_atom(); - let fun = fun_h.to_fun(State::ARG); + let fun = fun_h.to_fun(Atom::ARG); for arg in args { fun.add_arg(arg); @@ -774,7 +774,7 @@ impl Transformer { } Transformer::Deduplicate => { if let AtomView::Fun(f) = cur_input { - if f.get_symbol() == State::ARG { + if f.get_symbol() == Atom::ARG { let args: Vec<_> = f.iter().collect(); let mut args_dedup: Vec<_> = Vec::with_capacity(args.len()); @@ -786,7 +786,7 @@ impl Transformer { } let mut fun_h = workspace.new_atom(); - let fun = fun_h.to_fun(State::ARG); + let fun = fun_h.to_fun(Atom::ARG); for arg in args_dedup { fun.add_arg(arg); @@ -801,7 +801,7 @@ impl Transformer { } Transformer::Permutations(f_name) => { if let AtomView::Fun(f) = cur_input { - if f.get_symbol() == State::ARG { + if f.get_symbol() == Atom::ARG { let args: Vec<_> = f.iter().collect(); let mut sum_h = workspace.new_atom(); From fe10bf2dff349a8c739f5b7b25fc4179ab73f14d Mon Sep 17 00:00:00 2001 From: Ben Ruijl Date: Thu, 12 Dec 2024 15:37:34 +0100 Subject: [PATCH 2/7] Add zero test for expressions - Add test to see if expression is a polynomial - Add custom normalization for atom fields - Collect numbers from divisions --- src/collect.rs | 37 ++++++++- src/domains/atom.rs | 92 +++++++++++++++++++--- src/domains/float.rs | 16 ++++ src/evaluate.rs | 179 ++++++++++++++++++++++++++++++++++++++++++- src/id.rs | 170 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 480 insertions(+), 14 deletions(-) diff --git a/src/collect.rs b/src/collect.rs index 0e2b913d..cd048c3b 100644 --- a/src/collect.rs +++ b/src/collect.rs @@ -549,7 +549,23 @@ impl<'a> AtomView<'a> { None } } - AtomView::Pow(_) | AtomView::Var(_) | AtomView::Fun(_) => None, + AtomView::Pow(p) => { + let (b, e) = p.get_base_exp(); + if let Ok(e) = i64::try_from(e) { + if let Some(n) = get_num(b) { + if let Coefficient::Rational(r) = n { + if e < 0 { + return Some(r.pow((-e) as u64).inv().into()); + } else { + return Some(r.pow(e as u64).into()); + } + } + } + } + + None + } + AtomView::Var(_) | AtomView::Fun(_) => None, } } @@ -609,6 +625,25 @@ impl<'a> AtomView<'a> { changed } + AtomView::Pow(p) => { + let (b, e) = p.get_base_exp(); + + let mut changed = false; + let mut nb = ws.new_atom(); + changed |= b.collect_num_impl(ws, &mut nb); + let mut ne = ws.new_atom(); + changed |= e.collect_num_impl(ws, &mut ne); + + if !changed { + out.set_from_view(self); + } else { + let mut np = ws.new_atom(); + np.to_pow(nb.as_view(), ne.as_view()); + np.as_view().normalize(ws, out); + } + + changed + } _ => { out.set_from_view(self); false diff --git a/src/domains/atom.rs b/src/domains/atom.rs index 85da7dcd..71384881 100644 --- a/src/domains/atom.rs +++ b/src/domains/atom.rs @@ -7,10 +7,33 @@ use super::{ integer::Integer, Derivable, EuclideanDomain, Field, InternalOrdering, Ring, SelfRing, }; +use dyn_clone::DynClone; use rand::Rng; -#[derive(Clone, Copy, PartialEq, Eq, Hash)] -pub struct AtomField {} +pub trait Map: Fn(AtomView, &mut Atom) -> bool + DynClone + Send + Sync {} +dyn_clone::clone_trait_object!(Map); +impl, &mut Atom) -> bool> Map for T {} + +/// The field of general expressions. +#[derive(Clone)] +pub struct AtomField { + /// Perform a cancellation check of numerators and denominators after a division. + pub cancel_check_on_division: bool, + /// A custom normalization function applied after every operation. + pub custom_normalization: Option>, +} + +impl PartialEq for AtomField { + fn eq(&self, _other: &Self) -> bool { + true + } +} + +impl Eq for AtomField {} + +impl std::hash::Hash for AtomField { + fn hash(&self, _state: &mut H) {} +} impl Default for AtomField { fn default() -> Self { @@ -20,7 +43,34 @@ impl Default for AtomField { impl AtomField { pub fn new() -> AtomField { - AtomField {} + AtomField { + custom_normalization: None, + cancel_check_on_division: false, + } + } + + #[inline(always)] + fn normalize(&self, r: Atom) -> Atom { + if let Some(f) = &self.custom_normalization { + let mut res = Atom::new(); + if f(r.as_view(), &mut res) { + res + } else { + r + } + } else { + r + } + } + + #[inline(always)] + fn normalize_mut(&self, r: &mut Atom) { + if let Some(f) = &self.custom_normalization { + let mut res = Atom::new(); + if f(r.as_view(), &mut res) { + std::mem::swap(r, &mut res); + } + } } } @@ -46,39 +96,44 @@ impl Ring for AtomField { type Element = Atom; fn add(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { - a + b + self.normalize(a + b) } fn sub(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { - a - b + self.normalize(a - b) } fn mul(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { - a * b + self.normalize(a * b) } fn add_assign(&self, a: &mut Self::Element, b: &Self::Element) { *a = &*a + b; + self.normalize_mut(a); } fn sub_assign(&self, a: &mut Self::Element, b: &Self::Element) { *a = &*a - b; + self.normalize_mut(a); } fn mul_assign(&self, a: &mut Self::Element, b: &Self::Element) { *a = self.mul(a, b); + self.normalize_mut(a); } fn add_mul_assign(&self, a: &mut Self::Element, b: &Self::Element, c: &Self::Element) { *a = &*a + self.mul(b, c); + self.normalize_mut(a); } fn sub_mul_assign(&self, a: &mut Self::Element, b: &Self::Element, c: &Self::Element) { *a = &*a - self.mul(b, c); + self.normalize_mut(a); } fn neg(&self, a: &Self::Element) -> Self::Element { - -a + self.normalize(-a) } fn zero(&self) -> Self::Element { @@ -90,11 +145,12 @@ impl Ring for AtomField { } fn pow(&self, b: &Self::Element, e: u64) -> Self::Element { - b.npow(Integer::from(e)) + self.normalize(b.npow(Integer::from(e))) } + /// Check if the result could be 0 using a statistical method. fn is_zero(a: &Self::Element) -> bool { - a.is_zero() + !a.as_view().zero_test(10, f64::EPSILON).is_false() } fn is_one(&self, a: &Self::Element) -> bool { @@ -162,7 +218,7 @@ impl EuclideanDomain for AtomField { } fn quot_rem(&self, a: &Self::Element, b: &Self::Element) -> (Self::Element, Self::Element) { - (a / b, self.zero()) + (self.div(a, b), self.zero()) } fn gcd(&self, _a: &Self::Element, _b: &Self::Element) -> Self::Element { @@ -173,16 +229,28 @@ impl EuclideanDomain for AtomField { impl Field for AtomField { fn div(&self, a: &Self::Element, b: &Self::Element) -> Self::Element { - a / b + let r = a / b; + + self.normalize(if self.cancel_check_on_division { + r.cancel() + } else { + r + }) } fn div_assign(&self, a: &mut Self::Element, b: &Self::Element) { *a = self.div(a, b); + + if self.cancel_check_on_division { + *a = a.cancel(); + } + + self.normalize_mut(a); } fn inv(&self, a: &Self::Element) -> Self::Element { let one = Atom::new_num(1); - self.div(&one, a) + self.normalize(self.div(&one, a)) } } diff --git a/src/domains/float.rs b/src/domains/float.rs index d7fb1e67..c7cea6f2 100644 --- a/src/domains/float.rs +++ b/src/domains/float.rs @@ -1993,6 +1993,22 @@ pub struct ErrorPropagatingFloat { abs_err: f64, } +impl From for ErrorPropagatingFloat { + fn from(value: f64) -> Self { + if value == 0. { + ErrorPropagatingFloat { + value, + abs_err: f64::EPSILON, + } + } else { + ErrorPropagatingFloat { + value, + abs_err: f64::EPSILON * value.abs(), + } + } + } +} + impl Neg for ErrorPropagatingFloat { type Output = Self; diff --git a/src/evaluate.rs b/src/evaluate.rs index ba176cc2..48ad50dd 100644 --- a/src/evaluate.rs +++ b/src/evaluate.rs @@ -17,10 +17,13 @@ use crate::{ coefficient::CoefficientView, combinatorics::unique_permutations, domains::{ - float::{Complex, NumericalFloatLike, Real}, + float::{ + Complex, ErrorPropagatingFloat, NumericalFloatLike, Real, RealNumberLike, SingleFloat, + }, integer::Integer, rational::Rational, }, + id::ConditionResult, state::State, LicenseManager, }; @@ -237,6 +240,12 @@ impl Atom { optimization_settings.verbose, )) } + + /// Check if the expression could be 0, using (potentially) numerical sampling with + /// a given tolerance and number of iterations. + pub fn zero_test(&self, iterations: usize, tolerance: f64) -> ConditionResult { + self.as_view().zero_test(iterations, tolerance) + } } #[derive(Debug, Clone)] @@ -4155,6 +4164,164 @@ impl<'a> AtomView<'a> { } } } + + /// Check if the expression could be 0, using (potentially) numerical sampling with + /// a given tolerance and number of iterations. + pub fn zero_test(&self, iterations: usize, tolerance: f64) -> ConditionResult { + match self { + AtomView::Num(num_view) => { + if num_view.is_zero() { + ConditionResult::True + } else { + ConditionResult::False + } + } + AtomView::Var(_) => ConditionResult::False, + AtomView::Fun(_) => ConditionResult::False, + AtomView::Pow(p) => p.get_base().zero_test(iterations, tolerance), + AtomView::Mul(mul_view) => { + let mut is_zero = ConditionResult::False; + for arg in mul_view { + match arg.zero_test(iterations, tolerance) { + ConditionResult::True => return ConditionResult::True, + ConditionResult::False => {} + ConditionResult::Inconclusive => { + is_zero = ConditionResult::Inconclusive; + } + } + } + + is_zero + } + AtomView::Add(_) => { + // an expanded polynomial is only zero if it is a literal zero + if self.is_polynomial(false, true).is_some() { + ConditionResult::False + } else { + self.zero_test_impl(iterations, tolerance) + } + } + } + } + + fn zero_test_impl(&self, iterations: usize, tolerance: f64) -> ConditionResult { + // collect all variables and functions and fill in random variables + + let mut rng = rand::thread_rng(); + + if self.contains_symbol(State::I) { + let mut vars: HashMap<_, _> = self + .get_all_indeterminates(true) + .into_iter() + .filter_map(|x| { + let s = x.get_symbol().unwrap(); + if !State::is_builtin(s) || s == Atom::DERIVATIVE { + Some((x, Complex::new(0f64.into(), 0f64.into()))) + } else { + None + } + }) + .collect(); + + let mut cache = HashMap::default(); + + for _ in 0..iterations { + cache.clear(); + + for x in vars.values_mut() { + *x = x.sample_unit(&mut rng); + } + + let r = self + .evaluate( + |x| { + Complex::new( + ErrorPropagatingFloat::new( + 0f64.from_rational(x), + -0f64.get_epsilon().log10(), + ), + ErrorPropagatingFloat::new( + 0f64.zero(), + -0f64.get_epsilon().log10(), + ), + ) + }, + &vars, + &HashMap::default(), + &mut cache, + ) + .unwrap(); + + let res_re = r.re.get_num().to_f64(); + let res_im = r.im.get_num().to_f64(); + if res_re.is_finite() + && (res_re - r.re.get_absolute_error() > 0. + || res_re + r.re.get_absolute_error() < 0.) + || res_im.is_finite() + && (res_im - r.im.get_absolute_error() > 0. + || res_im + r.im.get_absolute_error() < 0.) + { + return ConditionResult::False; + } + + if vars.len() == 0 && r.re.get_absolute_error() < tolerance { + return ConditionResult::True; + } + } + + ConditionResult::Inconclusive + } else { + let mut vars: HashMap<_, ErrorPropagatingFloat> = self + .get_all_indeterminates(true) + .into_iter() + .filter_map(|x| { + let s = x.get_symbol().unwrap(); + if !State::is_builtin(s) || s == Atom::DERIVATIVE { + Some((x, 0f64.into())) + } else { + None + } + }) + .collect(); + + let mut cache = HashMap::default(); + + for _ in 0..iterations { + cache.clear(); + + for x in vars.values_mut() { + *x = x.sample_unit(&mut rng); + } + + let r = self + .evaluate( + |x| { + ErrorPropagatingFloat::new( + 0f64.from_rational(x), + -0f64.get_epsilon().log10(), + ) + }, + &vars, + &HashMap::default(), + &mut cache, + ) + .unwrap(); + + let res = r.get_num().to_f64(); + if res.is_finite() + && (res - r.get_absolute_error() > 0. || res + r.get_absolute_error() < 0.) + { + return ConditionResult::False; + } + + if vars.len() == 0 && r.get_absolute_error() < tolerance { + return ConditionResult::True; + } + } + + ConditionResult::Inconclusive + } + } } #[cfg(test)] @@ -4165,6 +4332,7 @@ mod test { atom::Atom, domains::{float::Float, rational::Rational}, evaluate::{EvaluationFn, FunctionMap, OptimizationSettings}, + id::ConditionResult, state::State, }; @@ -4304,4 +4472,13 @@ mod test { let r = e_f64.evaluate_single(&[1.1]); assert!((r - 1622709.2254269677).abs() / 1622709.2254269677 < 1e-10); } + + #[test] + fn zero_test() { + let e = Atom::parse("(sin(v1)^2-sin(v1))(sin(v1)^2+sin(v1))^2 - (1/4 sin(2v1)^2-1/2 sin(2v1)cos(v1)-2 cos(v1)^2+1/2 sin(2v1)cos(v1)^3+3 cos(v1)^4-cos(v1)^6)").unwrap(); + assert_eq!(e.zero_test(10, f64::EPSILON), ConditionResult::Inconclusive); + + let e = Atom::parse("x + (1+x)^2 + (x+2)*5").unwrap(); + assert_eq!(e.zero_test(10, f64::EPSILON), ConditionResult::False); + } } diff --git a/src/id.rs b/src/id.rs index b00795f0..fd92a636 100644 --- a/src/id.rs +++ b/src/id.rs @@ -122,6 +122,21 @@ impl Atom { self.as_view().contains(s.as_atom_view()) } + /// Check if the expression can be considered a polynomial in some variables, including + /// redefinitions. For example `f(x)+y` is considered a polynomial in `f(x)` and `y`, whereas + /// `f(x)+x` is not a polynomial. + /// + /// Rational powers or powers in variables are not rewritten, e.g. `x^(2y)` is not considered + /// polynomial in `x^y`. + pub fn is_polynomial( + &self, + allow_not_expanded: bool, + allow_negative_powers: bool, + ) -> Option>> { + self.as_view() + .is_polynomial(allow_not_expanded, allow_negative_powers) + } + /// Replace all occurrences of the pattern. pub fn replace_all( &self, @@ -344,6 +359,150 @@ impl<'a> AtomView<'a> { false } + /// Check if the expression can be considered a polynomial in some variables, including + /// redefinitions. For example `f(x)+y` is considered a polynomial in `f(x)` and `y`, whereas + /// `f(x)+x` is not a polynomial. + /// + /// Rational powers or powers in variables are not rewritten, e.g. `x^(2y)` is not considered + /// polynomial in `x^y`. + pub fn is_polynomial( + &self, + allow_not_expanded: bool, + allow_negative_powers: bool, + ) -> Option>> { + let mut vars = HashMap::default(); + let mut symbol_cache = HashSet::default(); + if self.is_polynomial_impl( + allow_not_expanded, + allow_negative_powers, + &mut vars, + &mut symbol_cache, + ) { + symbol_cache.clear(); + for (k, v) in vars { + if v { + symbol_cache.insert(k); + } + } + + Some(symbol_cache) + } else { + None + } + } + + fn is_polynomial_impl( + &self, + allow_not_expanded: bool, + allow_negative_powers: bool, + variables: &mut HashMap, bool>, + symbol_cache: &mut HashSet>, + ) -> bool { + if let Some(x) = variables.get(self) { + return *x; + } + + macro_rules! block_check { + ($e: expr) => { + symbol_cache.clear(); + $e.get_all_indeterminates_impl(true, symbol_cache); + for x in symbol_cache.drain() { + if variables.contains_key(&x) { + return false; + } else { + variables.insert(x, false); // disallow at any level + } + } + + variables.insert(*$e, true); // overwrites block above + }; + } + + match self { + AtomView::Num(_) => true, + AtomView::Var(_) => { + variables.insert(*self, true); + true + } + AtomView::Fun(_) => { + block_check!(self); + true + } + AtomView::Pow(pow_view) => { + // x^y is allowed if x and y do not appear elsewhere + let (base, exp) = pow_view.get_base_exp(); + + if let AtomView::Num(_) = exp { + let (positive, integer) = if let Ok(k) = i64::try_from(exp) { + (k >= 0, true) + } else { + (false, false) + }; + + if integer && (allow_negative_powers || positive) { + if variables.get(&base) == Some(&true) { + return true; + } + + if allow_not_expanded && positive { + // do not consider (x+y)^-2 a polynomial in x and y + return base.is_polynomial_impl( + allow_not_expanded, + allow_negative_powers, + variables, + symbol_cache, + ); + } + + // turn the base into a variable + block_check!(&base); + return true; + } + } + + block_check!(self); + true + } + AtomView::Mul(mul_view) => { + for child in mul_view { + if !allow_not_expanded { + if let AtomView::Add(_) = child { + if variables.get(&child) == Some(&true) { + continue; + } + + block_check!(&child); + continue; + } + } + + if !child.is_polynomial_impl( + allow_not_expanded, + allow_negative_powers, + variables, + symbol_cache, + ) { + return false; + } + } + true + } + AtomView::Add(add_view) => { + for child in add_view { + if !child.is_polynomial_impl( + allow_not_expanded, + allow_negative_powers, + variables, + symbol_cache, + ) { + return false; + } + } + true + } + } + } + /// Replace part of an expression by calling the map `m` on each subexpression. /// The function `m` must return `true` if the expression was replaced and must write the new expression to `out`. /// A [Context] object is passed to the function, which contains information about the current position in the expression. @@ -3650,4 +3809,15 @@ mod test { let expr = p.replace_all(expr.as_view(), &rhs, None, None); assert_eq!(expr, Atom::new_num(1)); } + + #[test] + fn is_polynomial() { + let e = Atom::parse("v1^2 + (1+v5)^3 / v1 + (1+v3)*(1+v4)^v7 + v1^2 + (v1+v2)^3").unwrap(); + let vars = e.as_view().is_polynomial(true, true).unwrap(); + assert_eq!(vars.len(), 5); + + let e = Atom::parse("(1+v5)^(3/2) / v6 + (1+v3)*(1+v4)^v7 + (v1+v2)^3").unwrap(); + let vars = e.as_view().is_polynomial(false, false).unwrap(); + assert_eq!(vars.len(), 5); + } } From 70338e7e9903e2dce063e238ff494edc9d4b93b4 Mon Sep 17 00:00:00 2001 From: Ben Ruijl Date: Thu, 12 Dec 2024 17:55:14 +0100 Subject: [PATCH 3/7] Add conversion from system of atoms to matrix - Add matrix augmentation function - Expose Gaussian elimination methods - Improved AsAtomView trait --- examples/collect.rs | 8 +-- examples/solve_linear_system.rs | 13 ++-- src/api/python.rs | 10 +-- src/atom.rs | 47 +++++++----- src/atom/representation.rs | 2 + src/collect.rs | 69 +++++++++--------- src/id.rs | 2 +- src/solve.rs | 122 ++++++++++++++++++++++++++------ src/tensors/matrix.rs | 39 ++++++---- src/transformer.rs | 2 +- 10 files changed, 213 insertions(+), 101 deletions(-) diff --git a/examples/collect.rs b/examples/collect.rs index b1c9a2ef..01d7133c 100644 --- a/examples/collect.rs +++ b/examples/collect.rs @@ -2,11 +2,11 @@ use symbolica::{atom::Atom, fun, state::State}; fn main() { let input = Atom::parse("x*(1+a)+x*5*y+f(5,x)+2+y^2+x^2 + x^3").unwrap(); - let x = State::get_symbol("x").into(); + let x = Atom::new_var(State::get_symbol("x")); let key = State::get_symbol("key"); let coeff = State::get_symbol("val"); - let r = input.coefficient_list::(std::slice::from_ref(&x)); + let r = input.coefficient_list::(std::slice::from_ref(&x)); println!("> Coefficient list:"); for (key, val) in r { @@ -14,7 +14,7 @@ fn main() { } println!("> Collect in x:"); - let out = input.collect::( + let out = input.collect::( &x, Some(Box::new(|x, out| { out.set_from_view(&x); @@ -24,7 +24,7 @@ fn main() { println!("\t{}", out); println!("> Collect in x with wrapping:"); - let out = input.collect::( + let out = input.collect::( &x, Some(Box::new(move |a, out| { out.set_from_view(&a); diff --git a/examples/solve_linear_system.rs b/examples/solve_linear_system.rs index 4ba3992a..b08b5ae9 100644 --- a/examples/solve_linear_system.rs +++ b/examples/solve_linear_system.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use symbolica::{ - atom::{Atom, AtomView}, + atom::{representation::InlineVar, Atom, AtomView}, domains::{ integer::Z, rational::Q, @@ -13,15 +13,14 @@ use symbolica::{ }; fn solve() { - let x = State::get_symbol("x"); - let y = State::get_symbol("y"); - let z = State::get_symbol("z"); + let x = State::get_symbol("x").into(); + let y = State::get_symbol("y").into(); + let z = State::get_symbol("z").into(); let eqs = ["c*x + f(c)*y + z - 1", "x + c*y + z/c - 2", "(c-1)x + c*z"]; - let atoms: Vec<_> = eqs.iter().map(|e| Atom::parse(e).unwrap()).collect(); - let system: Vec<_> = atoms.iter().map(|x| x.as_view()).collect(); + let system: Vec<_> = eqs.iter().map(|e| Atom::parse(e).unwrap()).collect(); - let sol = AtomView::solve_linear_system::(&system, &[x, y, z]).unwrap(); + let sol = AtomView::solve_linear_system::(&system, &[x, y, z]).unwrap(); for (v, s) in ["x", "y", "z"].iter().zip(&sol) { println!("{} = {}", v, s); diff --git a/src/api/python.rs b/src/api/python.rs index 816043a5..b8744dbb 100644 --- a/src/api/python.rs +++ b/src/api/python.rs @@ -3907,7 +3907,7 @@ impl PythonExpression { for a in x { if let Ok(r) = a.extract::() { if matches!(r.expr, Atom::Var(_) | Atom::Fun(_)) { - xs.push(r.expr.into()); + xs.push(r.expr); } else { return Err(exceptions::PyValueError::new_err( "Collect must be done wrt a variable or function", @@ -3920,7 +3920,7 @@ impl PythonExpression { } } - let b = self.expr.collect_multiple::( + let b = self.expr.collect_multiple::( &Arc::new(xs), if let Some(key_map) = key_map { Some(Box::new(move |key, out| { @@ -4038,7 +4038,7 @@ impl PythonExpression { for a in x { if let Ok(r) = a.extract::() { if matches!(r.expr, Atom::Var(_) | Atom::Fun(_)) { - xs.push(r.expr.into()); + xs.push(r.expr); } else { return Err(exceptions::PyValueError::new_err( "Collect must be done wrt a variable or function", @@ -4051,7 +4051,7 @@ impl PythonExpression { } } - let list = self.expr.coefficient_list::(&xs); + let list = self.expr.coefficient_list::(&xs); let py_list: Vec<_> = list .into_iter() @@ -4700,7 +4700,7 @@ impl PythonExpression { } } - let res = AtomView::solve_linear_system::(&system_b, &vars).map_err(|e| { + let res = AtomView::solve_linear_system::(&system_b, &vars).map_err(|e| { exceptions::PyValueError::new_err(format!("Could not solve system: {}", e)) })?; diff --git a/src/atom.rs b/src/atom.rs index 10bb9ebe..76e26e53 100644 --- a/src/atom.rs +++ b/src/atom.rs @@ -171,6 +171,13 @@ impl std::fmt::Display for AtomView<'_> { } } +impl From for Atom { + /// Convert a symbol to an atom. This will allocate memory. + fn from(symbol: Symbol) -> Atom { + Atom::new_var(symbol) + } +} + impl<'a> From> for AtomView<'a> { fn from(n: NumView<'a>) -> AtomView<'a> { AtomView::Num(n) @@ -308,34 +315,40 @@ impl<'a> AtomOrView<'a> { /// A trait for any type that can be converted into an `AtomView`. /// To be used for functions that accept any argument that can be /// converted to an `AtomView`. -pub trait AsAtomView<'a>: Copy + Sized { - fn as_atom_view(self) -> AtomView<'a>; +pub trait AsAtomView { + fn as_atom_view(&self) -> AtomView; } -impl<'a> AsAtomView<'a> for AtomView<'a> { - fn as_atom_view(self) -> AtomView<'a> { - self +impl<'a> AsAtomView for AtomView<'a> { + fn as_atom_view(&self) -> AtomView<'a> { + *self } } -impl<'a> AsAtomView<'a> for &'a InlineVar { - fn as_atom_view(self) -> AtomView<'a> { +impl AsAtomView for InlineVar { + fn as_atom_view(&self) -> AtomView { self.as_view() } } -impl<'a> AsAtomView<'a> for &'a InlineNum { - fn as_atom_view(self) -> AtomView<'a> { +impl AsAtomView for InlineNum { + fn as_atom_view(&self) -> AtomView { self.as_view() } } -impl<'a, T: AsRef> AsAtomView<'a> for &'a T { - fn as_atom_view(self) -> AtomView<'a> { +impl> AsAtomView for T { + fn as_atom_view(&self) -> AtomView { self.as_ref().as_view() } } +impl<'a> AsAtomView for AtomOrView<'a> { + fn as_atom_view(&self) -> AtomView { + self.as_view() + } +} + impl AsRef for Atom { fn as_ref(&self) -> &Atom { self @@ -872,7 +885,7 @@ impl FunctionBuilder { } /// Add an argument to the function. - pub fn add_arg<'b, T: AsAtomView<'b>>(mut self, arg: T) -> FunctionBuilder { + pub fn add_arg(mut self, arg: T) -> FunctionBuilder { if let Atom::Fun(f) = self.handle.deref_mut() { f.add_arg(arg.as_atom_view()); } @@ -881,7 +894,7 @@ impl FunctionBuilder { } /// Add multiple arguments to the function. - pub fn add_args<'b, T: AsAtomView<'b>>(mut self, args: &[T]) -> FunctionBuilder { + pub fn add_args(mut self, args: &[T]) -> FunctionBuilder { if let Atom::Fun(f) = self.handle.deref_mut() { for a in args { f.add_arg(a.as_atom_view()); @@ -1010,7 +1023,7 @@ impl Atom { } /// Take the `self` to the power `exp`. Use [`Atom::npow()`] for a numerical power and [`Atom::rpow()`] for the reverse operation. - pub fn pow<'a, T: AsAtomView<'a>>(&self, exp: T) -> Atom { + pub fn pow(&self, exp: T) -> Atom { Workspace::get_local().with(|ws| { let mut t = ws.new_atom(); self.as_view() @@ -1022,7 +1035,7 @@ impl Atom { } /// Take `base` to the power `self`. - pub fn rpow<'a, T: AsAtomView<'a>>(&self, base: T) -> Atom { + pub fn rpow(&self, base: T) -> Atom { Workspace::get_local().with(|ws| { let mut t = ws.new_atom(); base.as_atom_view() @@ -1034,7 +1047,7 @@ impl Atom { } /// Add the atoms in `args`. - pub fn add_many<'a, T: AsAtomView<'a> + Copy>(args: &[T]) -> Atom { + pub fn add_many<'a, T: AsAtomView + Copy>(args: &[T]) -> Atom { let mut out = Atom::new(); Workspace::get_local().with(|ws| { let mut t = ws.new_atom(); @@ -1049,7 +1062,7 @@ impl Atom { } /// Multiply the atoms in `args`. - pub fn mul_many<'a, T: AsAtomView<'a> + Copy>(args: &[T]) -> Atom { + pub fn mul_many<'a, T: AsAtomView + Copy>(args: &[T]) -> Atom { let mut out = Atom::new(); Workspace::get_local().with(|ws| { let mut t = ws.new_atom(); diff --git a/src/atom/representation.rs b/src/atom/representation.rs index 82a393f8..402c0c3c 100644 --- a/src/atom/representation.rs +++ b/src/atom/representation.rs @@ -40,6 +40,7 @@ const ZERO_DATA: [u8; 3] = [NUM_ID, 1, 0]; pub type RawAtom = Vec; /// An inline variable. +#[derive(Copy, Clone)] pub struct InlineVar { data: [u8; 16], size: u8, @@ -101,6 +102,7 @@ impl From for InlineVar { } /// An inline rational number that has 64-bit components. +#[derive(Copy, Clone)] pub struct InlineNum { data: [u8; 24], size: u8, diff --git a/src/collect.rs b/src/collect.rs index cd048c3b..c3ec1cc2 100644 --- a/src/collect.rs +++ b/src/collect.rs @@ -1,5 +1,5 @@ use crate::{ - atom::{Add, AsAtomView, Atom, AtomOrView, AtomView, Symbol}, + atom::{Add, AsAtomView, Atom, AtomView, Symbol}, coefficient::{Coefficient, CoefficientView}, domains::{integer::Z, rational::Q}, poly::{factor::Factorize, polynomial::MultivariatePolynomial, Exponent}, @@ -16,13 +16,13 @@ impl Atom { /// /// Both the *key* (the quantity collected in) and its coefficient can be mapped using /// `key_map` and `coeff_map` respectively. - pub fn collect( + pub fn collect( &self, - x: &AtomOrView, + x: T, key_map: Option>, coeff_map: Option>, ) -> Atom { - self.as_view().collect::(x, key_map, coeff_map) + self.as_view().collect::(x, key_map, coeff_map) } /// Collect terms involving the same power of `x`, where `x` is a variable or function, e.g. @@ -33,23 +33,24 @@ impl Atom { /// /// Both the *key* (the quantity collected in) and its coefficient can be mapped using /// `key_map` and `coeff_map` respectively. - pub fn collect_multiple( + pub fn collect_multiple( &self, - xs: &[AtomOrView], + xs: &[T], key_map: Option>, coeff_map: Option>, ) -> Atom { - self.as_view().collect_multiple::(xs, key_map, coeff_map) + self.as_view() + .collect_multiple::(xs, key_map, coeff_map) } /// Collect terms involving the same power of `x` in `xs`, where `xs` is a list of indeterminates. /// Return the list of key-coefficient pairs - pub fn coefficient_list(&self, xs: &[AtomOrView]) -> Vec<(Atom, Atom)> { - self.as_view().coefficient_list::(xs) + pub fn coefficient_list(&self, xs: &[T]) -> Vec<(Atom, Atom)> { + self.as_view().coefficient_list::(xs) } /// Collect terms involving the literal occurrence of `x`. - pub fn coefficient<'a, T: AsAtomView<'a>>(&self, x: T) -> Atom { + pub fn coefficient(&self, x: T) -> Atom { Workspace::get_local().with(|ws| self.as_view().coefficient_with_ws(x.as_atom_view(), ws)) } @@ -92,36 +93,36 @@ impl<'a> AtomView<'a> { /// /// Both the *key* (the quantity collected in) and its coefficient can be mapped using /// `key_map` and `coeff_map` respectively. - pub fn collect( + pub fn collect( &self, - x: &AtomOrView, + x: T, key_map: Option>, coeff_map: Option>, ) -> Atom { - self.collect_multiple::(std::slice::from_ref(x), key_map, coeff_map) + self.collect_multiple::(std::slice::from_ref(&x), key_map, coeff_map) } - pub fn collect_multiple( + pub fn collect_multiple( &self, - xs: &[AtomOrView], + xs: &[T], key_map: Option>, coeff_map: Option>, ) -> Atom { let mut out = Atom::new(); Workspace::get_local() - .with(|ws| self.collect_multiple_impl::(xs, ws, key_map, coeff_map, &mut out)); + .with(|ws| self.collect_multiple_impl::(xs, ws, key_map, coeff_map, &mut out)); out } - pub fn collect_multiple_impl( + pub fn collect_multiple_impl( &self, - xs: &[AtomOrView], + xs: &[T], ws: &Workspace, key_map: Option>, coeff_map: Option>, out: &mut Atom, ) { - let r = self.coefficient_list::(xs); + let r = self.coefficient_list::(xs); let mut add_h = Atom::new(); let add = add_h.to_add(); @@ -165,10 +166,10 @@ impl<'a> AtomView<'a> { /// Collect terms involving the same powers of `x` in `xs`, where `x` is an indeterminate. /// Return the list of key-coefficient pairs. - pub fn coefficient_list(&self, xs: &[AtomOrView]) -> Vec<(Atom, Atom)> { + pub fn coefficient_list(&self, xs: &[T]) -> Vec<(Atom, Atom)> { let vars = xs .iter() - .map(|x| x.as_view().to_owned().into()) + .map(|x| x.as_atom_view().to_owned().into()) .collect::>(); let p = self.to_polynomial_in_vars::(&Arc::new(vars)); @@ -179,7 +180,7 @@ impl<'a> AtomView<'a> { for (p, v) in t.exponents.iter().zip(xs) { let mut pow = Atom::new(); - pow.to_pow(v.as_view(), Atom::new_num(p.to_i32() as i64).as_view()); + pow.to_pow(v.as_atom_view(), Atom::new_num(p.to_i32() as i64).as_view()); key = key * pow; } @@ -654,7 +655,11 @@ impl<'a> AtomView<'a> { #[cfg(test)] mod test { - use crate::{atom::Atom, fun, state::State}; + use crate::{ + atom::{representation::InlineVar, Atom}, + fun, + state::State, + }; #[test] fn collect_num() { @@ -679,7 +684,7 @@ mod test { let input = Atom::parse("v1*(1+v3)+v1*5*v2+f1(5,v1)+2+v2^2+v1^2+v1^3").unwrap(); let x = State::get_symbol("v1"); - let r = input.coefficient_list::(&[x.into()]); + let r = input.coefficient_list::(&[x.into()]); let res = vec![ ( @@ -702,7 +707,7 @@ mod test { let input = Atom::parse("v1*(1+v3)+v1*5*v2+f1(5,v1)+2+v2^2+v1^2+v1^3").unwrap(); let x = State::get_symbol("v1"); - let out = input.collect::(&x.into(), None, None); + let out = input.collect::(x.into(), None, None); let ref_out = Atom::parse("v1^2+v1^3+v2^2+f1(5,v1)+v1*(5*v2+v3+1)+2").unwrap(); assert_eq!(out, ref_out) @@ -713,7 +718,7 @@ mod test { let input = Atom::parse("(1+v1)^2*v1+(1+v2)^100").unwrap(); let x = State::get_symbol("v1"); - let out = input.collect::(&x.into(), None, None); + let out = input.collect::(x.into(), None, None); let ref_out = Atom::parse("v1+2*v1^2+v1^3+(v2+1)^100").unwrap(); assert_eq!(out, ref_out) @@ -726,8 +731,8 @@ mod test { let key = State::get_symbol("f3"); let coeff = State::get_symbol("f4"); println!("> Collect in x with wrapping:"); - let out = input.collect::( - &x.into(), + let out = input.collect::( + x.into(), Some(Box::new(move |a, out| { out.set_from_view(&a); *out = fun!(key, out); @@ -801,10 +806,10 @@ mod test { ) .unwrap(); - let out = input.as_view().coefficient_list::(&[ - State::get_symbol("v1").into(), - State::get_symbol("v2").into(), - Atom::parse("v5(1,2,3)").unwrap().into(), + let out = input.as_view().coefficient_list::(&[ + Atom::new_var(State::get_symbol("v1")), + Atom::new_var(State::get_symbol("v2")), + Atom::parse("v5(1,2,3)").unwrap(), ]); assert_eq!(out.len(), 8); diff --git a/src/id.rs b/src/id.rs index fd92a636..6fa5ab7e 100644 --- a/src/id.rs +++ b/src/id.rs @@ -118,7 +118,7 @@ impl Atom { } /// Returns true iff `self` contains `a` literally. - pub fn contains<'a, T: AsAtomView<'a>>(&self, s: T) -> bool { + pub fn contains(&self, s: T) -> bool { self.as_view().contains(s.as_atom_view()) } diff --git a/src/solve.rs b/src/solve.rs index d31850db..d27c00f4 100644 --- a/src/solve.rs +++ b/src/solve.rs @@ -1,10 +1,10 @@ use std::{ops::Neg, sync::Arc}; use crate::{ - atom::{Atom, AtomView, Symbol}, + atom::{AsAtomView, Atom, AtomView, Symbol}, domains::{ float::{FloatField, Real, SingleFloat}, - integer::{IntegerRing, Z}, + integer::Z, rational::Q, rational_polynomial::{RationalPolynomial, RationalPolynomialField}, InternalOrdering, @@ -29,8 +29,9 @@ impl Atom { /// Solve a non-linear system numerically over the reals using Newton's method. pub fn nsolve_system< N: SingleFloat + Real + PartialOrd + InternalOrdering + Eq + std::hash::Hash, + T: AsAtomView, >( - system: &[AtomView], + system: &[T], vars: &[Symbol], init: &[N], prec: N, @@ -41,11 +42,26 @@ impl Atom { /// Solve a system that is linear in `vars`, if possible. /// Each expression in `system` is understood to yield 0. - pub fn solve_linear_system( - system: &[AtomView], - vars: &[Symbol], + pub fn solve_linear_system( + system: &[T1], + vars: &[T2], ) -> Result, String> { - AtomView::solve_linear_system::(system, vars) + AtomView::solve_linear_system::(system, vars) + } + + /// Convert a system of linear equations to a matrix representation, returning the matrix + /// and the right-hand side. + pub fn system_to_matrix( + system: &[T1], + vars: &[T2], + ) -> Result< + ( + Matrix>, + Matrix>, + ), + String, + > { + AtomView::system_to_matrix::(system, vars) } } @@ -94,6 +110,20 @@ impl<'a> AtomView<'a> { /// Solve a non-linear system numerically over the reals using Newton's method. pub fn nsolve_system< N: SingleFloat + Real + PartialOrd + InternalOrdering + Eq + std::hash::Hash, + T: AsAtomView, + >( + system: &[T], + vars: &[Symbol], + init: &[N], + prec: N, + max_iterations: usize, + ) -> Result, String> { + let system = system.iter().map(|v| v.as_atom_view()).collect::>(); + AtomView::nsolve_system_impl(&system, vars, init, prec, max_iterations) + } + + fn nsolve_system_impl< + N: SingleFloat + Real + PartialOrd + InternalOrdering + Eq + std::hash::Hash, >( system: &[AtomView], vars: &[Symbol], @@ -189,18 +219,58 @@ impl<'a> AtomView<'a> { /// Solve a system that is linear in `vars`, if possible. /// Each expression in `system` is understood to yield 0. - pub fn solve_linear_system( - system: &[AtomView], - vars: &[Symbol], + pub fn solve_linear_system( + system: &[T1], + vars: &[T2], ) -> Result, String> { - let vars: Vec<_> = vars.iter().map(|v| Variable::Symbol(*v)).collect(); + let system: Vec<_> = system.iter().map(|v| v.as_atom_view()).collect(); + let vars: Vec<_> = vars + .iter() + .map(|v| v.as_atom_view().to_owned().into()) + .collect(); + + AtomView::solve_linear_system_impl::(&system, &vars) + } + + /// Convert a system of linear equations to a matrix representation, returning the matrix + /// and the right-hand side. + pub fn system_to_matrix( + system: &[T1], + vars: &[T2], + ) -> Result< + ( + Matrix>, + Matrix>, + ), + String, + > { + let system: Vec<_> = system.iter().map(|v| v.as_atom_view()).collect(); + + let vars: Vec<_> = vars + .iter() + .map(|v| v.as_atom_view().to_owned().into()) + .collect(); + + AtomView::system_to_matrix_impl::(&system, &vars) + } + + fn system_to_matrix_impl( + system: &[AtomView], + vars: &[Variable], + ) -> Result< + ( + Matrix>, + Matrix>, + ), + String, + > { let mut mat = Vec::with_capacity(system.len() * vars.len()); let mut row = vec![RationalPolynomial::<_, E>::new(&Z, Arc::new(vec![])); vars.len()]; - let mut rhs = vec![RationalPolynomial::<_, E>::new(&Z, Arc::new(vec![])); vars.len()]; + let mut rhs = vec![RationalPolynomial::<_, E>::new(&Z, Arc::new(vec![])); system.len()]; for (si, a) in system.iter().enumerate() { - let rat: RationalPolynomial = a.to_rational_polynomial(&Q, &Z, None); + let rat: RationalPolynomial = a.to_rational_polynomial(&Q, &Z, None); let poly = rat.to_polynomial(&vars, true).unwrap(); @@ -243,10 +313,19 @@ impl<'a> AtomView<'a> { let field = RationalPolynomialField::new(Z); - let nrows = (mat.len() / rhs.len()) as u32; - let m = Matrix::from_linear(mat, nrows, rhs.len() as u32, field.clone()).unwrap(); + let m = Matrix::from_linear(mat, system.len() as u32, vars.len() as u32, field.clone()) + .unwrap(); let b = Matrix::new_vec(rhs, field); + Ok((m, b)) + } + + fn solve_linear_system_impl( + system: &[AtomView], + vars: &[Variable], + ) -> Result, String> { + let (m, b) = Self::system_to_matrix_impl::(system, vars)?; + let sol = match m.solve(&b) { Ok(sol) => sol, Err(e) => Err(format!("Could not solve {:?}", e))?, @@ -268,7 +347,7 @@ mod test { use std::sync::Arc; use crate::{ - atom::{Atom, AtomView}, + atom::{representation::InlineVar, Atom, AtomView}, domains::{ float::{Real, F64}, integer::Z, @@ -282,19 +361,18 @@ mod test { #[test] fn solve() { - let x = State::get_symbol("v1"); - let y = State::get_symbol("v2"); - let z = State::get_symbol("v3"); + let x = State::get_symbol("v1").into(); + let y = State::get_symbol("v2").into(); + let z = State::get_symbol("v3").into(); let eqs = [ "v4*v1 + f1(v4)*v2 + v3 - 1", "v1 + v4*v2 + v3/v4 - 2", "(v4-1)v1 + v4*v3", ]; - let atoms: Vec<_> = eqs.iter().map(|e| Atom::parse(e).unwrap()).collect(); - let system: Vec<_> = atoms.iter().map(|x| x.as_view()).collect(); + let system: Vec<_> = eqs.iter().map(|e| Atom::parse(e).unwrap()).collect(); - let sol = AtomView::solve_linear_system::(&system, &[x, y, z]).unwrap(); + let sol = AtomView::solve_linear_system::(&system, &[x, y, z]).unwrap(); let res = [ "(v4^3-2*v4^2*f1(v4))*(v4^2-v4^3+v4^4-f1(v4)+v4*f1(v4)-v4^2*f1(v4))^-1", diff --git a/src/tensors/matrix.rs b/src/tensors/matrix.rs index 725f8671..aaffa882 100644 --- a/src/tensors/matrix.rs +++ b/src/tensors/matrix.rs @@ -1200,8 +1200,8 @@ impl Matrix { Ok(det) } - /// Write the matrix in echelon form. - fn gaussian_elimination( + /// Write the first `max_col` columns of the matrix in echelon form. + pub fn gaussian_elimination( &mut self, max_col: u32, early_return: bool, @@ -1258,7 +1258,7 @@ impl Matrix { } /// Create a row-reduced matrix from a matrix in echelon form. - fn back_substitution(&mut self, max_col: u32) { + pub fn back_substitution(&mut self, max_col: u32) { let field = self.field.clone(); for i in (0..self.nrows).rev() { if let Some(j) = (0..max_col).find(|&j| !F::is_zero(&self[(i, j)])) { @@ -1284,7 +1284,29 @@ impl Matrix { } } - /// Solves `A * x = 0` for the first `max_col` columns in x. + /// Augment the matrix with another matrix, e.g. create `[A B]` from matrix `A` and `B`. + /// + /// Returns an error when the matrices do not have the same number of rows. + pub fn augment(&self, matrix: &Matrix) -> Result, MatrixError> { + if self.nrows != matrix.nrows { + return Err(MatrixError::ShapeMismatch); + } + + let mut m = Matrix::new(self.nrows, self.ncols + matrix.ncols, self.field.clone()); + + for (r, (r1, r2)) in self.row_iter().zip(matrix.row_iter()).enumerate() { + m.data[r as usize * m.ncols as usize + ..r as usize * m.ncols as usize + self.ncols as usize] + .clone_from_slice(r1); + m.data[r as usize * m.ncols as usize + self.ncols as usize + ..r as usize * m.ncols as usize + m.ncols as usize] + .clone_from_slice(r2); + } + + Ok(m) + } + + /// Solves `A * x = 0` for the first `max_col` columns in `x`. /// The other columns are augmented. pub fn solve_subsystem(&mut self, max_col: u32) -> Result> { if self.nrows < max_col { @@ -1317,14 +1339,7 @@ impl Matrix { }); } - // create the augmented matrix - let mut m = Matrix::new(neqs, nvars + 1, self.field.clone()); - for r in 0..neqs { - for c in 0..nvars { - m[(r, c)] = self[(r, c)].clone(); - } - m[(r, nvars)] = b.data[r as usize].clone(); - } + let mut m = self.augment(b)?; let mut i = match m.solve_subsystem(nvars) { Ok(i) => i, diff --git a/src/transformer.rs b/src/transformer.rs index f9ec3204..998a923c 100644 --- a/src/transformer.rs +++ b/src/transformer.rs @@ -537,7 +537,7 @@ impl Transformer { cur_input.derivative_with_ws_into(*x, workspace, out); } Transformer::Collect(x, key_map, coeff_map) => cur_input - .collect_multiple_impl::( + .collect_multiple_impl::( x, workspace, if key_map.is_empty() { From 26ed8f3aae34a75b6bffa5ef161f6e35ec940c31 Mon Sep 17 00:00:00 2001 From: Ben Ruijl Date: Fri, 13 Dec 2024 12:54:07 +0100 Subject: [PATCH 4/7] Make more functions accept generic arguments - Remove the need for many into() calls - Add owned Replacement and BorrowedReplacement - Remove methods on Pattern --- examples/fibonacci.rs | 10 +- examples/partition.rs | 7 +- examples/pattern_match.rs | 4 +- examples/pattern_restrictions.rs | 2 +- examples/replace_all.rs | 4 +- examples/replace_once.rs | 2 +- examples/streaming.rs | 4 +- src/api/mathematica.rs | 54 ++--- src/api/python.rs | 49 ++-- src/coefficient.rs | 14 +- src/collect.rs | 4 +- src/derivative.rs | 8 +- src/expand.rs | 11 +- src/id.rs | 401 +++++++++++++++++++------------ src/streaming.rs | 15 +- src/transformer.rs | 33 +-- tests/pattern_matching.rs | 12 +- 17 files changed, 338 insertions(+), 296 deletions(-) diff --git a/examples/fibonacci.rs b/examples/fibonacci.rs index 06babe19..41df7b64 100644 --- a/examples/fibonacci.rs +++ b/examples/fibonacci.rs @@ -7,10 +7,10 @@ use symbolica::{ fn main() { // prepare all patterns let pattern = Pattern::parse("f(x_)").unwrap(); - let rhs = Pattern::parse("f(x_ - 1) + f(x_ - 2)").unwrap().into(); + let rhs = Pattern::parse("f(x_ - 1) + f(x_ - 2)").unwrap(); let lhs_zero_pat = Pattern::parse("f(0)").unwrap(); let lhs_one_pat = Pattern::parse("f(1)").unwrap(); - let rhs_one = Atom::new_num(1).into_pattern().into(); + let rhs_one = Atom::new_num(1).to_pattern(); // prepare the pattern restriction `x_ > 1` let restrictions = ( @@ -32,14 +32,14 @@ fn main() { for _ in 0..9 { let mut out = RecycledAtom::new(); - pattern.replace_all_into(target.as_view(), &rhs, Some(&restrictions), None, &mut out); + target.replace_all_into(&pattern, &rhs, Some(&restrictions), None, &mut out); let mut out2 = RecycledAtom::new(); out.expand_into(&mut out2); - lhs_zero_pat.replace_all_into(out2.as_view(), &rhs_one, None, None, &mut out); + out2.replace_all_into(&lhs_zero_pat, &rhs_one, None, None, &mut out); - lhs_one_pat.replace_all_into(out.as_view(), &rhs_one, None, None, &mut out2); + out.replace_all_into(&lhs_one_pat, &rhs_one, None, None, &mut out2); println!("\t{}", out2); diff --git a/examples/partition.rs b/examples/partition.rs index bbf64c32..9a33d6de 100644 --- a/examples/partition.rs +++ b/examples/partition.rs @@ -5,8 +5,8 @@ fn main() { let f = State::get_symbol("f"); let g = State::get_symbol("g"); - let o = Pattern::parse("f(x__)").unwrap().replace_all( - input.as_view(), + let o = input.replace_all( + &Pattern::parse("f(x__)").unwrap(), &Pattern::Transformer(Box::new(( Some(Pattern::parse("x__").unwrap()), vec![Transformer::Partition( @@ -14,8 +14,7 @@ fn main() { false, false, )], - ))) - .into(), + ))), None, None, ); diff --git a/examples/pattern_match.rs b/examples/pattern_match.rs index 7c73c237..bfeefb67 100644 --- a/examples/pattern_match.rs +++ b/examples/pattern_match.rs @@ -9,13 +9,13 @@ fn main() { let pat_expr = Atom::parse("z*x_*y___*g___(z___,x_,w___)").unwrap(); - let pattern = pat_expr.as_view().into_pattern(); + let pattern = pat_expr.to_pattern(); let conditions = Condition::default(); let settings = MatchSettings::default(); println!("> Matching pattern {} to {}:", pat_expr, expr.as_view()); - let mut it = pattern.pattern_match(expr.as_view(), &conditions, &settings); + let mut it = expr.pattern_match(&pattern, &conditions, &settings); while let Some(m) = it.next() { println!( "\t Match at location {:?} - {:?}:", diff --git a/examples/pattern_restrictions.rs b/examples/pattern_restrictions.rs index 243f44f8..cc1e1afb 100644 --- a/examples/pattern_restrictions.rs +++ b/examples/pattern_restrictions.rs @@ -58,7 +58,7 @@ fn main() { expr ); - let mut it = pattern.pattern_match(expr.as_view(), &conditions, &settings); + let mut it = expr.pattern_match(&pattern, &conditions, &settings); while let Some(m) = it.next() { println!("\tMatch at location {:?} - {:?}:", m.position, m.used_flags); for (id, v) in m.match_stack { diff --git a/examples/replace_all.rs b/examples/replace_all.rs index 66a8695c..539f9565 100644 --- a/examples/replace_all.rs +++ b/examples/replace_all.rs @@ -3,8 +3,8 @@ use symbolica::{atom::Atom, id::Pattern}; fn main() { let expr = Atom::parse(" f(1,2,x) + f(1,2,3)").unwrap(); let pat = Pattern::parse("f(1,2,y_)").unwrap(); - let rhs = Pattern::parse("f(1,2,y_+1)").unwrap().into(); + let rhs = Pattern::parse("f(1,2,y_+1)").unwrap(); - let out = pat.replace_all(expr.as_view(), &rhs, None, None); + let out = expr.replace_all(&pat, &rhs, None, None); println!("{}", out); } diff --git a/examples/replace_once.rs b/examples/replace_once.rs index c7eb9104..5d4d1ca5 100644 --- a/examples/replace_once.rs +++ b/examples/replace_once.rs @@ -23,7 +23,7 @@ fn main() { let mut replaced = Atom::new(); - let mut it = pattern.replace_iter(expr.as_view(), &rhs, &restrictions, &settings); + let mut it = expr.replace_iter(&pattern, &rhs, &restrictions, &settings); while let Some(()) = it.next(&mut replaced) { println!("\t{}", replaced); } diff --git a/examples/streaming.rs b/examples/streaming.rs index 70c113ab..f7f98dc5 100644 --- a/examples/streaming.rs +++ b/examples/streaming.rs @@ -8,7 +8,7 @@ use symbolica::{ fn main() { let input = Atom::parse("x+ f(x) + 2*f(y) + 7*f(z)").unwrap(); let pattern = Pattern::parse("f(x_)").unwrap(); - let rhs = Pattern::parse("f(x) + x").unwrap().into(); + let rhs = Pattern::parse("f(x) + x").unwrap(); let mut stream = TermStreamer::>::new(TermStreamerConfig { n_cores: 4, @@ -18,7 +18,7 @@ fn main() { stream.push(input); // map every term in the expression - stream = stream.map(|x| pattern.replace_all(x.as_view(), &rhs, None, None).expand()); + stream = stream.map(|x| x.replace_all(&pattern, &rhs, None, None).expand()); let res = stream.to_expression(); println!("\t+ {}", res); diff --git a/src/api/mathematica.rs b/src/api/mathematica.rs index 8d11745a..1b286399 100644 --- a/src/api/mathematica.rs +++ b/src/api/mathematica.rs @@ -5,14 +5,14 @@ use std::sync::{Arc, RwLock}; use smartstring::{LazyCompact, SmartString}; -use crate::domains::finite_field::{FiniteFieldCore, Zp, Zp64}; +use crate::domains::finite_field::{Zp, Zp64}; use crate::domains::integer::Z; use crate::domains::rational::Q; +use crate::domains::SelfRing; use crate::parser::Token; use crate::poly::Variable; use crate::{ - domains::rational_polynomial::RationalPolynomial, - printer::{PrintOptions, RationalPolynomialPrinter}, + domains::rational_polynomial::RationalPolynomial, printer::PrintOptions, printer::PrintState, state::State, }; use once_cell::sync::Lazy; @@ -76,16 +76,12 @@ fn simplify(input: String, prime: i64, explicit_rational_polynomial: bool) -> St ) .unwrap(); - format!( - "{}", - RationalPolynomialPrinter { - poly: &r, - opts: PrintOptions { - explicit_rational_polynomial, - ..PrintOptions::mathematica() - }, - add_parentheses: false - } + r.format_string( + &PrintOptions { + explicit_rational_polynomial, + ..PrintOptions::mathematica() + }, + PrintState::default(), ) } else { if prime >= 0 && prime <= u32::MAX as i64 { @@ -100,16 +96,12 @@ fn simplify(input: String, prime: i64, explicit_rational_polynomial: bool) -> St .unwrap(); symbolica.buffer.clear(); - format!( - "{}", - RationalPolynomialPrinter { - poly: &rf, - opts: PrintOptions { - explicit_rational_polynomial, - ..PrintOptions::mathematica() - }, - add_parentheses: false - } + rf.format_string( + &PrintOptions { + explicit_rational_polynomial, + ..PrintOptions::mathematica() + }, + PrintState::default(), ) } else { let field = Zp64::new(prime as u64); @@ -123,16 +115,12 @@ fn simplify(input: String, prime: i64, explicit_rational_polynomial: bool) -> St .unwrap(); symbolica.buffer.clear(); - format!( - "{}", - RationalPolynomialPrinter { - poly: &rf, - opts: PrintOptions { - explicit_rational_polynomial, - ..PrintOptions::mathematica() - }, - add_parentheses: false - } + rf.format_string( + &PrintOptions { + explicit_rational_polynomial, + ..PrintOptions::mathematica() + }, + PrintState::default(), ) } } diff --git a/src/api/python.rs b/src/api/python.rs index b8744dbb..bc711d97 100644 --- a/src/api/python.rs +++ b/src/api/python.rs @@ -317,7 +317,7 @@ pub enum ConvertibleToPattern { impl ConvertibleToPattern { pub fn to_pattern(self) -> PyResult { match self { - Self::Literal(l) => Ok(l.to_expression().expr.as_view().into_pattern().into()), + Self::Literal(l) => Ok(l.to_expression().expr.to_pattern().into()), Self::Pattern(e) => Ok(e), } } @@ -1548,10 +1548,7 @@ impl PythonTransformer { return append_transformer!( self, Transformer::ReplaceAllMultiple( - replacements - .into_iter() - .map(|r| (r.pattern, r.rhs, r.cond, r.settings)) - .collect() + replacements.into_iter().map(|r| r.replacement).collect() ) ); } @@ -3145,7 +3142,7 @@ impl PythonExpression { transformer_args.push(t.to_pattern()?.expr); } ExpressionOrTransformer::Expression(a) => { - transformer_args.push(a.expr.as_view().into_pattern()); + transformer_args.push(a.expr.to_pattern()); } } } @@ -3157,7 +3154,7 @@ impl PythonExpression { /// Convert the input to a transformer, on which subsequent transformations can be applied. pub fn transform(&self) -> PyResult { - Ok(Pattern::Transformer(Box::new((Some(self.expr.into_pattern()), vec![]))).into()) + Ok(Pattern::Transformer(Box::new((Some(self.expr.to_pattern()), vec![]))).into()) } /// Get the `idx`th component of the expression. @@ -3201,7 +3198,7 @@ impl PythonExpression { pub fn contains(&self, s: ConvertibleToPattern) -> PyResult { Ok(PythonCondition { condition: Condition::Yield(Relation::Contains( - self.expr.into_pattern(), + self.expr.to_pattern(), s.to_pattern()?.expr, )), }) @@ -3380,7 +3377,7 @@ impl PythonExpression { pub fn is_type(&self, atom_type: PythonAtomType) -> PythonCondition { PythonCondition { condition: Condition::Yield(Relation::IsType( - self.expr.into_pattern(), + self.expr.to_pattern(), match atom_type { PythonAtomType::Num => AtomType::Num, PythonAtomType::Var => AtomType::Var, @@ -3398,22 +3395,22 @@ impl PythonExpression { fn __richcmp__(&self, other: ConvertibleToPattern, op: CompareOp) -> PyResult { Ok(match op { CompareOp::Eq => PythonCondition { - condition: Relation::Eq(self.expr.into_pattern(), other.to_pattern()?.expr).into(), + condition: Relation::Eq(self.expr.to_pattern(), other.to_pattern()?.expr).into(), }, CompareOp::Ne => PythonCondition { - condition: Relation::Ne(self.expr.into_pattern(), other.to_pattern()?.expr).into(), + condition: Relation::Ne(self.expr.to_pattern(), other.to_pattern()?.expr).into(), }, CompareOp::Ge => PythonCondition { - condition: Relation::Ge(self.expr.into_pattern(), other.to_pattern()?.expr).into(), + condition: Relation::Ge(self.expr.to_pattern(), other.to_pattern()?.expr).into(), }, CompareOp::Gt => PythonCondition { - condition: Relation::Gt(self.expr.into_pattern(), other.to_pattern()?.expr).into(), + condition: Relation::Gt(self.expr.to_pattern(), other.to_pattern()?.expr).into(), }, CompareOp::Le => PythonCondition { - condition: Relation::Le(self.expr.into_pattern(), other.to_pattern()?.expr).into(), + condition: Relation::Le(self.expr.to_pattern(), other.to_pattern()?.expr).into(), }, CompareOp::Lt => PythonCondition { - condition: Relation::Lt(self.expr.into_pattern(), other.to_pattern()?.expr).into(), + condition: Relation::Lt(self.expr.to_pattern(), other.to_pattern()?.expr).into(), }, }) } @@ -4606,7 +4603,7 @@ impl PythonExpression { let mut out = RecycledAtom::new(); let mut out2 = RecycledAtom::new(); - while pattern.replace_all_into(expr_ref, rhs, cond.as_ref(), Some(&settings), &mut out) { + while expr_ref.replace_all_into(&pattern, rhs, cond.as_ref(), Some(&settings), &mut out) { if !repeat.unwrap_or(false) { break; } @@ -4645,11 +4642,7 @@ impl PythonExpression { ) -> PyResult { let reps = replacements .iter() - .map(|x| { - Replacement::new(&x.pattern, &x.rhs) - .with_conditions(&x.cond) - .with_settings(&x.settings) - }) + .map(|x| x.replacement.borrow()) .collect::>(); let mut expr_ref = self.expr.as_view(); @@ -5281,10 +5274,7 @@ impl PythonExpression { #[pyclass(name = "Replacement", module = "symbolica")] #[derive(Clone)] pub struct PythonReplacement { - pattern: Pattern, - rhs: PatternOrMap, - cond: Condition, - settings: MatchSettings, + replacement: Replacement, } #[pymethods] @@ -5338,13 +5328,10 @@ impl PythonReplacement { settings.rhs_cache_size = rhs_cache_size; } - let cond = cond.map(|r| r.0).unwrap_or(Condition::default()); - Ok(Self { - pattern, - rhs, - cond, - settings, + replacement: Replacement::new(pattern, rhs) + .with_conditions(cond.map(|r| r.0).unwrap_or_default()) + .with_settings(settings), }) } } diff --git a/src/coefficient.rs b/src/coefficient.rs index cfb2a331..b6982def 100644 --- a/src/coefficient.rs +++ b/src/coefficient.rs @@ -1616,21 +1616,19 @@ mod test { let a = a.set_coefficient_ring(&Arc::new(vec![])); - let expr = Atom::new_var(v2) - .into_pattern() + let expr = expr .replace_all( - expr.as_view(), - &Atom::new_num(3).into_pattern().into(), + &Atom::new_var(v2).to_pattern(), + &Atom::new_num(3).to_pattern(), None, None, ) .expand(); - let a = Atom::new_var(v2) - .into_pattern() + let a = a .replace_all( - a.as_view(), - &Atom::new_num(3).into_pattern().into(), + &Atom::new_var(v2).to_pattern(), + &Atom::new_num(3).to_pattern(), None, None, ) diff --git a/src/collect.rs b/src/collect.rs index c3ec1cc2..8803067b 100644 --- a/src/collect.rs +++ b/src/collect.rs @@ -191,8 +191,8 @@ impl<'a> AtomView<'a> { } /// Collect terms involving the literal occurrence of `x`. - pub fn coefficient(&self, x: AtomView<'_>) -> Atom { - Workspace::get_local().with(|ws| self.coefficient_with_ws(x, ws)) + pub fn coefficient(&self, x: T) -> Atom { + Workspace::get_local().with(|ws| self.coefficient_with_ws(x.as_atom_view(), ws)) } /// Collect terms involving the literal occurrence of `x`. diff --git a/src/derivative.rs b/src/derivative.rs index 6d8df1b1..413694a8 100644 --- a/src/derivative.rs +++ b/src/derivative.rs @@ -4,7 +4,7 @@ use std::{ }; use crate::{ - atom::{Atom, AtomView, FunctionBuilder, Symbol}, + atom::{AsAtomView, Atom, AtomView, FunctionBuilder, Symbol}, coefficient::{Coefficient, CoefficientView}, combinatorics::CombinationWithReplacementIterator, domains::{atom::AtomField, integer::Integer, rational::Rational}, @@ -26,15 +26,15 @@ impl Atom { } /// Series expand in `x` around `expansion_point` to depth `depth`. - pub fn series( + pub fn series( &self, x: Symbol, - expansion_point: AtomView, + expansion_point: T, depth: Rational, depth_is_absolute: bool, ) -> Result, &'static str> { self.as_view() - .series(x, expansion_point, depth, depth_is_absolute) + .series(x, expansion_point.as_atom_view(), depth, depth_is_absolute) } } diff --git a/src/expand.rs b/src/expand.rs index 32ca58bc..5c8931a7 100644 --- a/src/expand.rs +++ b/src/expand.rs @@ -3,7 +3,7 @@ use std::{ops::DerefMut, sync::Arc}; use smallvec::SmallVec; use crate::{ - atom::{representation::InlineVar, Atom, AtomView, Symbol}, + atom::{representation::InlineVar, AsAtomView, Atom, AtomView, Symbol}, coefficient::CoefficientView, combinatorics::CombinationWithReplacementIterator, domains::{integer::Integer, rational::Q}, @@ -21,13 +21,14 @@ impl Atom { /// only in the indeterminate `var`. The parameter `E` should be a numerical type /// that fits the largest exponent in the expanded expression. Often, /// `u8` or `u16` is sufficient. - pub fn expand_via_poly(&self, var: Option) -> Atom { - self.as_view().expand_via_poly::(var) + pub fn expand_via_poly(&self, var: Option) -> Atom { + self.as_view() + .expand_via_poly::(var.as_ref().map(|x| x.as_atom_view())) } /// Expand an expression in the variable `var`. The function [expand_via_poly] may be faster. - pub fn expand_in(&self, var: AtomView) -> Atom { - self.as_view().expand_in(var) + pub fn expand_in(&self, var: T) -> Atom { + self.as_view().expand_in(var.as_atom_view()) } /// Expand an expression in the variable `var`. diff --git a/src/id.rs b/src/id.rs index 6fa5ab7e..eecf80eb 100644 --- a/src/id.rs +++ b/src/id.rs @@ -61,38 +61,141 @@ impl std::fmt::Debug for PatternOrMap { } } +/// A pattern or a map from a list of matched wildcards to an atom. +/// The latter can be used for complex replacements that cannot be +/// expressed using atom transformations. +#[derive(Clone, Copy)] +pub enum BorrowedPatternOrMap<'a> { + Pattern(&'a Pattern), + Map(&'a Box), +} + +pub trait BorrowPatternOrMap { + fn borrow(&self) -> BorrowedPatternOrMap; +} + +impl BorrowPatternOrMap for &Pattern { + fn borrow(&self) -> BorrowedPatternOrMap { + BorrowedPatternOrMap::Pattern(*self) + } +} + +impl BorrowPatternOrMap for Pattern { + fn borrow(&self) -> BorrowedPatternOrMap { + BorrowedPatternOrMap::Pattern(self) + } +} + +impl BorrowPatternOrMap for Box { + fn borrow(&self) -> BorrowedPatternOrMap { + BorrowedPatternOrMap::Map(self) + } +} + +impl BorrowPatternOrMap for &Box { + fn borrow(&self) -> BorrowedPatternOrMap { + BorrowedPatternOrMap::Map(*self) + } +} + +impl BorrowPatternOrMap for PatternOrMap { + fn borrow(&self) -> BorrowedPatternOrMap { + match self { + PatternOrMap::Pattern(p) => BorrowedPatternOrMap::Pattern(p), + PatternOrMap::Map(m) => BorrowedPatternOrMap::Map(m), + } + } +} + +impl BorrowPatternOrMap for &PatternOrMap { + fn borrow(&self) -> BorrowedPatternOrMap { + match self { + PatternOrMap::Pattern(p) => BorrowedPatternOrMap::Pattern(p), + PatternOrMap::Map(m) => BorrowedPatternOrMap::Map(m), + } + } +} + +impl<'a> BorrowPatternOrMap for BorrowedPatternOrMap<'a> { + fn borrow(&self) -> BorrowedPatternOrMap { + *self + } +} + /// A replacement, specified by a pattern and the right-hand side, /// with optional conditions and settings. -pub struct Replacement<'a> { - pat: &'a Pattern, - rhs: &'a PatternOrMap, - conditions: Option<&'a Condition>, - settings: Option<&'a MatchSettings>, +#[derive(Debug, Clone)] +pub struct Replacement { + pat: Pattern, + rhs: PatternOrMap, + conditions: Option>, + settings: Option, } -impl<'a> Replacement<'a> { - pub fn new(pat: &'a Pattern, rhs: &'a PatternOrMap) -> Self { +impl Replacement { + pub fn new>(pat: Pattern, rhs: R) -> Self { Replacement { pat, - rhs, + rhs: rhs.into(), conditions: None, settings: None, } } - pub fn with_conditions(mut self, conditions: &'a Condition) -> Self { + pub fn with_conditions(mut self, conditions: Condition) -> Self { self.conditions = Some(conditions); self } - pub fn with_settings(mut self, settings: &'a MatchSettings) -> Self { + pub fn with_settings(mut self, settings: MatchSettings) -> Self { self.settings = Some(settings); self } } +/// A borrowed version of a [Replacement]. +#[derive(Clone, Copy)] +pub struct BorrowedReplacement<'a> { + pub pattern: &'a Pattern, + pub rhs: BorrowedPatternOrMap<'a>, + pub conditions: Option<&'a Condition>, + pub settings: Option<&'a MatchSettings>, +} + +pub trait BorrowReplacement { + fn borrow(&self) -> BorrowedReplacement; +} + +impl BorrowReplacement for Replacement { + fn borrow(&self) -> BorrowedReplacement { + BorrowedReplacement { + pattern: &self.pat, + rhs: self.rhs.borrow(), + conditions: self.conditions.as_ref(), + settings: self.settings.as_ref(), + } + } +} + +impl BorrowReplacement for &Replacement { + fn borrow(&self) -> BorrowedReplacement { + BorrowedReplacement { + pattern: &self.pat, + rhs: self.rhs.borrow(), + conditions: self.conditions.as_ref(), + settings: self.settings.as_ref(), + } + } +} + +impl<'a> BorrowReplacement for BorrowedReplacement<'a> { + fn borrow(&self) -> BorrowedReplacement { + *self + } +} + impl Atom { - pub fn into_pattern(&self) -> Pattern { + pub fn to_pattern(&self) -> Pattern { Pattern::from_view(self.as_view(), true) } @@ -138,10 +241,10 @@ impl Atom { } /// Replace all occurrences of the pattern. - pub fn replace_all( + pub fn replace_all( &self, pattern: &Pattern, - rhs: &PatternOrMap, + rhs: R, conditions: Option<&Condition>, settings: Option<&MatchSettings>, ) -> Atom { @@ -150,10 +253,10 @@ impl Atom { } /// Replace all occurrences of the pattern. - pub fn replace_all_into( + pub fn replace_all_into( &self, pattern: &Pattern, - rhs: &PatternOrMap, + rhs: R, conditions: Option<&Condition>, settings: Option<&MatchSettings>, out: &mut Atom, @@ -163,15 +266,15 @@ impl Atom { } /// Replace all occurrences of the patterns, where replacements are tested in the order that they are given. - pub fn replace_all_multiple(&self, replacements: &[Replacement<'_>]) -> Atom { + pub fn replace_all_multiple(&self, replacements: &[T]) -> Atom { self.as_view().replace_all_multiple(replacements) } /// Replace all occurrences of the patterns, where replacements are tested in the order that they are given. /// Returns `true` iff a match was found. - pub fn replace_all_multiple_into( + pub fn replace_all_multiple_into( &self, - replacements: &[Replacement<'_>], + replacements: &[T], out: &mut Atom, ) -> bool { self.as_view().replace_all_multiple_into(replacements, out) @@ -183,6 +286,27 @@ impl Atom { pub fn replace_map bool>(&self, m: &F) -> Atom { self.as_view().replace_map(m) } + + /// Return an iterator that replaces the pattern in the target once. + pub fn replace_iter<'a>( + &'a self, + pattern: &'a Pattern, + rhs: &'a PatternOrMap, + conditions: &'a Condition, + settings: &'a MatchSettings, + ) -> ReplaceIterator<'a, 'a> { + ReplaceIterator::new(pattern, self.as_view(), rhs, conditions, settings) + } + + /// Return an iterator over matched expressions. + pub fn pattern_match<'a>( + &'a self, + pattern: &'a Pattern, + conditions: &'a Condition, + settings: &'a MatchSettings, + ) -> PatternAtomTreeIterator<'a, 'a> { + PatternAtomTreeIterator::new(pattern, self.as_view(), conditions, settings) + } } /// The context of an atom. @@ -281,12 +405,12 @@ impl<'a> AtomView<'a> { } /// Returns true iff `self` contains `a` literally. - pub fn contains(&self, a: AtomView) -> bool { + pub fn contains(&self, a: T) -> bool { let mut stack = Vec::with_capacity(20); stack.push(*self); while let Some(c) = stack.pop() { - if a == c { + if a.as_atom_view() == c { return true; } @@ -632,30 +756,34 @@ impl<'a> AtomView<'a> { } /// Replace all occurrences of the patterns, where replacements are tested in the order that they are given. - pub fn replace_all( + pub fn replace_all( &self, pattern: &Pattern, - rhs: &PatternOrMap, + rhs: R, conditions: Option<&Condition>, settings: Option<&MatchSettings>, ) -> Atom { - pattern.replace_all(*self, rhs, conditions, settings) + let mut out = Atom::new(); + self.replace_all_into(pattern, rhs, conditions, settings, &mut out); + out } /// Replace all occurrences of the patterns, where replacements are tested in the order that they are given. - pub fn replace_all_into( + pub fn replace_all_into( &self, pattern: &Pattern, - rhs: &PatternOrMap, + rhs: R, conditions: Option<&Condition>, settings: Option<&MatchSettings>, out: &mut Atom, ) -> bool { - pattern.replace_all_into(*self, rhs, conditions, settings, out) + Workspace::get_local().with(|ws| { + self.replace_all_with_ws_into(pattern, rhs.borrow(), ws, conditions, settings, out) + }) } /// Replace all occurrences of the patterns, where replacements are tested in the order that they are given. - pub fn replace_all_multiple(&self, replacements: &[Replacement<'_>]) -> Atom { + pub fn replace_all_multiple(&self, replacements: &[T]) -> Atom { let mut out = Atom::new(); self.replace_all_multiple_into(replacements, &mut out); out @@ -663,9 +791,9 @@ impl<'a> AtomView<'a> { /// Replace all occurrences of the patterns, where replacements are tested in the order that they are given. /// Returns `true` iff a match was found. - pub fn replace_all_multiple_into( + pub fn replace_all_multiple_into( &self, - replacements: &[Replacement<'_>], + replacements: &[T], out: &mut Atom, ) -> bool { Workspace::get_local().with(|ws| { @@ -683,9 +811,9 @@ impl<'a> AtomView<'a> { } /// Replace all occurrences of the patterns in the target, without normalizing the output. - fn replace_all_no_norm( + fn replace_all_no_norm( &self, - replacements: &[Replacement<'_>], + replacements: &[T], workspace: &Workspace, tree_level: usize, fn_level: usize, @@ -694,6 +822,8 @@ impl<'a> AtomView<'a> { ) -> bool { let mut beyond_max_level = true; for (rep_id, r) in replacements.iter().enumerate() { + let r = r.borrow(); + let def_c = Condition::default(); let def_s = MatchSettings::default(); let conditions = r.conditions.unwrap_or(&def_c); @@ -715,10 +845,10 @@ impl<'a> AtomView<'a> { continue; } - if r.pat.could_match(*self) { + if r.pattern.could_match(*self) { let mut match_stack = MatchStack::new(conditions, settings); - let mut it = AtomMatchIterator::new(r.pat, *self); + let mut it = AtomMatchIterator::new(&r.pattern, *self); if let Some((_, used_flags)) = it.next(&mut match_stack) { let mut rhs_subs = workspace.new_atom(); @@ -730,8 +860,8 @@ impl<'a> AtomView<'a> { } else { match_stack.stack = key.1; - match r.rhs { - PatternOrMap::Pattern(rhs) => { + match &r.rhs.borrow() { + BorrowedPatternOrMap::Pattern(rhs) => { rhs.substitute_wildcards( workspace, &mut rhs_subs, @@ -740,14 +870,14 @@ impl<'a> AtomView<'a> { ) .unwrap(); // TODO: escalate? } - PatternOrMap::Map(f) => { + BorrowedPatternOrMap::Map(f) => { let mut rhs = f(&match_stack); std::mem::swap(rhs_subs.deref_mut(), &mut rhs); } } if rhs_cache.len() < settings.rhs_cache_size - && !matches!(r.rhs, PatternOrMap::Pattern(Pattern::Literal(_))) + && !matches!(r.rhs, BorrowedPatternOrMap::Pattern(Pattern::Literal(_))) { rhs_cache.insert( (rep_id, match_stack.stack.clone()), @@ -900,6 +1030,63 @@ impl<'a> AtomView<'a> { submatch } + + /// Return an iterator that replaces the pattern in the target once. + pub fn replace_iter( + &self, + pattern: &'a Pattern, + rhs: &'a PatternOrMap, + conditions: &'a Condition, + settings: &'a MatchSettings, + ) -> ReplaceIterator<'a, 'a> { + ReplaceIterator::new(pattern, *self, rhs, conditions, settings) + } + + pub fn pattern_match( + &self, + pattern: &'a Pattern, + conditions: &'a Condition, + settings: &'a MatchSettings, + ) -> PatternAtomTreeIterator<'a, 'a> { + PatternAtomTreeIterator::new(pattern, *self, conditions, settings) + } + + /// Replace all occurrences of the pattern in the target, returning `true` iff a match was found. + /// For every matched atom, the first canonical match is used and then the atom is skipped. + pub fn replace_all_with_ws_into( + &self, + pattern: &Pattern, + rhs: BorrowedPatternOrMap, + workspace: &Workspace, + conditions: Option<&Condition>, + settings: Option<&MatchSettings>, + out: &mut Atom, + ) -> bool { + let rep = BorrowedReplacement { + pattern, + rhs, + conditions, + settings, + }; + + let mut rhs_cache = HashMap::default(); + let matched = self.replace_all_no_norm( + std::slice::from_ref(&rep), + workspace, + 0, + 0, + &mut rhs_cache, + out, + ); + + if matched { + let mut norm = workspace.new_atom(); + out.as_view().normalize(workspace, &mut norm); + std::mem::swap(out, &mut norm); + } + + matched + } } impl FromStr for Pattern { @@ -914,7 +1101,7 @@ impl FromStr for Pattern { impl Pattern { pub fn parse(input: &str) -> Result { // TODO: use workspace instead of owned atom - Ok(Atom::parse(input)?.into_pattern()) + Ok(Atom::parse(input)?.to_pattern()) } /// Convert the pattern to an atom, if there are not transformers present. @@ -1511,99 +1698,6 @@ impl Pattern { Ok(()) } - - /// Return an iterator that replaces the pattern in the target once. - pub fn replace_iter<'a>( - &'a self, - target: AtomView<'a>, - rhs: &'a PatternOrMap, - conditions: &'a Condition, - settings: &'a MatchSettings, - ) -> ReplaceIterator<'a, 'a> { - ReplaceIterator::new(self, target, rhs, conditions, settings) - } - - /// Replace all occurrences of the pattern in the target - /// For every matched atom, the first canonical match is used and then the atom is skipped. - pub fn replace_all( - &self, - target: AtomView<'_>, - rhs: &PatternOrMap, - conditions: Option<&Condition>, - settings: Option<&MatchSettings>, - ) -> Atom { - Workspace::get_local().with(|ws| { - let mut out = ws.new_atom(); - self.replace_all_with_ws_into(target, rhs, ws, conditions, settings, &mut out); - out.into_inner() - }) - } - - /// Replace all occurrences of the pattern in the target, returning `true` iff a match was found. - /// For every matched atom, the first canonical match is used and then the atom is skipped. - pub fn replace_all_into( - &self, - target: AtomView<'_>, - rhs: &PatternOrMap, - conditions: Option<&Condition>, - settings: Option<&MatchSettings>, - out: &mut Atom, - ) -> bool { - Workspace::get_local() - .with(|ws| self.replace_all_with_ws_into(target, rhs, ws, conditions, settings, out)) - } - - /// Replace all occurrences of the pattern in the target, returning `true` iff a match was found. - /// For every matched atom, the first canonical match is used and then the atom is skipped. - pub fn replace_all_with_ws_into( - &self, - target: AtomView<'_>, - rhs: &PatternOrMap, - workspace: &Workspace, - conditions: Option<&Condition>, - settings: Option<&MatchSettings>, - out: &mut Atom, - ) -> bool { - let mut rep = Replacement::new(self, rhs); - if let Some(c) = conditions { - rep = rep.with_conditions(c); - } - if let Some(s) = settings { - rep = rep.with_settings(s); - } - - let mut rhs_cache = HashMap::default(); - let matched = target.replace_all_no_norm( - std::slice::from_ref(&rep), - workspace, - 0, - 0, - &mut rhs_cache, - out, - ); - - if matched { - let mut norm = workspace.new_atom(); - out.as_view().normalize(workspace, &mut norm); - std::mem::swap(out, &mut norm); - } - - matched - } - - /// Replace all occurrences in `target`, where replacements are tested in the order that they are given. - pub fn replace_all_multiple(target: AtomView, replacements: &[Replacement<'_>]) -> Atom { - target.replace_all_multiple(replacements) - } - - pub fn pattern_match<'a: 'b, 'b>( - &'b self, - target: AtomView<'a>, - conditions: &'b Condition, - settings: &'b MatchSettings, - ) -> PatternAtomTreeIterator<'a, 'b> { - PatternAtomTreeIterator::new(self, target, conditions, settings) - } } impl std::fmt::Debug for Pattern { @@ -3640,7 +3734,7 @@ mod test { let p = Pattern::parse("v2+v2^v1_").unwrap(); let rhs = Pattern::parse("v2*(1+v2^(v1_-1))").unwrap(); - let r = p.replace_all(a.as_view(), &rhs.into(), None, None); + let r = a.replace_all(&p, &rhs, None, None); let res = Atom::parse("v1*(v2+v2^2+1)+v2*(v2+1)").unwrap(); assert_eq!(r, res); } @@ -3651,9 +3745,9 @@ mod test { let p = Pattern::parse("v1").unwrap(); let rhs = Pattern::parse("1").unwrap(); - let r = p.replace_all( - a.as_view(), - &rhs.into(), + let r = a.replace_all( + &p, + &rhs, None, Some(&MatchSettings { level_range: (1, Some(1)), @@ -3667,15 +3761,10 @@ mod test { #[test] fn multiple() { let a = Atom::parse("f(v1,v2)").unwrap(); - let p1 = Pattern::parse("v1").unwrap(); - let rhs1 = Pattern::parse("v2").unwrap(); - - let p2 = Pattern::parse("v2").unwrap(); - let rhs2 = Pattern::parse("v1").unwrap(); let r = a.replace_all_multiple(&[ - Replacement::new(&p1, &rhs1.into()), - Replacement::new(&p2, &rhs2.into()), + Replacement::new(Pattern::parse("v1").unwrap(), Pattern::parse("v2").unwrap()), + Replacement::new(Pattern::parse("v2").unwrap(), Pattern::parse("v1").unwrap()), ]); let res = Atom::parse("f(v2,v1)").unwrap(); @@ -3701,7 +3790,7 @@ mod test { .unwrap() })); - let r = p.replace_all(a.as_view(), &rhs, None, None); + let r = a.replace_all(&p, &rhs, None, None); let res = Atom::parse("v1(mu2)*v2(mu3)").unwrap(); assert_eq!(r, res); } @@ -3710,7 +3799,7 @@ mod test { fn repeat_replace() { let mut a = Atom::parse("f(10)").unwrap(); let p1 = Pattern::parse("f(v1_)").unwrap(); - let rhs1 = Pattern::parse("f(v1_ - 1)").unwrap().into(); + let rhs1 = Pattern::parse("f(v1_ - 1)").unwrap(); let rest = ( State::get_symbol("v1_"), @@ -3735,7 +3824,7 @@ mod test { fn match_stack_filter() { let a = Atom::parse("f(1,2,3,4)").unwrap(); let p1 = Pattern::parse("f(v1_,v2_,v3_,v4_)").unwrap(); - let rhs1 = Pattern::parse("f(v4_,v3_,v2_,v1_)").unwrap().into(); + let rhs1 = Pattern::parse("f(v4_,v3_,v2_,v1_)").unwrap(); let rest = PatternRestriction::MatchStack(Box::new(|m| { for x in m.get_matches().windows(2) { @@ -3763,14 +3852,10 @@ mod test { #[test] fn match_cache() { - let mut expr = Atom::parse("f1(1)*f1(2)+f1(1)*f1(2)*f2").unwrap(); + let expr = Atom::parse("f1(1)*f1(2)+f1(1)*f1(2)*f2").unwrap(); + let pat = Pattern::parse("v1_(id1_)*v2_(id2_)").unwrap(); - expr = Pattern::parse("v1_(id1_)*v2_(id2_)").unwrap().replace_all( - expr.as_view(), - &Pattern::parse("f1(id1_)").unwrap().into(), - None, - None, - ); + let expr = expr.replace_all(&pat, &Pattern::parse("f1(id1_)").unwrap(), None, None); let res = Atom::parse("f1(1)+f2*f1(1)").unwrap(); assert_eq!(expr, res); @@ -3778,35 +3863,35 @@ mod test { #[test] fn match_cyclic() { - let rhs = Pattern::parse("1").unwrap().into(); + let rhs = Pattern::parse("1").unwrap(); // literal wrap let expr = Atom::parse("fc1(1,2,3)").unwrap(); let p = Pattern::parse("fc1(v1__,v1_,1)").unwrap(); - let expr = p.replace_all(expr.as_view(), &rhs, None, None); + let expr = expr.replace_all(&p, &rhs, None, None); assert_eq!(expr, Atom::new_num(1)); // multiple wildcard wrap let expr = Atom::parse("fc1(1,2,3)").unwrap(); let p = Pattern::parse("fc1(v1__,2)").unwrap(); - let expr = p.replace_all(expr.as_view(), &rhs, None, None); + let expr = expr.replace_all(&p, &rhs, None, None); assert_eq!(expr, Atom::new_num(1)); // wildcard wrap let expr = Atom::parse("fc1(1,2,3)").unwrap(); let p = Pattern::parse("fc1(v1__,v1_,2)").unwrap(); - let expr = p.replace_all(expr.as_view(), &rhs, None, None); + let expr = expr.replace_all(&p, &rhs, None, None); assert_eq!(expr, Atom::new_num(1)); let expr = Atom::parse("fc1(v1,4,3,5,4)").unwrap(); let p = Pattern::parse("fc1(v1__,v1_,v2_,v1_)").unwrap(); - let expr = p.replace_all(expr.as_view(), &rhs, None, None); + let expr = expr.replace_all(&p, &rhs, None, None); assert_eq!(expr, Atom::new_num(1)); // function shift let expr = Atom::parse("fc1(f1(1),f1(2),f1(3))").unwrap(); let p = Pattern::parse("fc1(f1(v1_),f1(2),f1(3))").unwrap(); - let expr = p.replace_all(expr.as_view(), &rhs, None, None); + let expr = expr.replace_all(&p, &rhs, None, None); assert_eq!(expr, Atom::new_num(1)); } diff --git a/src/streaming.rs b/src/streaming.rs index 1cca619c..693deaa2 100644 --- a/src/streaming.rs +++ b/src/streaming.rs @@ -618,9 +618,9 @@ mod test { streamer = streamer.map(|f| f); let pattern = Pattern::parse("f1(x_)").unwrap(); - let rhs = Pattern::parse("f1(v1) + v1").unwrap().into(); + let rhs = Pattern::parse("f1(v1) + v1").unwrap(); - streamer = streamer.map(|x| pattern.replace_all(x.as_view(), &rhs, None, None).expand()); + streamer = streamer.map(|x| x.replace_all(&pattern, &rhs, None, None).expand()); streamer.normalize(); @@ -643,12 +643,11 @@ mod test { streamer.push(input); let pattern = Pattern::parse("v1_").unwrap(); - let rhs = Pattern::parse("v1").unwrap().into(); + let rhs = Pattern::parse("v1").unwrap(); streamer = streamer.map(|x| { - pattern - .replace_all( - x.as_view(), + x.replace_all( + &pattern, &rhs, Some( &( @@ -674,13 +673,13 @@ mod test { fn memory_stream() { let input = Atom::parse("v1 + f1(v1) + 2*f1(v2) + 7*f1(v3)").unwrap(); let pattern = Pattern::parse("f1(x_)").unwrap(); - let rhs = Pattern::parse("f1(v1) + v1").unwrap().into(); + let rhs = Pattern::parse("f1(v1) + v1").unwrap(); let mut stream = TermStreamer::>::new(TermStreamerConfig::default()); stream.push(input); // map every term in the expression - stream = stream.map(|x| pattern.replace_all(x.as_view(), &rhs, None, None).expand()); + stream = stream.map(|x| x.replace_all(&pattern, &rhs, None, None).expand()); let r = stream.to_expression(); diff --git a/src/transformer.rs b/src/transformer.rs index 998a923c..63b9eed8 100644 --- a/src/transformer.rs +++ b/src/transformer.rs @@ -1,13 +1,13 @@ use std::{ops::ControlFlow, sync::Arc, time::Instant}; use crate::{ - atom::{representation::FunView, Atom, AtomOrView, AtomView, Fun, Symbol}, + atom::{representation::FunView, Atom, AtomView, Fun, Symbol}, coefficient::{Coefficient, CoefficientView}, combinatorics::{partitions, unique_permutations}, domains::rational::Rational, id::{ - Condition, Evaluate, MatchSettings, Pattern, PatternOrMap, PatternRestriction, Relation, - Replacement, + BorrowPatternOrMap, Condition, Evaluate, MatchSettings, Pattern, PatternOrMap, + PatternRestriction, Relation, Replacement, }, printer::{AtomPrinter, PrintOptions}, state::{RecycledAtom, Workspace}, @@ -121,7 +121,7 @@ pub enum Transformer { /// Perform a series expansion. Series(Symbol, Atom, Rational, bool), ///Collect all terms in powers of a variable. - Collect(Vec>, Vec, Vec), + Collect(Vec, Vec, Vec), /// Collect numbers. CollectNum, /// Apply find-and-replace on the lhs. @@ -132,14 +132,7 @@ pub enum Transformer { MatchSettings, ), /// Apply multiple find-and-replace on the lhs. - ReplaceAllMultiple( - Vec<( - Pattern, - PatternOrMap, - Condition, - MatchSettings, - )>, - ), + ReplaceAllMultiple(Vec), /// Take the product of a list of arguments in the rhs. Product, /// Take the sum of a list of arguments in the rhs. @@ -576,9 +569,9 @@ impl Transformer { } } Transformer::ReplaceAll(pat, rhs, cond, settings) => { - pat.replace_all_with_ws_into( - cur_input, - rhs, + cur_input.replace_all_with_ws_into( + pat, + rhs.borrow(), workspace, cond.into(), settings.into(), @@ -586,15 +579,7 @@ impl Transformer { ); } Transformer::ReplaceAllMultiple(replacements) => { - let reps = replacements - .iter() - .map(|(pat, rhs, cond, settings)| { - Replacement::new(pat, rhs) - .with_conditions(cond) - .with_settings(settings) - }) - .collect::>(); - cur_input.replace_all_multiple_into(&reps, out); + cur_input.replace_all_multiple_into(&replacements, out); } Transformer::Product => { if let AtomView::Fun(f) = cur_input { diff --git a/tests/pattern_matching.rs b/tests/pattern_matching.rs index b419506e..e511ed2a 100644 --- a/tests/pattern_matching.rs +++ b/tests/pattern_matching.rs @@ -8,10 +8,10 @@ use symbolica::{ fn fibonacci() { // prepare all patterns let pattern = Pattern::parse("f(x_)").unwrap(); - let rhs = Pattern::parse("f(x_ - 1) + f(x_ - 2)").unwrap().into(); + let rhs = Pattern::parse("f(x_ - 1) + f(x_ - 2)").unwrap(); let lhs_zero_pat = Pattern::parse("f(0)").unwrap(); let lhs_one_pat = Pattern::parse("f(1)").unwrap(); - let rhs_one = Atom::new_num(1).into_pattern().into(); + let rhs_one = Atom::new_num(1).to_pattern(); // prepare the pattern restriction `x_ > 1` let restrictions = ( @@ -28,14 +28,14 @@ fn fibonacci() { for _ in 0..9 { let mut out = RecycledAtom::new(); - pattern.replace_all_into(target.as_view(), &rhs, Some(&restrictions), None, &mut out); + target.replace_all_into(&pattern, &rhs, Some(&restrictions), None, &mut out); let mut out2 = RecycledAtom::new(); out.expand_into(&mut out2); - lhs_zero_pat.replace_all_into(out2.as_view(), &rhs_one, None, None, &mut out); + out2.replace_all_into(&lhs_zero_pat, &rhs_one, None, None, &mut out); - lhs_one_pat.replace_all_into(out.as_view(), &rhs_one, None, None, &mut out2); + out.replace_all_into(&lhs_one_pat, &rhs_one, None, None, &mut out2); target = out2; } @@ -57,7 +57,7 @@ fn replace_once() { let mut replaced = Atom::new(); - let mut it = pattern.replace_iter(expr.as_view(), &rhs, &restrictions, &settings); + let mut it = expr.replace_iter(&pattern, &rhs, &restrictions, &settings); let mut r = vec![]; while let Some(()) = it.next(&mut replaced) { r.push(replaced.clone()); From 17bfd3b3944ec3e665de2a102427e864784d1ede Mon Sep 17 00:00:00 2001 From: Ben Ruijl Date: Fri, 13 Dec 2024 14:50:14 +0100 Subject: [PATCH 5/7] Refactor immutable Atom(View) methods into AtomCore trait --- examples/coefficient_ring.rs | 5 +- examples/collect.rs | 6 +- examples/derivative.rs | 5 +- examples/evaluate.rs | 1 + examples/expansion.rs | 2 +- examples/factorization.rs | 2 +- examples/fibonacci.rs | 4 +- examples/groebner_basis.rs | 2 +- examples/nested_evaluation.rs | 2 +- examples/optimize.rs | 2 +- examples/optimize_multiple.rs | 2 +- examples/partition.rs | 7 +- examples/pattern_match.rs | 2 +- examples/pattern_restrictions.rs | 4 +- examples/polynomial_gcd.rs | 2 +- examples/rational_polynomial.rs | 2 +- examples/replace_all.rs | 5 +- examples/replace_once.rs | 6 +- examples/series.rs | 5 +- examples/solve_linear_system.rs | 2 +- examples/streaming.rs | 2 +- examples/tree_replace.rs | 4 +- src/api/python.rs | 2 +- src/atom.rs | 77 ++-- src/atom/core.rs | 628 +++++++++++++++++++++++++++++ src/coefficient.rs | 66 +-- src/collect.rs | 98 +---- src/derivative.rs | 39 +- src/domains/algebraic_number.rs | 2 +- src/domains/atom.rs | 2 +- src/domains/finite_field.rs | 2 +- src/domains/rational.rs | 2 +- src/domains/rational_polynomial.rs | 1 + src/evaluate.rs | 74 +--- src/expand.rs | 58 +-- src/id.rs | 175 +------- src/lib.rs | 2 +- src/normalize.rs | 2 +- src/parser.rs | 2 +- src/poly.rs | 83 +--- src/poly/evaluate.rs | 14 +- src/poly/factor.rs | 2 +- src/poly/groebner.rs | 2 +- src/poly/polynomial.rs | 6 +- src/poly/resultant.rs | 2 +- src/poly/series.rs | 2 +- src/poly/univariate.rs | 2 +- src/printer.rs | 22 +- src/solve.rs | 65 +-- src/streaming.rs | 46 +-- src/tensors.rs | 43 +- tests/pattern_matching.rs | 8 +- tests/rational_polynomial.rs | 8 +- 53 files changed, 835 insertions(+), 776 deletions(-) create mode 100644 src/atom/core.rs diff --git a/examples/coefficient_ring.rs b/examples/coefficient_ring.rs index 146598e6..77feb0ec 100644 --- a/examples/coefficient_ring.rs +++ b/examples/coefficient_ring.rs @@ -1,6 +1,9 @@ use std::sync::Arc; -use symbolica::{atom::Atom, state::State}; +use symbolica::{ + atom::{Atom, AtomCore}, + state::State, +}; fn main() { let expr = Atom::parse("x*z+x*(y+2)^-1*(y+z+1)").unwrap(); diff --git a/examples/collect.rs b/examples/collect.rs index 01d7133c..91a75cc3 100644 --- a/examples/collect.rs +++ b/examples/collect.rs @@ -1,4 +1,8 @@ -use symbolica::{atom::Atom, fun, state::State}; +use symbolica::{ + atom::{Atom, AtomCore}, + fun, + state::State, +}; fn main() { let input = Atom::parse("x*(1+a)+x*5*y+f(5,x)+2+y^2+x^2 + x^3").unwrap(); diff --git a/examples/derivative.rs b/examples/derivative.rs index 1a41f2a2..91afb0b1 100644 --- a/examples/derivative.rs +++ b/examples/derivative.rs @@ -1,4 +1,7 @@ -use symbolica::{atom::Atom, state::State}; +use symbolica::{ + atom::{Atom, AtomCore}, + state::State, +}; fn main() { let x = State::get_symbol("x"); diff --git a/examples/evaluate.rs b/examples/evaluate.rs index 280c3b3f..06a1977a 100644 --- a/examples/evaluate.rs +++ b/examples/evaluate.rs @@ -1,4 +1,5 @@ use ahash::HashMap; +use symbolica::atom::AtomCore; use symbolica::evaluate::EvaluationFn; use symbolica::{atom::Atom, state::State}; diff --git a/examples/expansion.rs b/examples/expansion.rs index 45f2f543..fac3fd6b 100644 --- a/examples/expansion.rs +++ b/examples/expansion.rs @@ -1,4 +1,4 @@ -use symbolica::atom::Atom; +use symbolica::atom::{Atom, AtomCore}; fn main() { let input = Atom::parse("(1+x)^3").unwrap(); diff --git a/examples/factorization.rs b/examples/factorization.rs index 19641ccb..a795a62e 100644 --- a/examples/factorization.rs +++ b/examples/factorization.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use symbolica::{ - atom::Atom, + atom::{Atom, AtomCore}, domains::{finite_field::Zp, integer::Z}, poly::{factor::Factorize, polynomial::MultivariatePolynomial, Variable}, state::State, diff --git a/examples/fibonacci.rs b/examples/fibonacci.rs index 41df7b64..83a33ddd 100644 --- a/examples/fibonacci.rs +++ b/examples/fibonacci.rs @@ -1,5 +1,5 @@ use symbolica::{ - atom::{Atom, AtomView}, + atom::{Atom, AtomCore, AtomView}, id::{Match, Pattern, WildcardRestriction}, state::{RecycledAtom, State}, }; @@ -35,7 +35,7 @@ fn main() { target.replace_all_into(&pattern, &rhs, Some(&restrictions), None, &mut out); let mut out2 = RecycledAtom::new(); - out.expand_into(&mut out2); + out.expand_into(None, &mut out2); out2.replace_all_into(&lhs_zero_pat, &rhs_one, None, None, &mut out); diff --git a/examples/groebner_basis.rs b/examples/groebner_basis.rs index cd87f8a3..524b46d5 100644 --- a/examples/groebner_basis.rs +++ b/examples/groebner_basis.rs @@ -1,5 +1,5 @@ use symbolica::{ - atom::Atom, + atom::{Atom, AtomCore}, domains::finite_field::Zp, poly::{groebner::GroebnerBasis, polynomial::MultivariatePolynomial, GrevLexOrder}, state::State, diff --git a/examples/nested_evaluation.rs b/examples/nested_evaluation.rs index bbf58154..e2638859 100644 --- a/examples/nested_evaluation.rs +++ b/examples/nested_evaluation.rs @@ -1,5 +1,5 @@ use symbolica::{ - atom::Atom, + atom::{Atom, AtomCore}, domains::rational::Rational, evaluate::{CompileOptions, FunctionMap, InlineASM, OptimizationSettings}, state::State, diff --git a/examples/optimize.rs b/examples/optimize.rs index de4996b4..14335946 100644 --- a/examples/optimize.rs +++ b/examples/optimize.rs @@ -1,5 +1,5 @@ use symbolica::{ - atom::Atom, + atom::{Atom, AtomCore}, domains::rational::Q, poly::evaluate::{BorrowedHornerScheme, InstructionSetPrinter}, }; diff --git a/examples/optimize_multiple.rs b/examples/optimize_multiple.rs index 478fc98a..06d6c0b1 100644 --- a/examples/optimize_multiple.rs +++ b/examples/optimize_multiple.rs @@ -1,5 +1,5 @@ use symbolica::{ - atom::Atom, + atom::{Atom, AtomCore}, domains::rational::Q, poly::evaluate::{HornerScheme, InstructionSetPrinter}, }; diff --git a/examples/partition.rs b/examples/partition.rs index 9a33d6de..eb1969f2 100644 --- a/examples/partition.rs +++ b/examples/partition.rs @@ -1,4 +1,9 @@ -use symbolica::{atom::Atom, id::Pattern, state::State, transformer::Transformer}; +use symbolica::{ + atom::{Atom, AtomCore}, + id::Pattern, + state::State, + transformer::Transformer, +}; fn main() { let input = Atom::parse("f(1,3,2,3,1)").unwrap(); diff --git a/examples/pattern_match.rs b/examples/pattern_match.rs index bfeefb67..88665e89 100644 --- a/examples/pattern_match.rs +++ b/examples/pattern_match.rs @@ -1,5 +1,5 @@ use symbolica::{ - atom::Atom, + atom::{Atom, AtomCore}, id::{Condition, Match, MatchSettings}, state::State, }; diff --git a/examples/pattern_restrictions.rs b/examples/pattern_restrictions.rs index cc1e1afb..68dfebd7 100644 --- a/examples/pattern_restrictions.rs +++ b/examples/pattern_restrictions.rs @@ -1,5 +1,5 @@ use symbolica::{ - atom::{Atom, AtomView}, + atom::{Atom, AtomCore, AtomView}, coefficient::CoefficientView, domains::finite_field, id::{Condition, Match, MatchSettings, WildcardRestriction}, @@ -9,7 +9,7 @@ fn main() { let expr = Atom::parse("f(1,2,3,4,5,6,7)").unwrap(); let pat_expr = Atom::parse("f(x__,y__,z__,w__)").unwrap(); - let pattern = pat_expr.as_view().into_pattern(); + let pattern = pat_expr.as_view().to_pattern(); let x = State::get_symbol("x__"); let y = State::get_symbol("y__"); diff --git a/examples/polynomial_gcd.rs b/examples/polynomial_gcd.rs index 8604d091..c679a9a4 100644 --- a/examples/polynomial_gcd.rs +++ b/examples/polynomial_gcd.rs @@ -1,5 +1,5 @@ use symbolica::{ - atom::Atom, + atom::{Atom, AtomCore}, domains::{integer::Z, rational::Q}, }; use tracing_subscriber::{fmt, prelude::*, util::SubscriberInitExt, EnvFilter}; diff --git a/examples/rational_polynomial.rs b/examples/rational_polynomial.rs index 5cad80c4..f7de5a21 100644 --- a/examples/rational_polynomial.rs +++ b/examples/rational_polynomial.rs @@ -1,5 +1,5 @@ use symbolica::{ - atom::Atom, + atom::{Atom, AtomCore}, domains::{integer::Z, rational_polynomial::RationalPolynomial}, }; diff --git a/examples/replace_all.rs b/examples/replace_all.rs index 539f9565..b852afb7 100644 --- a/examples/replace_all.rs +++ b/examples/replace_all.rs @@ -1,4 +1,7 @@ -use symbolica::{atom::Atom, id::Pattern}; +use symbolica::{ + atom::{Atom, AtomCore}, + id::Pattern, +}; fn main() { let expr = Atom::parse(" f(1,2,x) + f(1,2,3)").unwrap(); diff --git a/examples/replace_once.rs b/examples/replace_once.rs index 5d4d1ca5..b390aa12 100644 --- a/examples/replace_once.rs +++ b/examples/replace_once.rs @@ -1,5 +1,5 @@ use symbolica::{ - atom::Atom, + atom::{Atom, AtomCore}, id::{Condition, MatchSettings}, }; @@ -8,9 +8,9 @@ fn main() { let pat_expr = Atom::parse("f(x_)").unwrap(); let rhs_expr = Atom::parse("g(x_)").unwrap(); - let rhs = rhs_expr.as_view().into_pattern().into(); + let rhs = rhs_expr.as_view().to_pattern().into(); - let pattern = pat_expr.as_view().into_pattern(); + let pattern = pat_expr.as_view().to_pattern(); let restrictions = Condition::default(); let settings = MatchSettings::default(); diff --git a/examples/series.rs b/examples/series.rs index 9e4dca4e..ae128b48 100644 --- a/examples/series.rs +++ b/examples/series.rs @@ -1,4 +1,7 @@ -use symbolica::{atom::Atom, state::State}; +use symbolica::{ + atom::{Atom, AtomCore}, + state::State, +}; fn main() { let x = State::get_symbol("x"); diff --git a/examples/solve_linear_system.rs b/examples/solve_linear_system.rs index b08b5ae9..4512ded6 100644 --- a/examples/solve_linear_system.rs +++ b/examples/solve_linear_system.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use symbolica::{ - atom::{representation::InlineVar, Atom, AtomView}, + atom::{representation::InlineVar, Atom, AtomCore, AtomView}, domains::{ integer::Z, rational::Q, diff --git a/examples/streaming.rs b/examples/streaming.rs index f7f98dc5..4d5360c2 100644 --- a/examples/streaming.rs +++ b/examples/streaming.rs @@ -1,6 +1,6 @@ use brotli::CompressorWriter; use symbolica::{ - atom::Atom, + atom::{Atom, AtomCore}, id::Pattern, streaming::{TermStreamer, TermStreamerConfig}, }; diff --git a/examples/tree_replace.rs b/examples/tree_replace.rs index 55bd10f8..50408f7e 100644 --- a/examples/tree_replace.rs +++ b/examples/tree_replace.rs @@ -1,5 +1,5 @@ use symbolica::{ - atom::Atom, + atom::{Atom, AtomCore}, id::{Condition, Match, MatchSettings, PatternAtomTreeIterator}, state::State, }; @@ -8,7 +8,7 @@ fn main() { let expr = Atom::parse("f(z)*f(f(x))*f(y)").unwrap(); let pat_expr = Atom::parse("f(x_)").unwrap(); - let pattern = pat_expr.as_view().into_pattern(); + let pattern = pat_expr.to_pattern(); let restrictions = Condition::default(); let settings = MatchSettings::default(); diff --git a/src/api/python.rs b/src/api/python.rs index bc711d97..f9e03964 100644 --- a/src/api/python.rs +++ b/src/api/python.rs @@ -32,7 +32,7 @@ use smartstring::{LazyCompact, SmartString}; use pyo3::pymodule; use crate::{ - atom::{Atom, AtomType, AtomView, ListIterator, Symbol}, + atom::{Atom, AtomCore, AtomType, AtomView, ListIterator, Symbol}, coefficient::CoefficientView, domains::{ algebraic_number::AlgebraicExtension, diff --git a/src/atom.rs b/src/atom.rs index 76e26e53..433efd2b 100644 --- a/src/atom.rs +++ b/src/atom.rs @@ -1,7 +1,8 @@ mod coefficient; +mod core; pub mod representation; -use representation::{InlineNum, InlineVar}; +use representation::InlineVar; use crate::{ coefficient::Coefficient, @@ -12,6 +13,7 @@ use crate::{ }; use std::{cmp::Ordering, hash::Hash, ops::DerefMut, str::FromStr}; +pub use self::core::AtomCore; pub use self::representation::{ Add, AddView, Fun, ListIterator, ListSlice, Mul, MulView, Num, NumView, Pow, PowView, Var, VarView, @@ -20,7 +22,7 @@ use self::representation::{FunView, RawAtom}; /// A symbol, for example the name of a variable or the name of a function, /// together with its properties. -/// Should be created using `get_symbol` of `State`. +/// Should be created using [State::get_symbol]. #[derive(Copy, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct Symbol { id: u32, @@ -312,49 +314,6 @@ impl<'a> AtomOrView<'a> { } } -/// A trait for any type that can be converted into an `AtomView`. -/// To be used for functions that accept any argument that can be -/// converted to an `AtomView`. -pub trait AsAtomView { - fn as_atom_view(&self) -> AtomView; -} - -impl<'a> AsAtomView for AtomView<'a> { - fn as_atom_view(&self) -> AtomView<'a> { - *self - } -} - -impl AsAtomView for InlineVar { - fn as_atom_view(&self) -> AtomView { - self.as_view() - } -} - -impl AsAtomView for InlineNum { - fn as_atom_view(&self) -> AtomView { - self.as_view() - } -} - -impl> AsAtomView for T { - fn as_atom_view(&self) -> AtomView { - self.as_ref().as_view() - } -} - -impl<'a> AsAtomView for AtomOrView<'a> { - fn as_atom_view(&self) -> AtomView { - self.as_view() - } -} - -impl AsRef for Atom { - fn as_ref(&self) -> &Atom { - self - } -} - impl<'a> AtomView<'a> { pub fn to_owned(&self) -> Atom { let mut a = Atom::default(); @@ -852,12 +811,12 @@ impl Atom { } } -/// A constructor of a function. Consider using the [`fun!`] macro instead. +/// A constructor of a function. Consider using the [crate::fun!] macro instead. /// /// For example: /// ``` /// # use symbolica::{ -/// # atom::{Atom, AsAtomView, FunctionBuilder}, +/// # atom::{Atom, AtomCore, FunctionBuilder}, /// # state::{FunctionAttribute, State}, /// # }; /// # fn main() { @@ -885,7 +844,7 @@ impl FunctionBuilder { } /// Add an argument to the function. - pub fn add_arg(mut self, arg: T) -> FunctionBuilder { + pub fn add_arg(mut self, arg: T) -> FunctionBuilder { if let Atom::Fun(f) = self.handle.deref_mut() { f.add_arg(arg.as_atom_view()); } @@ -894,7 +853,7 @@ impl FunctionBuilder { } /// Add multiple arguments to the function. - pub fn add_args(mut self, args: &[T]) -> FunctionBuilder { + pub fn add_args(mut self, args: &[T]) -> FunctionBuilder { if let Atom::Fun(f) = self.handle.deref_mut() { for a in args { f.add_arg(a.as_atom_view()); @@ -1023,7 +982,7 @@ impl Atom { } /// Take the `self` to the power `exp`. Use [`Atom::npow()`] for a numerical power and [`Atom::rpow()`] for the reverse operation. - pub fn pow(&self, exp: T) -> Atom { + pub fn pow(&self, exp: T) -> Atom { Workspace::get_local().with(|ws| { let mut t = ws.new_atom(); self.as_view() @@ -1035,7 +994,7 @@ impl Atom { } /// Take `base` to the power `self`. - pub fn rpow(&self, base: T) -> Atom { + pub fn rpow(&self, base: T) -> Atom { Workspace::get_local().with(|ws| { let mut t = ws.new_atom(); base.as_atom_view() @@ -1047,7 +1006,7 @@ impl Atom { } /// Add the atoms in `args`. - pub fn add_many<'a, T: AsAtomView + Copy>(args: &[T]) -> Atom { + pub fn add_many<'a, T: AtomCore + Copy>(args: &[T]) -> Atom { let mut out = Atom::new(); Workspace::get_local().with(|ws| { let mut t = ws.new_atom(); @@ -1062,7 +1021,7 @@ impl Atom { } /// Multiply the atoms in `args`. - pub fn mul_many<'a, T: AsAtomView + Copy>(args: &[T]) -> Atom { + pub fn mul_many<'a, T: AtomCore + Copy>(args: &[T]) -> Atom { let mut out = Atom::new(); Workspace::get_local().with(|ws| { let mut t = ws.new_atom(); @@ -1571,9 +1530,19 @@ impl> std::ops::Div for Atom { } } +impl AsRef for Atom { + fn as_ref(&self) -> &Atom { + self + } +} + #[cfg(test)] mod test { - use crate::{atom::Atom, fun, state::State}; + use crate::{ + atom::{Atom, AtomCore}, + fun, + state::State, + }; #[test] fn debug() { diff --git a/src/atom/core.rs b/src/atom/core.rs new file mode 100644 index 00000000..759b8478 --- /dev/null +++ b/src/atom/core.rs @@ -0,0 +1,628 @@ +use ahash::{HashMap, HashSet}; +use rayon::ThreadPool; + +use crate::{ + coefficient::{Coefficient, CoefficientView, ConvertToRing}, + domains::{ + atom::AtomField, + factorized_rational_polynomial::{ + FactorizedRationalPolynomial, FromNumeratorAndFactorizedDenominator, + }, + float::{Real, SingleFloat}, + integer::Z, + rational::Rational, + rational_polynomial::{ + FromNumeratorAndDenominator, RationalPolynomial, RationalPolynomialField, + }, + EuclideanDomain, InternalOrdering, + }, + evaluate::{EvalTree, EvaluationFn, ExpressionEvaluator, FunctionMap, OptimizationSettings}, + id::{ + BorrowPatternOrMap, BorrowReplacement, Condition, ConditionResult, Context, MatchSettings, + Pattern, PatternAtomTreeIterator, PatternOrMap, PatternRestriction, ReplaceIterator, + }, + poly::{ + factor::Factorize, gcd::PolynomialGCD, polynomial::MultivariatePolynomial, series::Series, + Exponent, PositiveExponent, Variable, + }, + printer::{AtomPrinter, PrintOptions, PrintState}, + state::Workspace, + tensors::matrix::Matrix, +}; +use std::sync::Arc; + +use super::{ + representation::{InlineNum, InlineVar}, + Atom, AtomOrView, AtomView, Symbol, +}; + +/// All core features of expressions, such as expansion and +/// pattern matching that leave the expression unchanged. +pub trait AtomCore { + /// Take a view of the atom. + fn as_atom_view(&self) -> AtomView; + + /// Collect terms involving the same power of `x`, where `x` is a variable or function, e.g. + /// + /// ```math + /// collect(x + x * y + x^2, x) = x * (1+y) + x^2 + /// ``` + /// + /// Both the *key* (the quantity collected in) and its coefficient can be mapped using + /// `key_map` and `coeff_map` respectively. + fn collect( + &self, + x: T, + key_map: Option>, + coeff_map: Option>, + ) -> Atom { + self.as_atom_view().collect::(x, key_map, coeff_map) + } + + /// Collect terms involving the same power of `x`, where `x` is a variable or function, e.g. + /// + /// ```math + /// collect(x + x * y + x^2, x) = x * (1+y) + x^2 + /// ``` + /// + /// Both the *key* (the quantity collected in) and its coefficient can be mapped using + /// `key_map` and `coeff_map` respectively. + fn collect_multiple( + &self, + xs: &[T], + key_map: Option>, + coeff_map: Option>, + ) -> Atom { + self.as_atom_view() + .collect_multiple::(xs, key_map, coeff_map) + } + + /// Collect terms involving the same power of `x` in `xs`, where `xs` is a list of indeterminates. + /// Return the list of key-coefficient pairs + fn coefficient_list(&self, xs: &[T]) -> Vec<(Atom, Atom)> { + self.as_atom_view().coefficient_list::(xs) + } + + /// Collect terms involving the literal occurrence of `x`. + fn coefficient(&self, x: T) -> Atom { + Workspace::get_local().with(|ws| { + self.as_atom_view() + .coefficient_with_ws(x.as_atom_view(), ws) + }) + } + + /// Write the expression over a common denominator. + fn together(&self) -> Atom { + self.as_atom_view().together() + } + + /// Write the expression as a sum of terms with minimal denominators. + fn apart(&self, x: Symbol) -> Atom { + self.as_atom_view().apart(x) + } + + /// Cancel all common factors between numerators and denominators. + /// Any non-canceling parts of the expression will not be rewritten. + fn cancel(&self) -> Atom { + self.as_atom_view().cancel() + } + + /// Factor the expression over the rationals. + fn factor(&self) -> Atom { + self.as_atom_view().factor() + } + + /// Collect numerical factors by removing the numerical content from additions. + /// For example, `-2*x + 4*x^2 + 6*x^3` will be transformed into `-2*(x - 2*x^2 - 3*x^3)`. + /// + /// The first argument of the addition is normalized to a positive quantity. + fn collect_num(&self) -> Atom { + self.as_atom_view().collect_num() + } + + /// Expand an expression. The function [AtomCore::expand_via_poly] may be faster. + fn expand(&self) -> Atom { + self.as_atom_view().expand() + } + + /// Expand the expression by converting it to a polynomial, optionally + /// only in the indeterminate `var`. The parameter `E` should be a numerical type + /// that fits the largest exponent in the expanded expression. Often, + /// `u8` or `u16` is sufficient. + fn expand_via_poly(&self, var: Option) -> Atom { + self.as_atom_view() + .expand_via_poly::(var.as_ref().map(|x| x.as_atom_view())) + } + + /// Expand an expression in the variable `var`. The function [AtomCore::expand_via_poly] may be faster. + fn expand_in(&self, var: T) -> Atom { + self.as_atom_view().expand_in(var.as_atom_view()) + } + + /// Expand an expression in the variable `var`. + fn expand_in_symbol(&self, var: Symbol) -> Atom { + self.as_atom_view() + .expand_in(InlineVar::from(var).as_view()) + } + + /// Expand an expression, returning `true` iff the expression changed. + fn expand_into(&self, var: Option, out: &mut Atom) -> bool { + self.as_atom_view().expand_into(var, out) + } + + /// Distribute numbers in the expression, for example: + /// `2*(x+y)` -> `2*x+2*y`. + fn expand_num(&self) -> Atom { + self.as_atom_view().expand_num() + } + + /// Take a derivative of the expression with respect to `x`. + fn derivative(&self, x: Symbol) -> Atom { + self.as_atom_view().derivative(x) + } + + /// Take a derivative of the expression with respect to `x` and + /// write the result in `out`. + /// Returns `true` if the derivative is non-zero. + fn derivative_into(&self, x: Symbol, out: &mut Atom) -> bool { + self.as_atom_view().derivative_into(x, out) + } + + /// Series expand in `x` around `expansion_point` to depth `depth`. + fn series( + &self, + x: Symbol, + expansion_point: T, + depth: Rational, + depth_is_absolute: bool, + ) -> Result, &'static str> { + self.as_atom_view() + .series(x, expansion_point.as_atom_view(), depth, depth_is_absolute) + } + + /// Find the root of a function in `x` numerically over the reals using Newton's method. + fn nsolve( + &self, + x: Symbol, + init: N, + prec: N, + max_iterations: usize, + ) -> Result { + self.as_atom_view().nsolve(x, init, prec, max_iterations) + } + + /// Solve a non-linear system numerically over the reals using Newton's method. + fn nsolve_system< + N: SingleFloat + Real + PartialOrd + InternalOrdering + Eq + std::hash::Hash, + T: AtomCore, + >( + system: &[T], + vars: &[Symbol], + init: &[N], + prec: N, + max_iterations: usize, + ) -> Result, String> { + AtomView::nsolve_system(system, vars, init, prec, max_iterations) + } + + /// Solve a system that is linear in `vars`, if possible. + /// Each expression in `system` is understood to yield 0. + fn solve_linear_system( + system: &[T1], + vars: &[T2], + ) -> Result, String> { + AtomView::solve_linear_system::(system, vars) + } + + /// Convert a system of linear equations to a matrix representation, returning the matrix + /// and the right-hand side. + fn system_to_matrix( + system: &[T1], + vars: &[T2], + ) -> Result< + ( + Matrix>, + Matrix>, + ), + String, + > { + AtomView::system_to_matrix::(system, vars) + } + + /// Evaluate a (nested) expression a single time. + /// For repeated evaluations, use [Self::evaluator()] and convert + /// to an optimized version or generate a compiled version of your expression. + /// + /// All variables and all user functions in the expression must occur in the map. + fn evaluate<'b, T: Real, F: Fn(&Rational) -> T + Copy>( + &'b self, + coeff_map: F, + const_map: &HashMap, T>, + function_map: &HashMap>, + cache: &mut HashMap, T>, + ) -> Result { + self.as_atom_view() + .evaluate(coeff_map, const_map, function_map, cache) + } + + /// Convert nested expressions to a tree suitable for repeated evaluations with + /// different values for `params`. + /// All variables and all user functions in the expression must occur in the map. + fn to_evaluation_tree<'a>( + &'a self, + fn_map: &FunctionMap<'a, Rational>, + params: &[Atom], + ) -> Result, String> { + self.as_atom_view().to_evaluation_tree(fn_map, params) + } + + /// Create an efficient evaluator for a (nested) expression. + /// All free parameters must appear in `params` and all other variables + /// and user functions in the expression must occur in the function map. + /// The function map may have nested expressions. + fn evaluator<'a>( + &'a self, + fn_map: &FunctionMap<'a, Rational>, + params: &[Atom], + optimization_settings: OptimizationSettings, + ) -> Result, String> { + let mut tree = self.to_evaluation_tree(fn_map, params)?; + Ok(tree.optimize( + optimization_settings.horner_iterations, + optimization_settings.n_cores, + optimization_settings.hot_start.clone(), + optimization_settings.verbose, + )) + } + + /// Convert nested expressions to a tree suitable for repeated evaluations with + /// different values for `params`. + /// All variables and all user functions in the expression must occur in the map. + fn evaluator_multiple<'a>( + exprs: &[AtomView<'a>], + fn_map: &FunctionMap<'a, Rational>, + params: &[Atom], + optimization_settings: OptimizationSettings, + ) -> Result, String> { + let mut tree = AtomView::to_eval_tree_multiple(exprs, fn_map, params)?; + Ok(tree.optimize( + optimization_settings.horner_iterations, + optimization_settings.n_cores, + optimization_settings.hot_start.clone(), + optimization_settings.verbose, + )) + } + + /// Check if the expression could be 0, using (potentially) numerical sampling with + /// a given tolerance and number of iterations. + fn zero_test(&self, iterations: usize, tolerance: f64) -> ConditionResult { + self.as_atom_view().zero_test(iterations, tolerance) + } + + /// Set the coefficient ring to the multivariate rational polynomial with `vars` variables. + fn set_coefficient_ring(&self, vars: &Arc>) -> Atom { + self.as_atom_view().set_coefficient_ring(vars) + } + + /// Convert all coefficients to floats with a given precision `decimal_prec``. + /// The precision of floating point coefficients in the input will be truncated to `decimal_prec`. + fn coefficients_to_float(&self, decimal_prec: u32) -> Atom { + let mut a = Atom::new(); + self.as_atom_view() + .coefficients_to_float_into(decimal_prec, &mut a); + a + } + + /// Convert all coefficients to floats with a given precision `decimal_prec``. + /// The precision of floating point coefficients in the input will be truncated to `decimal_prec`. + fn coefficients_to_float_into(&self, decimal_prec: u32, out: &mut Atom) { + self.as_atom_view() + .coefficients_to_float_into(decimal_prec, out); + } + + /// Map all coefficients using a given function. + fn map_coefficient Coefficient + Copy>(&self, f: F) -> Atom { + self.as_atom_view().map_coefficient(f) + } + + /// Map all coefficients using a given function. + fn map_coefficient_into Coefficient + Copy>( + &self, + f: F, + out: &mut Atom, + ) { + self.as_atom_view().map_coefficient_into(f, out); + } + + /// Map all floating point and rational coefficients to the best rational approximation + /// in the interval `[self*(1-relative_error),self*(1+relative_error)]`. + fn rationalize_coefficients(&self, relative_error: &Rational) -> Atom { + self.as_atom_view().rationalize_coefficients(relative_error) + } + + /// Convert the atom to a polynomial, optionally in the variable ordering + /// specified by `var_map`. If new variables are encountered, they are + /// added to the variable map. Similarly, non-polynomial parts are automatically + /// defined as a new independent variable in the polynomial. + fn to_polynomial( + &self, + field: &R, + var_map: Option>>, + ) -> MultivariatePolynomial { + self.as_atom_view().to_polynomial(field, var_map) + } + + /// Convert the atom to a polynomial in specific variables. + /// All other parts will be collected into the coefficient, which + /// is a general expression. + /// + /// This routine does not perform expansions. + fn to_polynomial_in_vars( + &self, + var_map: &Arc>, + ) -> MultivariatePolynomial { + self.as_atom_view().to_polynomial_in_vars(var_map) + } + + /// Convert the atom to a rational polynomial, optionally in the variable ordering + /// specified by `var_map`. If new variables are encountered, they are + /// added to the variable map. Similarly, non-rational polynomial parts are automatically + /// defined as a new independent variable in the rational polynomial. + fn to_rational_polynomial< + R: EuclideanDomain + ConvertToRing, + RO: EuclideanDomain + PolynomialGCD, + E: PositiveExponent, + >( + &self, + field: &R, + out_field: &RO, + var_map: Option>>, + ) -> RationalPolynomial + where + RationalPolynomial: + FromNumeratorAndDenominator + FromNumeratorAndDenominator, + { + self.as_atom_view() + .to_rational_polynomial(field, out_field, var_map) + } + + /// Convert the atom to a rational polynomial with factorized denominators, optionally in the variable ordering + /// specified by `var_map`. If new variables are encountered, they are + /// added to the variable map. Similarly, non-rational polynomial parts are automatically + /// defined as a new independent variable in the rational polynomial. + fn to_factorized_rational_polynomial< + R: EuclideanDomain + ConvertToRing, + RO: EuclideanDomain + PolynomialGCD, + E: PositiveExponent, + >( + &self, + field: &R, + out_field: &RO, + var_map: Option>>, + ) -> FactorizedRationalPolynomial + where + FactorizedRationalPolynomial: FromNumeratorAndFactorizedDenominator + + FromNumeratorAndFactorizedDenominator, + MultivariatePolynomial: Factorize, + { + self.as_atom_view() + .to_factorized_rational_polynomial(field, out_field, var_map) + } + + // Format the atom. + fn format( + &self, + fmt: &mut W, + opts: &PrintOptions, + print_state: PrintState, + ) -> Result { + self.as_atom_view().format(fmt, opts, print_state) + } + + /// Construct a printer for the atom with special options. + fn printer<'a>(&'a self, opts: PrintOptions) -> AtomPrinter<'a> { + AtomPrinter::new_with_options(self.as_atom_view(), opts) + } + + /// Print the atom in a form that is unique and independent of any implementation details. + /// + /// Anti-symmetric functions are not supported. + fn to_canonical_string(&self) -> String { + self.as_atom_view().to_canonical_string() + } + + /// Map the function `f` over all terms. + fn map_terms_single_core(&self, f: impl Fn(AtomView) -> Atom) -> Atom { + self.as_atom_view().map_terms_single_core(f) + } + + /// Map the function `f` over all terms, using parallel execution with `n_cores` cores. + fn map_terms(&self, f: impl Fn(AtomView) -> Atom + Send + Sync, n_cores: usize) -> Atom { + self.as_atom_view().map_terms(f, n_cores) + } + + /// Map the function `f` over all terms, using parallel execution with `n_cores` cores. + fn map_terms_with_pool( + &self, + f: impl Fn(AtomView) -> Atom + Send + Sync, + p: &ThreadPool, + ) -> Atom { + self.as_atom_view().map_terms_with_pool(f, p) + } + + /// Canonize (products of) tensors in the expression by relabeling repeated indices. + /// The tensors must be written as functions, with its indices are the arguments. + /// The repeated indices should be provided in `contracted_indices`. + /// + /// If the contracted indices are distinguishable (for example in their dimension), + /// you can provide an optional group marker for each index using `index_group`. + /// This makes sure that an index will not be renamed to an index from a different group. + /// + /// Example + /// ------- + /// ``` + /// # use symbolica::{atom::{Atom, AtomCore}, state::{FunctionAttribute, State}}; + /// # + /// # fn main() { + /// let _ = State::get_symbol_with_attributes("fs", &[FunctionAttribute::Symmetric]).unwrap(); + /// let _ = State::get_symbol_with_attributes("fc", &[FunctionAttribute::Cyclesymmetric]).unwrap(); + /// let a = Atom::parse("fs(mu2,mu3)*fc(mu4,mu2,k1,mu4,k1,mu3)").unwrap(); + /// + /// let mu1 = Atom::parse("mu1").unwrap(); + /// let mu2 = Atom::parse("mu2").unwrap(); + /// let mu3 = Atom::parse("mu3").unwrap(); + /// let mu4 = Atom::parse("mu4").unwrap(); + /// + /// let r = a.canonize_tensors(&[mu1.as_view(), mu2.as_view(), mu3.as_view(), mu4.as_view()], None).unwrap(); + /// println!("{}", r); + /// # } + /// ``` + /// yields `fs(mu1,mu2)*fc(mu1,k1,mu3,k1,mu2,mu3)`. + fn canonize_tensors( + &self, + contracted_indices: &[AtomView], + index_group: Option<&[AtomView]>, + ) -> Result { + self.as_atom_view() + .canonize_tensors(contracted_indices, index_group) + } + + fn to_pattern(&self) -> Pattern { + Pattern::from_view(self.as_atom_view(), true) + } + + /// Get all symbols in the expression, optionally including function symbols. + fn get_all_symbols(&self, include_function_symbols: bool) -> HashSet { + self.as_atom_view() + .get_all_symbols(include_function_symbols) + } + + /// Get all variables and functions in the expression. + fn get_all_indeterminates<'a>(&'a self, enter_functions: bool) -> HashSet> { + self.as_atom_view().get_all_indeterminates(enter_functions) + } + + /// Returns true iff `self` contains the symbol `s`. + fn contains_symbol(&self, s: Symbol) -> bool { + self.as_atom_view().contains_symbol(s) + } + + /// Returns true iff `self` contains `a` literally. + fn contains(&self, s: T) -> bool { + self.as_atom_view().contains(s.as_atom_view()) + } + + /// Check if the expression can be considered a polynomial in some variables, including + /// redefinitions. For example `f(x)+y` is considered a polynomial in `f(x)` and `y`, whereas + /// `f(x)+x` is not a polynomial. + /// + /// Rational powers or powers in variables are not rewritten, e.g. `x^(2y)` is not considered + /// polynomial in `x^y`. + fn is_polynomial( + &self, + allow_not_expanded: bool, + allow_negative_powers: bool, + ) -> Option>> { + self.as_atom_view() + .is_polynomial(allow_not_expanded, allow_negative_powers) + } + + /// Replace all occurrences of the pattern. + fn replace_all( + &self, + pattern: &Pattern, + rhs: R, + conditions: Option<&Condition>, + settings: Option<&MatchSettings>, + ) -> Atom { + self.as_atom_view() + .replace_all(pattern, rhs, conditions, settings) + } + + /// Replace all occurrences of the pattern. + fn replace_all_into( + &self, + pattern: &Pattern, + rhs: R, + conditions: Option<&Condition>, + settings: Option<&MatchSettings>, + out: &mut Atom, + ) -> bool { + self.as_atom_view() + .replace_all_into(pattern, rhs, conditions, settings, out) + } + + /// Replace all occurrences of the patterns, where replacements are tested in the order that they are given. + fn replace_all_multiple(&self, replacements: &[T]) -> Atom { + self.as_atom_view().replace_all_multiple(replacements) + } + + /// Replace all occurrences of the patterns, where replacements are tested in the order that they are given. + /// Returns `true` iff a match was found. + fn replace_all_multiple_into( + &self, + replacements: &[T], + out: &mut Atom, + ) -> bool { + self.as_atom_view() + .replace_all_multiple_into(replacements, out) + } + + /// Replace part of an expression by calling the map `m` on each subexpression. + /// The function `m` must return `true` if the expression was replaced and must write the new expression to `out`. + /// A [Context] object is passed to the function, which contains information about the current position in the expression. + fn replace_map bool>(&self, m: &F) -> Atom { + self.as_atom_view().replace_map(m) + } + + /// Return an iterator that replaces the pattern in the target once. + fn replace_iter<'a>( + &'a self, + pattern: &'a Pattern, + rhs: &'a PatternOrMap, + conditions: &'a Condition, + settings: &'a MatchSettings, + ) -> ReplaceIterator<'a, 'a> { + ReplaceIterator::new(pattern, self.as_atom_view(), rhs, conditions, settings) + } + + /// Return an iterator over matched expressions. + fn pattern_match<'a>( + &'a self, + pattern: &'a Pattern, + conditions: &'a Condition, + settings: &'a MatchSettings, + ) -> PatternAtomTreeIterator<'a, 'a> { + PatternAtomTreeIterator::new(pattern, self.as_atom_view(), conditions, settings) + } +} + +impl<'a> AtomCore for AtomView<'a> { + fn as_atom_view(&self) -> AtomView<'a> { + *self + } +} + +impl AtomCore for InlineVar { + fn as_atom_view(&self) -> AtomView { + self.as_view() + } +} + +impl AtomCore for InlineNum { + fn as_atom_view(&self) -> AtomView { + self.as_view() + } +} + +impl> AtomCore for T { + fn as_atom_view(&self) -> AtomView { + self.as_ref().as_view() + } +} + +impl<'a> AtomCore for AtomOrView<'a> { + fn as_atom_view(&self) -> AtomView { + self.as_view() + } +} diff --git a/src/coefficient.rs b/src/coefficient.rs index b6982def..06fb5aec 100644 --- a/src/coefficient.rs +++ b/src/coefficient.rs @@ -1093,51 +1093,9 @@ impl<'a> TryFrom> for Float { } } -impl Atom { - /// Set the coefficient ring to the multivariate rational polynomial with `vars` variables. - pub fn set_coefficient_ring(&self, vars: &Arc>) -> Atom { - self.as_view().set_coefficient_ring(vars) - } - - /// Convert all coefficients to floats with a given precision `decimal_prec``. - /// The precision of floating point coefficients in the input will be truncated to `decimal_prec`. - pub fn coefficients_to_float(&self, decimal_prec: u32) -> Atom { - let mut a = Atom::new(); - self.as_view() - .coefficients_to_float_into(decimal_prec, &mut a); - a - } - - /// Convert all coefficients to floats with a given precision `decimal_prec``. - /// The precision of floating point coefficients in the input will be truncated to `decimal_prec`. - pub fn coefficients_to_float_into(&self, decimal_prec: u32, out: &mut Atom) { - self.as_view().coefficients_to_float_into(decimal_prec, out); - } - - /// Map all coefficients using a given function. - pub fn map_coefficient Coefficient + Copy>(&self, f: F) -> Atom { - self.as_view().map_coefficient(f) - } - - /// Map all coefficients using a given function. - pub fn map_coefficient_into Coefficient + Copy>( - &self, - f: F, - out: &mut Atom, - ) { - self.as_view().map_coefficient_into(f, out); - } - - /// Map all floating point and rational coefficients to the best rational approximation - /// in the interval `[self*(1-relative_error),self*(1+relative_error)]`. - pub fn rationalize_coefficients(&self, relative_error: &Rational) -> Atom { - self.as_view().rationalize_coefficients(relative_error) - } -} - impl<'a> AtomView<'a> { /// Set the coefficient ring to the multivariate rational polynomial with `vars` variables. - pub fn set_coefficient_ring(&self, vars: &Arc>) -> Atom { + pub(crate) fn set_coefficient_ring(&self, vars: &Arc>) -> Atom { Workspace::get_local().with(|ws| { let mut out = ws.new_atom(); self.set_coefficient_ring_with_ws_into(vars, ws, &mut out); @@ -1146,7 +1104,7 @@ impl<'a> AtomView<'a> { } /// Set the coefficient ring to the multivariate rational polynomial with `vars` variables. - pub fn set_coefficient_ring_with_ws_into( + pub(crate) fn set_coefficient_ring_with_ws_into( &self, vars: &Arc>, workspace: &Workspace, @@ -1312,14 +1270,7 @@ impl<'a> AtomView<'a> { /// Convert all coefficients to floats with a given precision `decimal_prec``. /// The precision of floating point coefficients in the input will be truncated to `decimal_prec`. - pub fn coefficients_to_float(&self, decimal_prec: u32) -> Atom { - let mut a = Atom::new(); - self.coefficients_to_float_into(decimal_prec, &mut a); - a - } - /// Convert all coefficients to floats with a given precision `decimal_prec``. - /// The precision of floating point coefficients in the input will be truncated to `decimal_prec`. - pub fn coefficients_to_float_into(&self, decimal_prec: u32, out: &mut Atom) { + pub(crate) fn coefficients_to_float_into(&self, decimal_prec: u32, out: &mut Atom) { let binary_prec = (decimal_prec as f64 * LOG2_10).ceil() as u32; Workspace::get_local().with(|ws| self.to_float_impl(binary_prec, true, false, ws, out)) @@ -1442,7 +1393,7 @@ impl<'a> AtomView<'a> { /// Map all floating point and rational coefficients to the best rational approximation /// in the interval `[self*(1-relative_error),self*(1+relative_error)]`. - pub fn rationalize_coefficients(&self, relative_error: &Rational) -> Atom { + pub(crate) fn rationalize_coefficients(&self, relative_error: &Rational) -> Atom { let mut a = Atom::new(); Workspace::get_local().with(|ws| { self.map_coefficient_impl( @@ -1467,14 +1418,17 @@ impl<'a> AtomView<'a> { } /// Map all coefficients using a given function. - pub fn map_coefficient Coefficient + Copy>(&self, f: F) -> Atom { + pub(crate) fn map_coefficient Coefficient + Copy>( + &self, + f: F, + ) -> Atom { let mut a = Atom::new(); self.map_coefficient_into(f, &mut a); a } /// Map all coefficients using a given function. - pub fn map_coefficient_into Coefficient + Copy>( + pub(crate) fn map_coefficient_into Coefficient + Copy>( &self, f: F, out: &mut Atom, @@ -1574,7 +1528,7 @@ mod test { use std::sync::Arc; use crate::{ - atom::Atom, + atom::{Atom, AtomCore}, domains::float::Float, printer::{AtomPrinter, PrintOptions}, state::State, diff --git a/src/collect.rs b/src/collect.rs index 8803067b..719f4c28 100644 --- a/src/collect.rs +++ b/src/collect.rs @@ -1,5 +1,5 @@ use crate::{ - atom::{Add, AsAtomView, Atom, AtomView, Symbol}, + atom::{Add, Atom, AtomCore, AtomView, Symbol}, coefficient::{Coefficient, CoefficientView}, domains::{integer::Z, rational::Q}, poly::{factor::Factorize, polynomial::MultivariatePolynomial, Exponent}, @@ -7,83 +7,6 @@ use crate::{ }; use std::sync::Arc; -impl Atom { - /// Collect terms involving the same power of `x`, where `x` is a variable or function, e.g. - /// - /// ```math - /// collect(x + x * y + x^2, x) = x * (1+y) + x^2 - /// ``` - /// - /// Both the *key* (the quantity collected in) and its coefficient can be mapped using - /// `key_map` and `coeff_map` respectively. - pub fn collect( - &self, - x: T, - key_map: Option>, - coeff_map: Option>, - ) -> Atom { - self.as_view().collect::(x, key_map, coeff_map) - } - - /// Collect terms involving the same power of `x`, where `x` is a variable or function, e.g. - /// - /// ```math - /// collect(x + x * y + x^2, x) = x * (1+y) + x^2 - /// ``` - /// - /// Both the *key* (the quantity collected in) and its coefficient can be mapped using - /// `key_map` and `coeff_map` respectively. - pub fn collect_multiple( - &self, - xs: &[T], - key_map: Option>, - coeff_map: Option>, - ) -> Atom { - self.as_view() - .collect_multiple::(xs, key_map, coeff_map) - } - - /// Collect terms involving the same power of `x` in `xs`, where `xs` is a list of indeterminates. - /// Return the list of key-coefficient pairs - pub fn coefficient_list(&self, xs: &[T]) -> Vec<(Atom, Atom)> { - self.as_view().coefficient_list::(xs) - } - - /// Collect terms involving the literal occurrence of `x`. - pub fn coefficient(&self, x: T) -> Atom { - Workspace::get_local().with(|ws| self.as_view().coefficient_with_ws(x.as_atom_view(), ws)) - } - - /// Write the expression over a common denominator. - pub fn together(&self) -> Atom { - self.as_view().together() - } - - /// Write the expression as a sum of terms with minimal denominators. - pub fn apart(&self, x: Symbol) -> Atom { - self.as_view().apart(x) - } - - /// Cancel all common factors between numerators and denominators. - /// Any non-canceling parts of the expression will not be rewritten. - pub fn cancel(&self) -> Atom { - self.as_view().cancel() - } - - /// Factor the expression over the rationals. - pub fn factor(&self) -> Atom { - self.as_view().factor() - } - - /// Collect numerical factors by removing the numerical content from additions. - /// For example, `-2*x + 4*x^2 + 6*x^3` will be transformed into `-2*(x - 2*x^2 - 3*x^3)`. - /// - /// The first argument of the addition is normalized to a positive quantity. - pub fn collect_num(&self) -> Atom { - self.as_view().collect_num() - } -} - impl<'a> AtomView<'a> { /// Collect terms involving the same power of `x`, where `x` is an indeterminate, e.g. /// @@ -93,7 +16,7 @@ impl<'a> AtomView<'a> { /// /// Both the *key* (the quantity collected in) and its coefficient can be mapped using /// `key_map` and `coeff_map` respectively. - pub fn collect( + pub(crate) fn collect( &self, x: T, key_map: Option>, @@ -102,7 +25,7 @@ impl<'a> AtomView<'a> { self.collect_multiple::(std::slice::from_ref(&x), key_map, coeff_map) } - pub fn collect_multiple( + pub(crate) fn collect_multiple( &self, xs: &[T], key_map: Option>, @@ -114,7 +37,7 @@ impl<'a> AtomView<'a> { out } - pub fn collect_multiple_impl( + pub(crate) fn collect_multiple_impl( &self, xs: &[T], ws: &Workspace, @@ -166,7 +89,7 @@ impl<'a> AtomView<'a> { /// Collect terms involving the same powers of `x` in `xs`, where `x` is an indeterminate. /// Return the list of key-coefficient pairs. - pub fn coefficient_list(&self, xs: &[T]) -> Vec<(Atom, Atom)> { + pub(crate) fn coefficient_list(&self, xs: &[T]) -> Vec<(Atom, Atom)> { let vars = xs .iter() .map(|x| x.as_atom_view().to_owned().into()) @@ -190,11 +113,6 @@ impl<'a> AtomView<'a> { coeffs } - /// Collect terms involving the literal occurrence of `x`. - pub fn coefficient(&self, x: T) -> Atom { - Workspace::get_local().with(|ws| self.coefficient_with_ws(x.as_atom_view(), ws)) - } - /// Collect terms involving the literal occurrence of `x`. pub fn coefficient_with_ws(&self, x: AtomView<'_>, workspace: &Workspace) -> Atom { let mut coeffs = workspace.new_atom(); @@ -312,7 +230,7 @@ impl<'a> AtomView<'a> { /// Cancel all common factors between numerators and denominators. /// Any non-canceling parts of the expression will not be rewritten. - pub fn cancel(&self) -> Atom { + pub(crate) fn cancel(&self) -> Atom { let mut out = Atom::new(); self.cancel_into(&mut out); out @@ -320,7 +238,7 @@ impl<'a> AtomView<'a> { /// Cancel all common factors between numerators and denominators. /// Any non-canceling parts of the expression will not be rewritten. - pub fn cancel_into(&self, out: &mut Atom) { + pub(crate) fn cancel_into(&self, out: &mut Atom) { Workspace::get_local().with(|ws| { self.cancel_with_ws_into(ws, out); }); @@ -656,7 +574,7 @@ impl<'a> AtomView<'a> { #[cfg(test)] mod test { use crate::{ - atom::{representation::InlineVar, Atom}, + atom::{representation::InlineVar, Atom, AtomCore}, fun, state::State, }; diff --git a/src/derivative.rs b/src/derivative.rs index 413694a8..95005064 100644 --- a/src/derivative.rs +++ b/src/derivative.rs @@ -4,7 +4,7 @@ use std::{ }; use crate::{ - atom::{AsAtomView, Atom, AtomView, FunctionBuilder, Symbol}, + atom::{Atom, AtomView, FunctionBuilder, Symbol}, coefficient::{Coefficient, CoefficientView}, combinatorics::CombinationWithReplacementIterator, domains::{atom::AtomField, integer::Integer, rational::Rational}, @@ -12,35 +12,9 @@ use crate::{ state::Workspace, }; -impl Atom { - /// Take a derivative of the expression with respect to `x`. - pub fn derivative(&self, x: Symbol) -> Atom { - self.as_view().derivative(x) - } - - /// Take a derivative of the expression with respect to `x` and - /// write the result in `out`. - /// Returns `true` if the derivative is non-zero. - pub fn derivative_into(&self, x: Symbol, out: &mut Atom) -> bool { - self.as_view().derivative_into(x, out) - } - - /// Series expand in `x` around `expansion_point` to depth `depth`. - pub fn series( - &self, - x: Symbol, - expansion_point: T, - depth: Rational, - depth_is_absolute: bool, - ) -> Result, &'static str> { - self.as_view() - .series(x, expansion_point.as_atom_view(), depth, depth_is_absolute) - } -} - impl<'a> AtomView<'a> { /// Take a derivative of the expression with respect to `x`. - pub fn derivative(&self, x: Symbol) -> Atom { + pub(crate) fn derivative(&self, x: Symbol) -> Atom { Workspace::get_local().with(|ws| { let mut out = ws.new_atom(); self.derivative_with_ws_into(x, ws, &mut out); @@ -51,14 +25,14 @@ impl<'a> AtomView<'a> { /// Take a derivative of the expression with respect to `x` and /// write the result in `out`. /// Returns `true` if the derivative is non-zero. - pub fn derivative_into(&self, x: Symbol, out: &mut Atom) -> bool { + pub(crate) fn derivative_into(&self, x: Symbol, out: &mut Atom) -> bool { Workspace::get_local().with(|ws| self.derivative_with_ws_into(x, ws, out)) } /// Take a derivative of the expression with respect to `x` and /// write the result in `out`. /// Returns `true` if the derivative is non-zero. - pub fn derivative_with_ws_into( + pub(crate) fn derivative_with_ws_into( &self, x: Symbol, workspace: &Workspace, @@ -762,7 +736,10 @@ impl Sub<&Atom> for &Series { #[cfg(test)] mod test { - use crate::{atom::Atom, state::State}; + use crate::{ + atom::{Atom, AtomCore}, + state::State, + }; #[test] fn derivative() { diff --git a/src/domains/algebraic_number.rs b/src/domains/algebraic_number.rs index a13116ce..753db213 100644 --- a/src/domains/algebraic_number.rs +++ b/src/domains/algebraic_number.rs @@ -644,7 +644,7 @@ impl, E: PositiveExponent> #[cfg(test)] mod tests { - use crate::atom::Atom; + use crate::atom::{Atom, AtomCore}; use crate::domains::algebraic_number::AlgebraicExtension; use crate::domains::finite_field::{PrimeIteratorU64, Zp, Z2}; use crate::domains::rational::Q; diff --git a/src/domains/atom.rs b/src/domains/atom.rs index 71384881..ec3af278 100644 --- a/src/domains/atom.rs +++ b/src/domains/atom.rs @@ -1,5 +1,5 @@ use crate::{ - atom::{Atom, AtomView}, + atom::{Atom, AtomCore, AtomView}, poly::Variable, }; diff --git a/src/domains/finite_field.rs b/src/domains/finite_field.rs index fe1cf82a..b86db16e 100644 --- a/src/domains/finite_field.rs +++ b/src/domains/finite_field.rs @@ -172,7 +172,7 @@ pub trait FiniteFieldCore: Field { /// `m` will be a prime, and the domain will be a field. /// /// [Zp] ([`FiniteField`]) and [Zp64] ([`FiniteField`]) use Montgomery modular arithmetic -/// to increase the performance of the multiplication operator. For the prime `2`, use [Z2] instead. +/// to increase the performance of the multiplication operator. For the prime `2`, use [type@Z2] instead. /// /// For `m` larger than `2^64`, use [`FiniteField`]. /// diff --git a/src/domains/rational.rs b/src/domains/rational.rs index bb5b5ef3..e190c84d 100644 --- a/src/domains/rational.rs +++ b/src/domains/rational.rs @@ -1147,7 +1147,7 @@ impl<'a> std::iter::Sum<&'a Self> for Rational { #[cfg(test)] mod test { use crate::{ - atom::Atom, + atom::{Atom, AtomCore}, domains::{ integer::Z, rational::{FractionField, Rational, Q}, diff --git a/src/domains/rational_polynomial.rs b/src/domains/rational_polynomial.rs index 2c7e91bc..34143d8b 100644 --- a/src/domains/rational_polynomial.rs +++ b/src/domains/rational_polynomial.rs @@ -1283,6 +1283,7 @@ mod test { use std::sync::Arc; use crate::{ + atom::AtomCore, domains::{integer::Z, rational::Q, rational_polynomial::RationalPolynomial, Ring}, state::State, }; diff --git a/src/evaluate.rs b/src/evaluate.rs index 48ad50dd..036a5e8b 100644 --- a/src/evaluate.rs +++ b/src/evaluate.rs @@ -176,78 +176,6 @@ impl Default for OptimizationSettings { } } -impl Atom { - /// Evaluate a (nested) expression a single time. - /// For repeated evaluations, use [Self::evaluator()] and convert - /// to an optimized version or generate a compiled version of your expression. - /// - /// All variables and all user functions in the expression must occur in the map. - pub fn evaluate<'b, T: Real, F: Fn(&Rational) -> T + Copy>( - &'b self, - coeff_map: F, - const_map: &HashMap, T>, - function_map: &HashMap>, - cache: &mut HashMap, T>, - ) -> Result { - self.as_view() - .evaluate(coeff_map, const_map, function_map, cache) - } - - /// Convert nested expressions to a tree suitable for repeated evaluations with - /// different values for `params`. - /// All variables and all user functions in the expression must occur in the map. - pub fn to_evaluation_tree<'a>( - &'a self, - fn_map: &FunctionMap<'a, Rational>, - params: &[Atom], - ) -> Result, String> { - self.as_view().to_evaluation_tree(fn_map, params) - } - - /// Create an efficient evaluator for a (nested) expression. - /// All free parameters must appear in `params` and all other variables - /// and user functions in the expression must occur in the function map. - /// The function map may have nested expressions. - pub fn evaluator<'a>( - &'a self, - fn_map: &FunctionMap<'a, Rational>, - params: &[Atom], - optimization_settings: OptimizationSettings, - ) -> Result, String> { - let mut tree = self.to_evaluation_tree(fn_map, params)?; - Ok(tree.optimize( - optimization_settings.horner_iterations, - optimization_settings.n_cores, - optimization_settings.hot_start.clone(), - optimization_settings.verbose, - )) - } - - /// Convert nested expressions to a tree suitable for repeated evaluations with - /// different values for `params`. - /// All variables and all user functions in the expression must occur in the map. - pub fn evaluator_multiple<'a>( - exprs: &[AtomView<'a>], - fn_map: &FunctionMap<'a, Rational>, - params: &[Atom], - optimization_settings: OptimizationSettings, - ) -> Result, String> { - let mut tree = AtomView::to_eval_tree_multiple(exprs, fn_map, params)?; - Ok(tree.optimize( - optimization_settings.horner_iterations, - optimization_settings.n_cores, - optimization_settings.hot_start.clone(), - optimization_settings.verbose, - )) - } - - /// Check if the expression could be 0, using (potentially) numerical sampling with - /// a given tolerance and number of iterations. - pub fn zero_test(&self, iterations: usize, tolerance: f64) -> ConditionResult { - self.as_view().zero_test(iterations, tolerance) - } -} - #[derive(Debug, Clone)] pub struct SplitExpression { pub tree: Vec>, @@ -4329,7 +4257,7 @@ mod test { use ahash::HashMap; use crate::{ - atom::Atom, + atom::{Atom, AtomCore}, domains::{float::Float, rational::Rational}, evaluate::{EvaluationFn, FunctionMap, OptimizationSettings}, id::ConditionResult, diff --git a/src/expand.rs b/src/expand.rs index 5c8931a7..6130c09c 100644 --- a/src/expand.rs +++ b/src/expand.rs @@ -3,7 +3,7 @@ use std::{ops::DerefMut, sync::Arc}; use smallvec::SmallVec; use crate::{ - atom::{representation::InlineVar, AsAtomView, Atom, AtomView, Symbol}, + atom::{Atom, AtomView}, coefficient::CoefficientView, combinatorics::CombinationWithReplacementIterator, domains::{integer::Integer, rational::Q}, @@ -11,46 +11,9 @@ use crate::{ state::{RecycledAtom, Workspace}, }; -impl Atom { - /// Expand an expression. The function [expand_via_poly] may be faster. - pub fn expand(&self) -> Atom { - self.as_view().expand() - } - - /// Expand the expression by converting it to a polynomial, optionally - /// only in the indeterminate `var`. The parameter `E` should be a numerical type - /// that fits the largest exponent in the expanded expression. Often, - /// `u8` or `u16` is sufficient. - pub fn expand_via_poly(&self, var: Option) -> Atom { - self.as_view() - .expand_via_poly::(var.as_ref().map(|x| x.as_atom_view())) - } - - /// Expand an expression in the variable `var`. The function [expand_via_poly] may be faster. - pub fn expand_in(&self, var: T) -> Atom { - self.as_view().expand_in(var.as_atom_view()) - } - - /// Expand an expression in the variable `var`. - pub fn expand_in_symbol(&self, var: Symbol) -> Atom { - self.as_view().expand_in(InlineVar::from(var).as_view()) - } - - /// Expand an expression, returning `true` iff the expression changed. - pub fn expand_into(&self, out: &mut Atom) -> bool { - self.as_view().expand_into(None, out) - } - - /// Distribute numbers in the expression, for example: - /// `2*(x+y)` -> `2*x+2*y`. - pub fn expand_num(&self) -> Atom { - self.as_view().expand_num() - } -} - impl<'a> AtomView<'a> { /// Expand an expression. The function [expand_via_poly] may be faster. - pub fn expand(&self) -> Atom { + pub(crate) fn expand(&self) -> Atom { Workspace::get_local().with(|ws| { let mut a = ws.new_atom(); self.expand_with_ws_into(ws, None, &mut a); @@ -59,7 +22,7 @@ impl<'a> AtomView<'a> { } /// Expand an expression. The function [expand_via_poly] may be faster. - pub fn expand_in(&self, var: AtomView) -> Atom { + pub(crate) fn expand_in(&self, var: AtomView) -> Atom { Workspace::get_local().with(|ws| { let mut a = ws.new_atom(); self.expand_with_ws_into(ws, Some(var), &mut a); @@ -68,7 +31,7 @@ impl<'a> AtomView<'a> { } /// Expand an expression, returning `true` iff the expression changed. - pub fn expand_into(&self, var: Option, out: &mut Atom) -> bool { + pub(crate) fn expand_into(&self, var: Option, out: &mut Atom) -> bool { Workspace::get_local().with(|ws| self.expand_with_ws_into(ws, var, out)) } @@ -141,7 +104,7 @@ impl<'a> AtomView<'a> { /// only in the indeterminate `var`. The parameter `E` should be a numerical type /// that fits the largest exponent in the expanded expression. Often, /// `u8` or `u16` is sufficient. - pub fn expand_via_poly(&self, var: Option) -> Atom { + pub(crate) fn expand_via_poly(&self, var: Option) -> Atom { let var_map = var.map(|v| Arc::new(vec![v.to_owned().into()])); let mut out = Atom::new(); @@ -432,7 +395,7 @@ impl<'a> AtomView<'a> { /// Distribute numbers in the expression, for example: /// `2*(x+y)` -> `2*x+2*y`. - pub fn expand_num(&self) -> Atom { + pub(crate) fn expand_num(&self) -> Atom { let mut a = Atom::new(); Workspace::get_local().with(|ws| { self.expand_num_impl(ws, &mut a); @@ -440,13 +403,13 @@ impl<'a> AtomView<'a> { a } - pub fn expand_num_into(&self, out: &mut Atom) { + pub(crate) fn expand_num_into(&self, out: &mut Atom) { Workspace::get_local().with(|ws| { self.expand_with_ws_into(ws, None, out); }) } - pub fn expand_num_impl(&self, ws: &Workspace, out: &mut Atom) -> bool { + pub(crate) fn expand_num_impl(&self, ws: &Workspace, out: &mut Atom) -> bool { match self { AtomView::Num(_) | AtomView::Var(_) | AtomView::Fun(_) => { out.set_from_view(self); @@ -549,7 +512,10 @@ impl<'a> AtomView<'a> { #[cfg(test)] mod test { - use crate::{atom::Atom, state::State}; + use crate::{ + atom::{Atom, AtomCore}, + state::State, + }; #[test] fn expand_num() { diff --git a/src/id.rs b/src/id.rs index eecf80eb..f225e1d1 100644 --- a/src/id.rs +++ b/src/id.rs @@ -6,7 +6,7 @@ use dyn_clone::DynClone; use crate::{ atom::{ representation::{InlineVar, ListSlice}, - AsAtomView, Atom, AtomType, AtomView, Num, SliceType, Symbol, + Atom, AtomCore, AtomType, AtomView, Num, SliceType, Symbol, }, state::Workspace, transformer::{Transformer, TransformerError}, @@ -194,121 +194,6 @@ impl<'a> BorrowReplacement for BorrowedReplacement<'a> { } } -impl Atom { - pub fn to_pattern(&self) -> Pattern { - Pattern::from_view(self.as_view(), true) - } - - /// Get all symbols in the expression, optionally including function symbols. - pub fn get_all_symbols(&self, include_function_symbols: bool) -> HashSet { - let mut out = HashSet::default(); - self.as_view() - .get_all_symbols_impl(include_function_symbols, &mut out); - out - } - - /// Get all variables and functions in the expression. - pub fn get_all_indeterminates<'a>(&'a self, enter_functions: bool) -> HashSet> { - let mut out = HashSet::default(); - self.as_view() - .get_all_indeterminates_impl(enter_functions, &mut out); - out - } - - /// Returns true iff `self` contains the symbol `s`. - pub fn contains_symbol(&self, s: Symbol) -> bool { - self.as_view().contains_symbol(s) - } - - /// Returns true iff `self` contains `a` literally. - pub fn contains(&self, s: T) -> bool { - self.as_view().contains(s.as_atom_view()) - } - - /// Check if the expression can be considered a polynomial in some variables, including - /// redefinitions. For example `f(x)+y` is considered a polynomial in `f(x)` and `y`, whereas - /// `f(x)+x` is not a polynomial. - /// - /// Rational powers or powers in variables are not rewritten, e.g. `x^(2y)` is not considered - /// polynomial in `x^y`. - pub fn is_polynomial( - &self, - allow_not_expanded: bool, - allow_negative_powers: bool, - ) -> Option>> { - self.as_view() - .is_polynomial(allow_not_expanded, allow_negative_powers) - } - - /// Replace all occurrences of the pattern. - pub fn replace_all( - &self, - pattern: &Pattern, - rhs: R, - conditions: Option<&Condition>, - settings: Option<&MatchSettings>, - ) -> Atom { - self.as_view() - .replace_all(pattern, rhs, conditions, settings) - } - - /// Replace all occurrences of the pattern. - pub fn replace_all_into( - &self, - pattern: &Pattern, - rhs: R, - conditions: Option<&Condition>, - settings: Option<&MatchSettings>, - out: &mut Atom, - ) -> bool { - self.as_view() - .replace_all_into(pattern, rhs, conditions, settings, out) - } - - /// Replace all occurrences of the patterns, where replacements are tested in the order that they are given. - pub fn replace_all_multiple(&self, replacements: &[T]) -> Atom { - self.as_view().replace_all_multiple(replacements) - } - - /// Replace all occurrences of the patterns, where replacements are tested in the order that they are given. - /// Returns `true` iff a match was found. - pub fn replace_all_multiple_into( - &self, - replacements: &[T], - out: &mut Atom, - ) -> bool { - self.as_view().replace_all_multiple_into(replacements, out) - } - - /// Replace part of an expression by calling the map `m` on each subexpression. - /// The function `m` must return `true` if the expression was replaced and must write the new expression to `out`. - /// A [Context] object is passed to the function, which contains information about the current position in the expression. - pub fn replace_map bool>(&self, m: &F) -> Atom { - self.as_view().replace_map(m) - } - - /// Return an iterator that replaces the pattern in the target once. - pub fn replace_iter<'a>( - &'a self, - pattern: &'a Pattern, - rhs: &'a PatternOrMap, - conditions: &'a Condition, - settings: &'a MatchSettings, - ) -> ReplaceIterator<'a, 'a> { - ReplaceIterator::new(pattern, self.as_view(), rhs, conditions, settings) - } - - /// Return an iterator over matched expressions. - pub fn pattern_match<'a>( - &'a self, - pattern: &'a Pattern, - conditions: &'a Condition, - settings: &'a MatchSettings, - ) -> PatternAtomTreeIterator<'a, 'a> { - PatternAtomTreeIterator::new(pattern, self.as_view(), conditions, settings) - } -} - /// The context of an atom. #[derive(Clone, Copy, Debug)] pub struct Context { @@ -321,18 +206,22 @@ pub struct Context { } impl<'a> AtomView<'a> { - pub fn into_pattern(self) -> Pattern { + pub(crate) fn to_pattern(self) -> Pattern { Pattern::from_view(self, true) } /// Get all symbols in the expression, optionally including function symbols. - pub fn get_all_symbols(&self, include_function_symbols: bool) -> HashSet { + pub(crate) fn get_all_symbols(&self, include_function_symbols: bool) -> HashSet { let mut out = HashSet::default(); self.get_all_symbols_impl(include_function_symbols, &mut out); out } - fn get_all_symbols_impl(&self, include_function_symbols: bool, out: &mut HashSet) { + pub(crate) fn get_all_symbols_impl( + &self, + include_function_symbols: bool, + out: &mut HashSet, + ) { match self { AtomView::Num(_) => {} AtomView::Var(v) => { @@ -365,7 +254,7 @@ impl<'a> AtomView<'a> { } /// Get all variables and functions in the expression. - pub fn get_all_indeterminates(&self, enter_functions: bool) -> HashSet> { + pub(crate) fn get_all_indeterminates(&self, enter_functions: bool) -> HashSet> { let mut out = HashSet::default(); self.get_all_indeterminates_impl(enter_functions, &mut out); out @@ -405,7 +294,7 @@ impl<'a> AtomView<'a> { } /// Returns true iff `self` contains `a` literally. - pub fn contains(&self, a: T) -> bool { + pub(crate) fn contains(&self, a: T) -> bool { let mut stack = Vec::with_capacity(20); stack.push(*self); @@ -443,7 +332,7 @@ impl<'a> AtomView<'a> { } /// Returns true iff `self` contains the symbol `s`. - pub fn contains_symbol(&self, s: Symbol) -> bool { + pub(crate) fn contains_symbol(&self, s: Symbol) -> bool { let mut stack = Vec::with_capacity(20); stack.push(*self); while let Some(c) = stack.pop() { @@ -489,7 +378,7 @@ impl<'a> AtomView<'a> { /// /// Rational powers or powers in variables are not rewritten, e.g. `x^(2y)` is not considered /// polynomial in `x^y`. - pub fn is_polynomial( + pub(crate) fn is_polynomial( &self, allow_not_expanded: bool, allow_negative_powers: bool, @@ -630,7 +519,7 @@ impl<'a> AtomView<'a> { /// Replace part of an expression by calling the map `m` on each subexpression. /// The function `m` must return `true` if the expression was replaced and must write the new expression to `out`. /// A [Context] object is passed to the function, which contains information about the current position in the expression. - pub fn replace_map bool>(&self, m: &F) -> Atom { + pub(crate) fn replace_map bool>(&self, m: &F) -> Atom { let mut out = Atom::new(); self.replace_map_into(m, &mut out); out @@ -639,7 +528,7 @@ impl<'a> AtomView<'a> { /// Replace part of an expression by calling the map `m` on each subexpression. /// The function `m` must return `true` if the expression was replaced and must write the new expression to `out`. /// A [Context] object is passed to the function, which contains information about the current position in the expression. - pub fn replace_map_into bool>( + pub(crate) fn replace_map_into bool>( &self, m: &F, out: &mut Atom, @@ -756,7 +645,7 @@ impl<'a> AtomView<'a> { } /// Replace all occurrences of the patterns, where replacements are tested in the order that they are given. - pub fn replace_all( + pub(crate) fn replace_all( &self, pattern: &Pattern, rhs: R, @@ -769,7 +658,7 @@ impl<'a> AtomView<'a> { } /// Replace all occurrences of the patterns, where replacements are tested in the order that they are given. - pub fn replace_all_into( + pub(crate) fn replace_all_into( &self, pattern: &Pattern, rhs: R, @@ -783,7 +672,7 @@ impl<'a> AtomView<'a> { } /// Replace all occurrences of the patterns, where replacements are tested in the order that they are given. - pub fn replace_all_multiple(&self, replacements: &[T]) -> Atom { + pub(crate) fn replace_all_multiple(&self, replacements: &[T]) -> Atom { let mut out = Atom::new(); self.replace_all_multiple_into(replacements, &mut out); out @@ -791,7 +680,7 @@ impl<'a> AtomView<'a> { /// Replace all occurrences of the patterns, where replacements are tested in the order that they are given. /// Returns `true` iff a match was found. - pub fn replace_all_multiple_into( + pub(crate) fn replace_all_multiple_into( &self, replacements: &[T], out: &mut Atom, @@ -1031,29 +920,9 @@ impl<'a> AtomView<'a> { submatch } - /// Return an iterator that replaces the pattern in the target once. - pub fn replace_iter( - &self, - pattern: &'a Pattern, - rhs: &'a PatternOrMap, - conditions: &'a Condition, - settings: &'a MatchSettings, - ) -> ReplaceIterator<'a, 'a> { - ReplaceIterator::new(pattern, *self, rhs, conditions, settings) - } - - pub fn pattern_match( - &self, - pattern: &'a Pattern, - conditions: &'a Condition, - settings: &'a MatchSettings, - ) -> PatternAtomTreeIterator<'a, 'a> { - PatternAtomTreeIterator::new(pattern, *self, conditions, settings) - } - /// Replace all occurrences of the pattern in the target, returning `true` iff a match was found. /// For every matched atom, the first canonical match is used and then the atom is skipped. - pub fn replace_all_with_ws_into( + pub(crate) fn replace_all_with_ws_into( &self, pattern: &Pattern, rhs: BorrowedPatternOrMap, @@ -1388,7 +1257,7 @@ impl Pattern { } /// Create a pattern from an atom view. - fn from_view(atom: AtomView<'_>, is_top_layer: bool) -> Pattern { + pub(crate) fn from_view(atom: AtomView<'_>, is_top_layer: bool) -> Pattern { // split up Add and Mul for literal patterns as well so that x+y can match to x+y+z if Self::has_wildcard(atom) || is_top_layer && matches!(atom, AtomView::Mul(_) | AtomView::Add(_)) @@ -1956,7 +1825,7 @@ impl Evaluate for Relation { let c = Condition::default(); let s = MatchSettings::default(); let m = MatchStack::new(&c, &s); - let pat = state.map(|x| x.into_pattern()); + let pat = state.map(|x| x.to_pattern()); Ok(match self { Relation::Eq(a, b) @@ -3702,7 +3571,7 @@ impl<'a: 'b, 'b> ReplaceIterator<'a, 'b> { #[cfg(test)] mod test { use crate::{ - atom::Atom, + atom::{Atom, AtomCore}, id::{ ConditionResult, MatchSettings, PatternOrMap, PatternRestriction, Replacement, WildcardRestriction, diff --git a/src/lib.rs b/src/lib.rs index 6e62af7d..cc5a01cf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,7 +7,7 @@ //! For example: //! //! ``` -//! use symbolica::{atom::Atom, state::State}; +//! use symbolica::{atom::Atom, atom::AtomCore, state::State}; //! //! fn main() { //! let input = Atom::parse("x^2*log(2*x + y) + exp(3*x)").unwrap(); diff --git a/src/normalize.rs b/src/normalize.rs index cfb2653c..06e13ac8 100644 --- a/src/normalize.rs +++ b/src/normalize.rs @@ -512,7 +512,7 @@ impl Atom { /// Merge two terms if possible. If this function returns `true`, `self` /// will have been updated by the merge from `other` and `other` should be discarded. /// If the function return `false`, no merge was possible and no modifications were made. - pub fn merge_terms(&mut self, other: AtomView, helper: &mut Self) -> bool { + pub(crate) fn merge_terms(&mut self, other: AtomView, helper: &mut Self) -> bool { if let Atom::Num(n1) = self { if let AtomView::Num(n2) = other { n1.add(&n2); diff --git a/src/parser.rs b/src/parser.rs index 322e223e..9be03c6e 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -1298,7 +1298,7 @@ mod test { use std::sync::Arc; use crate::{ - atom::Atom, + atom::{Atom, AtomCore}, domains::integer::Z, parser::Token, printer::{AtomPrinter, PrintOptions}, diff --git a/src/poly.rs b/src/poly.rs index 1832ce56..1841ef52 100644 --- a/src/poly.rs +++ b/src/poly.rs @@ -19,7 +19,7 @@ use ahash::HashMap; use smallvec::{smallvec, SmallVec}; use smartstring::{LazyCompact, SmartString}; -use crate::atom::{Atom, AtomView, Symbol}; +use crate::atom::{Atom, AtomCore, AtomView, Symbol}; use crate::coefficient::{Coefficient, CoefficientView, ConvertToRing}; use crate::domains::atom::AtomField; use crate::domains::factorized_rational_polynomial::{ @@ -735,77 +735,6 @@ impl Variable { } } -impl Atom { - /// Convert the atom to a polynomial, optionally in the variable ordering - /// specified by `var_map`. If new variables are encountered, they are - /// added to the variable map. Similarly, non-polynomial parts are automatically - /// defined as a new independent variable in the polynomial. - pub fn to_polynomial( - &self, - field: &R, - var_map: Option>>, - ) -> MultivariatePolynomial { - self.as_view().to_polynomial(field, var_map) - } - - /// Convert the atom to a polynomial in specific variables. - /// All other parts will be collected into the coefficient, which - /// is a general expression. - /// - /// This routine does not perform expansions. - pub fn to_polynomial_in_vars( - &self, - var_map: &Arc>, - ) -> MultivariatePolynomial { - self.as_view().to_polynomial_in_vars(var_map) - } - - /// Convert the atom to a rational polynomial, optionally in the variable ordering - /// specified by `var_map`. If new variables are encountered, they are - /// added to the variable map. Similarly, non-rational polynomial parts are automatically - /// defined as a new independent variable in the rational polynomial. - pub fn to_rational_polynomial< - R: EuclideanDomain + ConvertToRing, - RO: EuclideanDomain + PolynomialGCD, - E: PositiveExponent, - >( - &self, - field: &R, - out_field: &RO, - var_map: Option>>, - ) -> RationalPolynomial - where - RationalPolynomial: - FromNumeratorAndDenominator + FromNumeratorAndDenominator, - { - self.as_view() - .to_rational_polynomial(field, out_field, var_map) - } - - /// Convert the atom to a rational polynomial with factorized denominators, optionally in the variable ordering - /// specified by `var_map`. If new variables are encountered, they are - /// added to the variable map. Similarly, non-rational polynomial parts are automatically - /// defined as a new independent variable in the rational polynomial. - pub fn to_factorized_rational_polynomial< - R: EuclideanDomain + ConvertToRing, - RO: EuclideanDomain + PolynomialGCD, - E: PositiveExponent, - >( - &self, - field: &R, - out_field: &RO, - var_map: Option>>, - ) -> FactorizedRationalPolynomial - where - FactorizedRationalPolynomial: FromNumeratorAndFactorizedDenominator - + FromNumeratorAndFactorizedDenominator, - MultivariatePolynomial: Factorize, - { - self.as_view() - .to_factorized_rational_polynomial(field, out_field, var_map) - } -} - impl<'a> AtomView<'a> { /// Convert an expanded expression to a polynomial. fn to_polynomial_expanded( @@ -1013,7 +942,7 @@ impl<'a> AtomView<'a> { /// specified by `var_map`. If new variables are encountered, they are /// added to the variable map. Similarly, non-polynomial parts are automatically /// defined as a new independent variable in the polynomial. - pub fn to_polynomial( + pub(crate) fn to_polynomial( &self, field: &R, var_map: Option>>, @@ -1021,7 +950,7 @@ impl<'a> AtomView<'a> { self.to_polynomial_impl(field, var_map.as_ref().unwrap_or(&Arc::new(Vec::new()))) } - pub fn to_polynomial_impl( + pub(crate) fn to_polynomial_impl( &self, field: &R, var_map: &Arc>, @@ -1155,7 +1084,7 @@ impl<'a> AtomView<'a> { /// is a general expression. /// /// This routine does not perform expansions. - pub fn to_polynomial_in_vars( + pub(crate) fn to_polynomial_in_vars( &self, var_map: &Arc>, ) -> MultivariatePolynomial { @@ -1260,7 +1189,7 @@ impl<'a> AtomView<'a> { /// specified by `var_map`. If new variables are encountered, they are /// added to the variable map. Similarly, non-rational polynomial parts are automatically /// defined as a new independent variable in the rational polynomial. - pub fn to_rational_polynomial< + pub(crate) fn to_rational_polynomial< R: EuclideanDomain + ConvertToRing, RO: EuclideanDomain + PolynomialGCD, E: PositiveExponent, @@ -1403,7 +1332,7 @@ impl<'a> AtomView<'a> { /// specified by `var_map`. If new variables are encountered, they are /// added to the variable map. Similarly, non-rational polynomial parts are automatically /// defined as a new independent variable in the rational polynomial. - pub fn to_factorized_rational_polynomial< + pub(crate) fn to_factorized_rational_polynomial< R: EuclideanDomain + ConvertToRing, RO: EuclideanDomain + PolynomialGCD, E: PositiveExponent, diff --git a/src/poly/evaluate.rs b/src/poly/evaluate.rs index e9388656..55425fbb 100644 --- a/src/poly/evaluate.rs +++ b/src/poly/evaluate.rs @@ -7,7 +7,12 @@ use ahash::{AHasher, HashMap, HashSet, HashSetExt}; use rand::{thread_rng, Rng}; use crate::{ - atom::Symbol, + atom::{Atom, AtomView}, + domains::{float::Real, Ring}, + evaluate::EvaluationFn, +}; +use crate::{ + atom::{AtomCore, Symbol}, coefficient::CoefficientView, domains::{ float::NumericalFloatLike, @@ -16,11 +21,6 @@ use crate::{ }, state::Workspace, }; -use crate::{ - atom::{Atom, AtomView}, - domains::{float::Real, Ring}, - evaluate::EvaluationFn, -}; use super::{polynomial::MultivariatePolynomial, PositiveExponent}; @@ -1961,7 +1961,7 @@ auto 𝑖 = 1i;\n", #[cfg(test)] mod test { use crate::{ - atom::Atom, + atom::{Atom, AtomCore}, domains::{float::Complex, rational::Q}, poly::{ evaluate::{BorrowedHornerScheme, InstructionSetPrinter}, diff --git a/src/poly/factor.rs b/src/poly/factor.rs index aaf9de88..4a981f3a 100644 --- a/src/poly/factor.rs +++ b/src/poly/factor.rs @@ -3398,7 +3398,7 @@ mod test { use std::sync::Arc; use crate::{ - atom::Atom, + atom::{Atom, AtomCore}, domains::{ algebraic_number::AlgebraicExtension, finite_field::{Zp, Z2}, diff --git a/src/poly/groebner.rs b/src/poly/groebner.rs index 95c242b1..5f82a737 100644 --- a/src/poly/groebner.rs +++ b/src/poly/groebner.rs @@ -934,7 +934,7 @@ echelonize_impl!(AlgebraicExtension); #[cfg(test)] mod test { use crate::{ - atom::Atom, + atom::{Atom, AtomCore}, domains::finite_field::Zp, poly::{groebner::GroebnerBasis, polynomial::MultivariatePolynomial, GrevLexOrder}, }; diff --git a/src/poly/polynomial.rs b/src/poly/polynomial.rs index 4af53e42..02606936 100755 --- a/src/poly/polynomial.rs +++ b/src/poly/polynomial.rs @@ -3740,7 +3740,11 @@ impl<'a, F: Ring, E: Exponent, O: MonomialOrder> IntoIterator #[cfg(test)] mod test { - use crate::{atom::Atom, domains::integer::Z, state::State}; + use crate::{ + atom::{Atom, AtomCore}, + domains::integer::Z, + state::State, + }; #[test] fn mul_packed() { diff --git a/src/poly/resultant.rs b/src/poly/resultant.rs index 72cb574d..ede08e95 100644 --- a/src/poly/resultant.rs +++ b/src/poly/resultant.rs @@ -206,7 +206,7 @@ impl UnivariatePolynomial { mod test { use std::sync::Arc; - use crate::atom::Atom; + use crate::atom::{Atom, AtomCore}; use crate::domains::integer::Z; use crate::domains::rational::Q; use crate::domains::rational_polynomial::{ diff --git a/src/poly/series.rs b/src/poly/series.rs index b161ca4a..0ffdf237 100644 --- a/src/poly/series.rs +++ b/src/poly/series.rs @@ -6,7 +6,7 @@ use std::{ }; use crate::{ - atom::{Atom, AtomView, FunctionBuilder}, + atom::{Atom, AtomCore, AtomView, FunctionBuilder}, coefficient::CoefficientView, domains::{ atom::AtomField, diff --git a/src/poly/univariate.rs b/src/poly/univariate.rs index fe49da20..09eedc7f 100644 --- a/src/poly/univariate.rs +++ b/src/poly/univariate.rs @@ -1769,7 +1769,7 @@ impl UnivariatePolynomial> { #[cfg(test)] mod test { use crate::{ - atom::Atom, + atom::{Atom, AtomCore}, domains::{float::F64, rational::Q}, }; diff --git a/src/printer.rs b/src/printer.rs index 42c34a00..f1c79cbc 100644 --- a/src/printer.rs +++ b/src/printer.rs @@ -251,20 +251,6 @@ impl std::fmt::Display for Symbol { } } -impl Atom { - /// Construct a printer for the atom with special options. - pub fn printer<'a>(&'a self, opts: PrintOptions) -> AtomPrinter<'a> { - AtomPrinter::new_with_options(self.as_view(), opts) - } - - /// Print the atom in a form that is unique and independent of any implementation details. - /// - /// Anti-symmetric functions are not supported. - pub fn to_canonical_string(&self) -> String { - self.as_view().to_canonical_string() - } -} - impl<'a> AtomView<'a> { fn fmt_debug(&self, fmt: &mut fmt::Formatter) -> fmt::Result { match self { @@ -277,7 +263,7 @@ impl<'a> AtomView<'a> { } } - pub fn format( + pub(crate) fn format( &self, fmt: &mut W, opts: &PrintOptions, @@ -294,14 +280,14 @@ impl<'a> AtomView<'a> { } /// Construct a printer for the atom with special options. - pub fn printer(&self, opts: PrintOptions) -> AtomPrinter { + pub(crate) fn printer(&self, opts: PrintOptions) -> AtomPrinter { AtomPrinter::new_with_options(*self, opts) } /// Print the atom in a form that is unique and independent of any implementation details. /// /// Anti-symmetric functions are not supported. - pub fn to_canonical_string(&self) -> String { + pub(crate) fn to_canonical_string(&self) -> String { let mut s = String::new(); self.to_canonical_view_impl(&mut s); s @@ -955,7 +941,7 @@ mod test { use colored::control::ShouldColorize; use crate::{ - atom::Atom, + atom::{Atom, AtomCore}, domains::{finite_field::Zp, integer::Z, SelfRing}, printer::{AtomPrinter, PrintOptions, PrintState}, state::{FunctionAttribute, State}, diff --git a/src/solve.rs b/src/solve.rs index d27c00f4..336c1dc9 100644 --- a/src/solve.rs +++ b/src/solve.rs @@ -1,7 +1,7 @@ use std::{ops::Neg, sync::Arc}; use crate::{ - atom::{AsAtomView, Atom, AtomView, Symbol}, + atom::{Atom, AtomCore, AtomView, Symbol}, domains::{ float::{FloatField, Real, SingleFloat}, integer::Z, @@ -14,60 +14,9 @@ use crate::{ tensors::matrix::Matrix, }; -impl Atom { - /// Find the root of a function in `x` numerically over the reals using Newton's method. - pub fn nsolve( - &self, - x: Symbol, - init: N, - prec: N, - max_iterations: usize, - ) -> Result { - self.as_view().nsolve(x, init, prec, max_iterations) - } - - /// Solve a non-linear system numerically over the reals using Newton's method. - pub fn nsolve_system< - N: SingleFloat + Real + PartialOrd + InternalOrdering + Eq + std::hash::Hash, - T: AsAtomView, - >( - system: &[T], - vars: &[Symbol], - init: &[N], - prec: N, - max_iterations: usize, - ) -> Result, String> { - AtomView::nsolve_system(system, vars, init, prec, max_iterations) - } - - /// Solve a system that is linear in `vars`, if possible. - /// Each expression in `system` is understood to yield 0. - pub fn solve_linear_system( - system: &[T1], - vars: &[T2], - ) -> Result, String> { - AtomView::solve_linear_system::(system, vars) - } - - /// Convert a system of linear equations to a matrix representation, returning the matrix - /// and the right-hand side. - pub fn system_to_matrix( - system: &[T1], - vars: &[T2], - ) -> Result< - ( - Matrix>, - Matrix>, - ), - String, - > { - AtomView::system_to_matrix::(system, vars) - } -} - impl<'a> AtomView<'a> { /// Find the root of a function in `x` numerically over the reals using Newton's method. - pub fn nsolve( + pub(crate) fn nsolve( &self, x: Symbol, init: N, @@ -108,9 +57,9 @@ impl<'a> AtomView<'a> { } /// Solve a non-linear system numerically over the reals using Newton's method. - pub fn nsolve_system< + pub(crate) fn nsolve_system< N: SingleFloat + Real + PartialOrd + InternalOrdering + Eq + std::hash::Hash, - T: AsAtomView, + T: AtomCore, >( system: &[T], vars: &[Symbol], @@ -219,7 +168,7 @@ impl<'a> AtomView<'a> { /// Solve a system that is linear in `vars`, if possible. /// Each expression in `system` is understood to yield 0. - pub fn solve_linear_system( + pub(crate) fn solve_linear_system( system: &[T1], vars: &[T2], ) -> Result, String> { @@ -235,7 +184,7 @@ impl<'a> AtomView<'a> { /// Convert a system of linear equations to a matrix representation, returning the matrix /// and the right-hand side. - pub fn system_to_matrix( + pub(crate) fn system_to_matrix( system: &[T1], vars: &[T2], ) -> Result< @@ -347,7 +296,7 @@ mod test { use std::sync::Arc; use crate::{ - atom::{representation::InlineVar, Atom, AtomView}, + atom::{representation::InlineVar, Atom, AtomCore, AtomView}, domains::{ float::{Real, F64}, integer::Z, diff --git a/src/streaming.rs b/src/streaming.rs index 693deaa2..65fff2e4 100644 --- a/src/streaming.rs +++ b/src/streaming.rs @@ -495,21 +495,9 @@ impl TermStreamer { } } -impl Atom { - /// Map the function `f` over all terms. - pub fn map_terms_single_core(&self, f: impl Fn(AtomView) -> Atom) -> Atom { - self.as_view().map_terms_single_core(f) - } - - /// Map the function `f` over all terms, using parallel execution with `n_cores` cores. - pub fn map_terms(&self, f: impl Fn(AtomView) -> Atom + Send + Sync, n_cores: usize) -> Atom { - self.as_view().map_terms(f, n_cores) - } -} - impl<'a> AtomView<'a> { /// Map the function `f` over all terms. - pub fn map_terms_single_core(&self, f: impl Fn(AtomView) -> Atom) -> Atom { + pub(crate) fn map_terms_single_core(&self, f: impl Fn(AtomView) -> Atom) -> Atom { if let AtomView::Add(aa) = self { return Workspace::get_local().with(|ws| { let mut r = ws.new_atom(); @@ -527,7 +515,11 @@ impl<'a> AtomView<'a> { } /// Map the function `f` over all terms, using parallel execution with `n_cores` cores. - pub fn map_terms(&self, f: impl Fn(AtomView) -> Atom + Send + Sync, n_cores: usize) -> Atom { + pub(crate) fn map_terms( + &self, + f: impl Fn(AtomView) -> Atom + Send + Sync, + n_cores: usize, + ) -> Atom { if n_cores < 2 || !LicenseManager::is_licensed() { return self.map_terms_single_core(f); } @@ -545,7 +537,7 @@ impl<'a> AtomView<'a> { } /// Map the function `f` over all terms, using parallel execution with `n_cores` cores. - pub fn map_terms_with_pool( + pub(crate) fn map_terms_with_pool( &self, f: impl Fn(AtomView) -> Atom + Send + Sync, p: &ThreadPool, @@ -593,7 +585,7 @@ mod test { use brotli::CompressorWriter; use crate::{ - atom::{Atom, AtomType}, + atom::{Atom, AtomCore, AtomType}, id::{Pattern, WildcardRestriction}, state::State, streaming::{TermStreamer, TermStreamerConfig}, @@ -648,17 +640,17 @@ mod test { streamer = streamer.map(|x| { x.replace_all( &pattern, - &rhs, - Some( - &( - State::get_symbol("v1_"), - WildcardRestriction::IsAtomType(AtomType::Var), - ) - .into(), - ), - None, - ) - .expand() + &rhs, + Some( + &( + State::get_symbol("v1_"), + WildcardRestriction::IsAtomType(AtomType::Var), + ) + .into(), + ), + None, + ) + .expand() }); streamer.normalize(); diff --git a/src/tensors.rs b/src/tensors.rs index 8d0d5bbe..a20bf082 100644 --- a/src/tensors.rs +++ b/src/tensors.rs @@ -6,45 +6,6 @@ use crate::{ pub mod matrix; -impl Atom { - /// Canonize (products of) tensors in the expression by relabeling repeated indices. - /// The tensors must be written as functions, with its indices are the arguments. - /// The repeated indices should be provided in `contracted_indices`. - /// - /// If the contracted indices are distinguishable (for example in their dimension), - /// you can provide an optional group marker for each index using `index_group`. - /// This makes sure that an index will not be renamed to an index from a different group. - /// - /// Example - /// ------- - /// ``` - /// # use symbolica::{atom::Atom, state::{FunctionAttribute, State}}; - /// # - /// # fn main() { - /// let _ = State::get_symbol_with_attributes("fs", &[FunctionAttribute::Symmetric]).unwrap(); - /// let _ = State::get_symbol_with_attributes("fc", &[FunctionAttribute::Cyclesymmetric]).unwrap(); - /// let a = Atom::parse("fs(mu2,mu3)*fc(mu4,mu2,k1,mu4,k1,mu3)").unwrap(); - /// - /// let mu1 = Atom::parse("mu1").unwrap(); - /// let mu2 = Atom::parse("mu2").unwrap(); - /// let mu3 = Atom::parse("mu3").unwrap(); - /// let mu4 = Atom::parse("mu4").unwrap(); - /// - /// let r = a.canonize_tensors(&[mu1.as_view(), mu2.as_view(), mu3.as_view(), mu4.as_view()], None).unwrap(); - /// println!("{}", r); - /// # } - /// ``` - /// yields `fs(mu1,mu2)*fc(mu1,k1,mu3,k1,mu2,mu3)`. - pub fn canonize_tensors( - &self, - contracted_indices: &[AtomView], - index_group: Option<&[AtomView]>, - ) -> Result { - self.as_view() - .canonize_tensors(contracted_indices, index_group) - } -} - impl<'a> AtomView<'a> { /// Canonize (products of) tensors in the expression by relabeling repeated indices. /// The tensors must be written as functions, with its indices are the arguments. @@ -53,7 +14,7 @@ impl<'a> AtomView<'a> { /// If the contracted indices are distinguishable (for example in their dimension), /// you can provide an optional group marker for each index using `index_group`. /// This makes sure that an index will not be renamed to an index from a different group. - pub fn canonize_tensors( + pub(crate) fn canonize_tensors( &self, contracted_indices: &[AtomView], index_group: Option<&[AtomView]>, @@ -461,7 +422,7 @@ impl<'a> AtomView<'a> { #[cfg(test)] mod test { use crate::{ - atom::{representation::InlineVar, Atom}, + atom::{representation::InlineVar, Atom, AtomCore}, state::State, }; diff --git a/tests/pattern_matching.rs b/tests/pattern_matching.rs index e511ed2a..2ff70a18 100644 --- a/tests/pattern_matching.rs +++ b/tests/pattern_matching.rs @@ -1,5 +1,5 @@ use symbolica::{ - atom::{Atom, AtomView}, + atom::{Atom, AtomCore, AtomView}, id::{Condition, Match, MatchSettings, Pattern, WildcardRestriction}, state::{RecycledAtom, State}, }; @@ -31,7 +31,7 @@ fn fibonacci() { target.replace_all_into(&pattern, &rhs, Some(&restrictions), None, &mut out); let mut out2 = RecycledAtom::new(); - out.expand_into(&mut out2); + out.expand_into(None, &mut out2); out2.replace_all_into(&lhs_zero_pat, &rhs_one, None, None, &mut out); @@ -49,9 +49,9 @@ fn replace_once() { let pat_expr = Atom::parse("f(x_)").unwrap(); let rhs_expr = Atom::parse("g(x_)").unwrap(); - let rhs = rhs_expr.as_view().into_pattern().into(); + let rhs = rhs_expr.as_view().to_pattern().into(); - let pattern = pat_expr.as_view().into_pattern(); + let pattern = pat_expr.as_view().to_pattern(); let restrictions = Condition::default(); let settings = MatchSettings::default(); diff --git a/tests/rational_polynomial.rs b/tests/rational_polynomial.rs index ceddb17d..835f7517 100644 --- a/tests/rational_polynomial.rs +++ b/tests/rational_polynomial.rs @@ -1,6 +1,12 @@ use std::sync::Arc; -use symbolica::{atom::Atom, domains::integer::Z, parser::Token, poly::Variable, state::State}; +use symbolica::{ + atom::{Atom, AtomCore}, + domains::integer::Z, + parser::Token, + poly::Variable, + state::State, +}; #[test] fn large_gcd_single_scale() { From 93003d98f6419ac3c8db73ad083f2624ad8d6942 Mon Sep 17 00:00:00 2001 From: Ben Ruijl Date: Fri, 13 Dec 2024 17:26:50 +0100 Subject: [PATCH 6/7] Refactor pattern matching API - ReplaceIterator and PatternAtomTreeIterator now implement Iterator - Unify call signatures --- examples/pattern_match.rs | 8 ++-- examples/pattern_restrictions.rs | 4 +- examples/replace_once.rs | 16 ++----- examples/tree_replace.rs | 16 +++++-- src/api/python.rs | 36 +++++++-------- src/atom/core.rs | 31 +++++++++---- src/evaluate.rs | 2 +- src/expand.rs | 4 +- src/id.rs | 79 +++++++++++++++++++++++--------- tests/pattern_matching.rs | 17 +++---- 10 files changed, 126 insertions(+), 87 deletions(-) diff --git a/examples/pattern_match.rs b/examples/pattern_match.rs index 88665e89..6d81d0ef 100644 --- a/examples/pattern_match.rs +++ b/examples/pattern_match.rs @@ -1,6 +1,6 @@ use symbolica::{ atom::{Atom, AtomCore}, - id::{Condition, Match, MatchSettings}, + id::Match, state::State, }; @@ -10,13 +10,11 @@ fn main() { let pat_expr = Atom::parse("z*x_*y___*g___(z___,x_,w___)").unwrap(); let pattern = pat_expr.to_pattern(); - let conditions = Condition::default(); - let settings = MatchSettings::default(); println!("> Matching pattern {} to {}:", pat_expr, expr.as_view()); - let mut it = expr.pattern_match(&pattern, &conditions, &settings); - while let Some(m) = it.next() { + let mut it = expr.pattern_match(&pattern, None, None); + while let Some(m) = it.next_detailed() { println!( "\t Match at location {:?} - {:?}:", m.position, m.used_flags diff --git a/examples/pattern_restrictions.rs b/examples/pattern_restrictions.rs index 68dfebd7..2f3fb40c 100644 --- a/examples/pattern_restrictions.rs +++ b/examples/pattern_restrictions.rs @@ -58,8 +58,8 @@ fn main() { expr ); - let mut it = expr.pattern_match(&pattern, &conditions, &settings); - while let Some(m) = it.next() { + let mut it = expr.pattern_match(&pattern, Some(&conditions), Some(&settings)); + while let Some(m) = it.next_detailed() { println!("\tMatch at location {:?} - {:?}:", m.position, m.used_flags); for (id, v) in m.match_stack { print!("\t\t{} = ", State::get_name(*id)); diff --git a/examples/replace_once.rs b/examples/replace_once.rs index b390aa12..9c22849b 100644 --- a/examples/replace_once.rs +++ b/examples/replace_once.rs @@ -1,18 +1,13 @@ -use symbolica::{ - atom::{Atom, AtomCore}, - id::{Condition, MatchSettings}, -}; +use symbolica::atom::{Atom, AtomCore}; fn main() { let expr = Atom::parse("f(z)*f(f(x))*f(y)").unwrap(); let pat_expr = Atom::parse("f(x_)").unwrap(); let rhs_expr = Atom::parse("g(x_)").unwrap(); - let rhs = rhs_expr.as_view().to_pattern().into(); + let rhs = rhs_expr.as_view().to_pattern(); let pattern = pat_expr.as_view().to_pattern(); - let restrictions = Condition::default(); - let settings = MatchSettings::default(); println!( "> Replace once {}={} in {}:", @@ -21,10 +16,7 @@ fn main() { expr.as_view() ); - let mut replaced = Atom::new(); - - let mut it = expr.replace_iter(&pattern, &rhs, &restrictions, &settings); - while let Some(()) = it.next(&mut replaced) { - println!("\t{}", replaced); + for x in expr.replace_iter(&pattern, &rhs, None, None) { + println!("\t{}", x); } } diff --git a/examples/tree_replace.rs b/examples/tree_replace.rs index 50408f7e..2a378fb9 100644 --- a/examples/tree_replace.rs +++ b/examples/tree_replace.rs @@ -1,7 +1,8 @@ use symbolica::{ atom::{Atom, AtomCore}, - id::{Condition, Match, MatchSettings, PatternAtomTreeIterator}, + id::Match, state::State, + symb, }; fn main() { @@ -9,13 +10,18 @@ fn main() { let pat_expr = Atom::parse("f(x_)").unwrap(); let pattern = pat_expr.to_pattern(); - let restrictions = Condition::default(); - let settings = MatchSettings::default(); println!("> Matching pattern {} to {}:", pat_expr, expr); - let mut it = PatternAtomTreeIterator::new(&pattern, expr.as_view(), &restrictions, &settings); - while let Some(m) = it.next() { + for x in expr.pattern_match(&pattern, None, None) { + println!("\t x_ = {}", x.get(&symb!("x_")).unwrap().to_atom()); + } + + println!("> Matching pattern {} to {}:", pat_expr, expr); + + // use next_detailed for detailed information + let mut it = expr.pattern_match(&pattern, None, None); + while let Some(m) = it.next_detailed() { println!("\tMatch at location {:?} - {:?}:", m.position, m.used_flags); for (id, v) in m.match_stack { print!("\t\t{} = ", State::get_name(*id)); diff --git a/src/api/python.rs b/src/api/python.rs index f9e03964..e488565d 100644 --- a/src/api/python.rs +++ b/src/api/python.rs @@ -4416,7 +4416,7 @@ impl PythonExpression { settings, ), move |(lhs, target, res, settings)| { - PatternAtomTreeIterator::new(lhs, target.as_view(), res, settings) + PatternAtomTreeIterator::new(lhs, target.as_view(), Some(res), Some(settings)) }, )) } @@ -4448,11 +4448,14 @@ impl PythonExpression { ..MatchSettings::default() }; - Ok( - PatternAtomTreeIterator::new(&pat, self.expr.as_view(), &conditions, &settings) - .next() - .is_some(), + Ok(PatternAtomTreeIterator::new( + &pat, + self.expr.as_view(), + Some(&conditions), + Some(&settings), ) + .next() + .is_some()) } /// Return an iterator over the replacement of the pattern `self` on `lhs` by `rhs`. @@ -4505,7 +4508,13 @@ impl PythonExpression { settings, ), move |(lhs, target, rhs, res, settings)| { - ReplaceIterator::new(lhs, target.as_view(), rhs, res, settings) + ReplaceIterator::new( + lhs, + target.as_view(), + crate::id::BorrowPatternOrMap::borrow(rhs), + Some(res), + Some(settings), + ) }, )) } @@ -5843,9 +5852,8 @@ impl PythonMatchIterator { fn __next__(&mut self) -> Option> { self.with_dependent_mut(|_, i| { i.next().map(|m| { - m.match_stack - .into_iter() - .map(|m| (Atom::new_var(m.0).into(), { m.1.to_atom().into() })) + m.into_iter() + .map(|(k, v)| (Atom::new_var(k).into(), { v.to_atom().into() })) .collect() }) }) @@ -5880,15 +5888,7 @@ impl PythonReplaceIterator { /// Return the next replacement. fn __next__(&mut self) -> PyResult> { - self.with_dependent_mut(|_, i| { - let mut out = Atom::default(); - - if i.next(&mut out).is_none() { - Ok(None) - } else { - Ok::<_, PyErr>(Some(out.into())) - } - }) + self.with_dependent_mut(|_, i| Ok(i.next().map(|x| x.into()))) } } diff --git a/src/atom/core.rs b/src/atom/core.rs index 759b8478..3c8dc620 100644 --- a/src/atom/core.rs +++ b/src/atom/core.rs @@ -19,7 +19,7 @@ use crate::{ evaluate::{EvalTree, EvaluationFn, ExpressionEvaluator, FunctionMap, OptimizationSettings}, id::{ BorrowPatternOrMap, BorrowReplacement, Condition, ConditionResult, Context, MatchSettings, - Pattern, PatternAtomTreeIterator, PatternOrMap, PatternRestriction, ReplaceIterator, + Pattern, PatternAtomTreeIterator, PatternRestriction, ReplaceIterator, }, poly::{ factor::Factorize, gcd::PolynomialGCD, polynomial::MultivariatePolynomial, series::Series, @@ -156,6 +156,11 @@ pub trait AtomCore { self.as_atom_view().expand_num() } + /// Check if the expression is expanded, optionally in only the variable or function `var`. + fn is_expanded(&self, var: Option) -> bool { + self.as_atom_view().is_expanded(var) + } + /// Take a derivative of the expression with respect to `x`. fn derivative(&self, x: Symbol) -> Atom { self.as_atom_view().derivative(x) @@ -304,7 +309,7 @@ pub trait AtomCore { self.as_atom_view().set_coefficient_ring(vars) } - /// Convert all coefficients to floats with a given precision `decimal_prec``. + /// Convert all coefficients to floats with a given precision `decimal_prec`. /// The precision of floating point coefficients in the input will be truncated to `decimal_prec`. fn coefficients_to_float(&self, decimal_prec: u32) -> Atom { let mut a = Atom::new(); @@ -313,7 +318,7 @@ pub trait AtomCore { a } - /// Convert all coefficients to floats with a given precision `decimal_prec``. + /// Convert all coefficients to floats with a given precision `decimal_prec`. /// The precision of floating point coefficients in the input will be truncated to `decimal_prec`. fn coefficients_to_float_into(&self, decimal_prec: u32, out: &mut Atom) { self.as_atom_view() @@ -576,22 +581,28 @@ pub trait AtomCore { } /// Return an iterator that replaces the pattern in the target once. - fn replace_iter<'a>( + fn replace_iter<'a, R: BorrowPatternOrMap>( &'a self, pattern: &'a Pattern, - rhs: &'a PatternOrMap, - conditions: &'a Condition, - settings: &'a MatchSettings, + rhs: &'a R, + conditions: Option<&'a Condition>, + settings: Option<&'a MatchSettings>, ) -> ReplaceIterator<'a, 'a> { - ReplaceIterator::new(pattern, self.as_atom_view(), rhs, conditions, settings) + ReplaceIterator::new( + pattern, + self.as_atom_view(), + rhs.borrow(), + conditions, + settings, + ) } /// Return an iterator over matched expressions. fn pattern_match<'a>( &'a self, pattern: &'a Pattern, - conditions: &'a Condition, - settings: &'a MatchSettings, + conditions: Option<&'a Condition>, + settings: Option<&'a MatchSettings>, ) -> PatternAtomTreeIterator<'a, 'a> { PatternAtomTreeIterator::new(pattern, self.as_atom_view(), conditions, settings) } diff --git a/src/evaluate.rs b/src/evaluate.rs index 036a5e8b..07b65877 100644 --- a/src/evaluate.rs +++ b/src/evaluate.rs @@ -3975,7 +3975,7 @@ impl<'a> AtomView<'a> { /// a variable or a function with fixed arguments. /// /// All variables and all user functions in the expression must occur in the map. - pub fn evaluate T + Copy>( + pub(crate) fn evaluate T + Copy>( &self, coeff_map: F, const_map: &HashMap, T>, diff --git a/src/expand.rs b/src/expand.rs index 6130c09c..8390250f 100644 --- a/src/expand.rs +++ b/src/expand.rs @@ -36,7 +36,7 @@ impl<'a> AtomView<'a> { } /// Expand an expression, returning `true` iff the expression changed. - pub fn expand_with_ws_into( + pub(crate) fn expand_with_ws_into( &self, workspace: &Workspace, var: Option, @@ -54,7 +54,7 @@ impl<'a> AtomView<'a> { } /// Check if the expression is expanded, optionally in only the variable or function `var`. - pub fn is_expanded(&self, var: Option) -> bool { + pub(crate) fn is_expanded(&self, var: Option) -> bool { match self { AtomView::Num(_) | AtomView::Var(_) | AtomView::Fun(_) => true, AtomView::Pow(pow_view) => { diff --git a/src/id.rs b/src/id.rs index f225e1d1..1525f6bc 100644 --- a/src/id.rs +++ b/src/id.rs @@ -1642,6 +1642,8 @@ impl From for Condition { } } +static DEFAULT_PATTERN_CONDITION: Condition = Condition::True; + /// A logical expression. #[derive(Clone, Debug, Default)] pub enum Condition { @@ -2275,7 +2277,19 @@ pub struct MatchSettings { pub rhs_cache_size: usize, } +static DEFAULT_MATCH_SETTINGS: MatchSettings = MatchSettings::new(); + impl MatchSettings { + pub const fn new() -> Self { + Self { + non_greedy_wildcards: Vec::new(), + level_range: (0, None), + level_is_tree_depth: false, + allow_new_wildcards_on_rhs: false, + rhs_cache_size: 0, + } + } + /// Create default match settings, but enable caching of the rhs. pub fn cached() -> Self { Self { @@ -2291,13 +2305,7 @@ impl MatchSettings { impl Default for MatchSettings { /// Create default match settings. Use [`MatchSettings::cached`] to enable caching. fn default() -> Self { - Self { - non_greedy_wildcards: Vec::new(), - level_range: (0, None), - level_is_tree_depth: false, - allow_new_wildcards_on_rhs: false, - rhs_cache_size: 0, - } + MatchSettings::new() } } @@ -3335,22 +3343,29 @@ impl<'a: 'b, 'b> PatternAtomTreeIterator<'a, 'b> { pub fn new( pattern: &'b Pattern, target: AtomView<'a>, - conditions: &'b Condition, - settings: &'b MatchSettings, + conditions: Option<&'b Condition>, + settings: Option<&'b MatchSettings>, ) -> PatternAtomTreeIterator<'a, 'b> { PatternAtomTreeIterator { pattern, - atom_tree_iterator: AtomTreeIterator::new(target, settings.clone()), + atom_tree_iterator: AtomTreeIterator::new( + target, + settings.unwrap_or(&DEFAULT_MATCH_SETTINGS).clone(), + ), current_target: None, pattern_iter: None, - match_stack: MatchStack::new(conditions, settings), + match_stack: MatchStack::new( + conditions.unwrap_or(&DEFAULT_PATTERN_CONDITION), + settings.unwrap_or(&DEFAULT_MATCH_SETTINGS), + ), tree_pos: Vec::new(), first_match: false, } } - /// Generate the next match if it exists. - pub fn next(&mut self) -> Option> { + /// Generate the next match if it exists, with detailed information about the + /// matched position. Use the iterator `Self::next` to a map of wildcard matches. + pub fn next_detailed(&mut self) -> Option> { loop { if let Some(ct) = self.current_target { if let Some(it) = self.pattern_iter.as_mut() { @@ -3389,10 +3404,23 @@ impl<'a: 'b, 'b> PatternAtomTreeIterator<'a, 'b> { } } +impl<'a: 'b, 'b> Iterator for PatternAtomTreeIterator<'a, 'b> { + type Item = HashMap>; + + /// Get the match map. Use `[PatternAtomTreeIterator::next_detailed]` to get more information. + fn next(&mut self) -> Option>> { + if let Some(_) = self.next_detailed() { + Some(self.match_stack.get_matches().iter().cloned().collect()) + } else { + None + } + } +} + /// Replace a pattern in the target once. Every call to `next`, /// will return a new match and replacement until the options are exhausted. pub struct ReplaceIterator<'a, 'b> { - rhs: &'b PatternOrMap, + rhs: BorrowedPatternOrMap<'b>, pattern_tree_iterator: PatternAtomTreeIterator<'a, 'b>, target: AtomView<'a>, } @@ -3401,9 +3429,9 @@ impl<'a: 'b, 'b> ReplaceIterator<'a, 'b> { pub fn new( pattern: &'b Pattern, target: AtomView<'a>, - rhs: &'b PatternOrMap, - conditions: &'a Condition, - settings: &'a MatchSettings, + rhs: BorrowedPatternOrMap<'b>, + conditions: Option<&'a Condition>, + settings: Option<&'a MatchSettings>, ) -> ReplaceIterator<'a, 'b> { ReplaceIterator { pattern_tree_iterator: PatternAtomTreeIterator::new( @@ -3533,17 +3561,17 @@ impl<'a: 'b, 'b> ReplaceIterator<'a, 'b> { } /// Return the next replacement. - pub fn next(&mut self, out: &mut Atom) -> Option<()> { - if let Some(pattern_match) = self.pattern_tree_iterator.next() { + pub fn next_into(&mut self, out: &mut Atom) -> Option<()> { + if let Some(pattern_match) = self.pattern_tree_iterator.next_detailed() { Workspace::get_local().with(|ws| { let mut new_rhs = ws.new_atom(); match self.rhs { - PatternOrMap::Pattern(p) => { + BorrowedPatternOrMap::Pattern(p) => { p.substitute_wildcards(ws, &mut new_rhs, pattern_match.match_stack, None) .unwrap(); // TODO: escalate? } - PatternOrMap::Map(f) => { + BorrowedPatternOrMap::Map(f) => { let mut new_atom = f(&pattern_match.match_stack); std::mem::swap(&mut new_atom, &mut new_rhs); } @@ -3568,6 +3596,15 @@ impl<'a: 'b, 'b> ReplaceIterator<'a, 'b> { } } +impl<'a: 'b, 'b> Iterator for ReplaceIterator<'a, 'b> { + type Item = Atom; + + fn next(&mut self) -> Option { + let mut out = Atom::new(); + self.next_into(&mut out).map(|_| out) + } +} + #[cfg(test)] mod test { use crate::{ diff --git a/tests/pattern_matching.rs b/tests/pattern_matching.rs index 2ff70a18..600fdcac 100644 --- a/tests/pattern_matching.rs +++ b/tests/pattern_matching.rs @@ -1,6 +1,6 @@ use symbolica::{ atom::{Atom, AtomCore, AtomView}, - id::{Condition, Match, MatchSettings, Pattern, WildcardRestriction}, + id::{Match, Pattern, WildcardRestriction}, state::{RecycledAtom, State}, }; @@ -49,19 +49,14 @@ fn replace_once() { let pat_expr = Atom::parse("f(x_)").unwrap(); let rhs_expr = Atom::parse("g(x_)").unwrap(); - let rhs = rhs_expr.as_view().to_pattern().into(); + let rhs = rhs_expr.as_view().to_pattern(); let pattern = pat_expr.as_view().to_pattern(); - let restrictions = Condition::default(); - let settings = MatchSettings::default(); - let mut replaced = Atom::new(); - - let mut it = expr.replace_iter(&pattern, &rhs, &restrictions, &settings); - let mut r = vec![]; - while let Some(()) = it.next(&mut replaced) { - r.push(replaced.clone()); - } + let r: Vec<_> = expr + .replace_iter(&pattern, &rhs, None, None) + .into_iter() + .collect(); let res = [ "g(z)*f(y)*f(f(x))", From 5e9279b87773ab07f50f81e5eaa116e8d07cd322 Mon Sep 17 00:00:00 2001 From: Ben Ruijl Date: Fri, 13 Dec 2024 20:10:44 +0100 Subject: [PATCH 7/7] Refactor evaluation API - Make the FunctionMap owned - Implement Borrow<[u8]> on Atom so that atoms that are keys in maps can be retrieved by an AtomView's get_data --- examples/evaluate.rs | 11 +-- examples/nested_evaluation.rs | 10 +- src/api/python.rs | 33 +++---- src/atom.rs | 4 +- src/atom/core.rs | 35 ++++--- src/atom/representation.rs | 21 +++++ src/evaluate.rs | 170 ++++++++++++++++------------------ src/poly/evaluate.rs | 14 ++- 8 files changed, 146 insertions(+), 152 deletions(-) diff --git a/examples/evaluate.rs b/examples/evaluate.rs index 06a1977a..f9e64956 100644 --- a/examples/evaluate.rs +++ b/examples/evaluate.rs @@ -11,13 +11,11 @@ fn main() { let a = Atom::parse("x*cos(x) + f(x, 1)^2 + g(g(x)) + p(0)").unwrap(); let mut const_map = HashMap::default(); - let mut fn_map: HashMap<_, EvaluationFn<_>> = HashMap::default(); - let mut cache = HashMap::default(); + let mut fn_map: HashMap<_, _> = HashMap::default(); // x = 6 and p(0) = 7 - let v = Atom::new_var(x); - const_map.insert(v.as_view(), 6.); - const_map.insert(p0.as_view(), 7.); + const_map.insert(Atom::new_var(x), 6.); + const_map.insert(p0, 7.); // f(x, y) = x^2 + y fn_map.insert( @@ -37,7 +35,6 @@ fn main() { println!( "Result for x = 6.: {}", - a.evaluate::(|x| x.into(), &const_map, &fn_map, &mut cache) - .unwrap() + a.evaluate(|x| x.into(), &const_map, &fn_map).unwrap() ); } diff --git a/examples/nested_evaluation.rs b/examples/nested_evaluation.rs index e2638859..6f54e154 100644 --- a/examples/nested_evaluation.rs +++ b/examples/nested_evaluation.rs @@ -26,7 +26,7 @@ fn main() { vec![Atom::new_num(1).into()], "p1".to_string(), vec![State::get_symbol("z")], - p1.as_view(), + p1, ) .unwrap(); fn_map @@ -34,7 +34,7 @@ fn main() { State::get_symbol("f"), "f".to_string(), vec![State::get_symbol("y"), State::get_symbol("z")], - f.as_view(), + f, ) .unwrap(); fn_map @@ -42,7 +42,7 @@ fn main() { State::get_symbol("g"), "g".to_string(), vec![State::get_symbol("y")], - g.as_view(), + g, ) .unwrap(); fn_map @@ -50,7 +50,7 @@ fn main() { State::get_symbol("h"), "h".to_string(), vec![State::get_symbol("y")], - h.as_view(), + h, ) .unwrap(); fn_map @@ -58,7 +58,7 @@ fn main() { State::get_symbol("i"), "i".to_string(), vec![State::get_symbol("y")], - i.as_view(), + i, ) .unwrap(); diff --git a/src/api/python.rs b/src/api/python.rs index e488565d..0543c60f 100644 --- a/src/api/python.rs +++ b/src/api/python.rs @@ -4848,8 +4848,6 @@ impl PythonExpression { constants: HashMap, functions: HashMap, ) -> PyResult { - let mut cache = HashMap::default(); - let constants = constants .iter() .map(|(k, v)| (k.expr.as_view(), *v)) @@ -4882,7 +4880,7 @@ impl PythonExpression { .collect::>()?; self.expr - .evaluate(|x| x.into(), &constants, &functions, &mut cache) + .evaluate(|x| x.into(), &constants, &functions) .map_err(|e| { exceptions::PyValueError::new_err(format!("Could not evaluate expression: {}", e)) }) @@ -4911,8 +4909,6 @@ impl PythonExpression { ) -> PyResult { let prec = (decimal_digit_precision as f64 * std::f64::consts::LOG2_10).ceil() as u32; - let mut cache = HashMap::default(); - let constants: HashMap = constants .iter() .map(|(k, v)| { @@ -4963,12 +4959,7 @@ impl PythonExpression { let a: PythonMultiPrecisionFloat = self .expr - .evaluate( - |x| x.to_multi_prec_float(prec), - &constants, - &functions, - &mut cache, - ) + .evaluate(|x| x.to_multi_prec_float(prec), &constants, &functions) .map_err(|e| { exceptions::PyValueError::new_err(format!("Could not evaluate expression: {}", e)) })? @@ -4992,8 +4983,6 @@ impl PythonExpression { constants: HashMap>, functions: HashMap, ) -> PyResult> { - let mut cache = HashMap::default(); - let constants = constants .iter() .map(|(k, v)| (k.expr.as_view(), *v)) @@ -5034,7 +5023,7 @@ impl PythonExpression { let r = self .expr - .evaluate(|x| x.into(), &constants, &functions, &mut cache) + .evaluate(|x| x.into(), &constants, &functions) .map_err(|e| { exceptions::PyValueError::new_err(format!("Could not evaluate expression: {}", e)) })?; @@ -5080,9 +5069,9 @@ impl PythonExpression { ) -> PyResult { let mut fn_map = FunctionMap::new(); - for (k, v) in &constants { + for (k, v) in constants { if let Ok(r) = v.expr.clone().try_into() { - fn_map.add_constant(k.expr.as_view(), r); + fn_map.add_constant(k.expr, r); } else { Err(exceptions::PyValueError::new_err(format!( "Constants must be rationals. If this is not possible, pass the value as a parameter", @@ -5090,7 +5079,7 @@ impl PythonExpression { } } - for ((symbol, rename, args), body) in &functions { + for ((symbol, rename, args), body) in functions { let symbol = symbol .to_id() .ok_or(exceptions::PyValueError::new_err(format!( @@ -5108,7 +5097,7 @@ impl PythonExpression { .collect::>()?; fn_map - .add_function(symbol, rename.clone(), args, body.expr.as_view()) + .add_function(symbol, rename.clone(), args, body.expr) .map_err(|e| { exceptions::PyValueError::new_err(format!("Could not add function: {}", e)) })?; @@ -5169,9 +5158,9 @@ impl PythonExpression { ) -> PyResult { let mut fn_map = FunctionMap::new(); - for (k, v) in &constants { + for (k, v) in constants { if let Ok(r) = v.expr.clone().try_into() { - fn_map.add_constant(k.expr.as_view(), r); + fn_map.add_constant(k.expr, r); } else { Err(exceptions::PyValueError::new_err(format!( "Constants must be rationals. If this is not possible, pass the value as a parameter", @@ -5179,7 +5168,7 @@ impl PythonExpression { } } - for ((symbol, rename, args), body) in &functions { + for ((symbol, rename, args), body) in functions { let symbol = symbol .to_id() .ok_or(exceptions::PyValueError::new_err(format!( @@ -5197,7 +5186,7 @@ impl PythonExpression { .collect::>()?; fn_map - .add_function(symbol, rename.clone(), args, body.expr.as_view()) + .add_function(symbol, rename.clone(), args, body.expr) .map_err(|e| { exceptions::PyValueError::new_err(format!("Could not add function: {}", e)) })?; diff --git a/src/atom.rs b/src/atom.rs index 433efd2b..bce70af3 100644 --- a/src/atom.rs +++ b/src/atom.rs @@ -15,8 +15,8 @@ use std::{cmp::Ordering, hash::Hash, ops::DerefMut, str::FromStr}; pub use self::core::AtomCore; pub use self::representation::{ - Add, AddView, Fun, ListIterator, ListSlice, Mul, MulView, Num, NumView, Pow, PowView, Var, - VarView, + Add, AddView, Fun, KeyLookup, ListIterator, ListSlice, Mul, MulView, Num, NumView, Pow, + PowView, Var, VarView, }; use self::representation::{FunView, RawAtom}; diff --git a/src/atom/core.rs b/src/atom/core.rs index 3c8dc620..951225bf 100644 --- a/src/atom/core.rs +++ b/src/atom/core.rs @@ -33,7 +33,7 @@ use std::sync::Arc; use super::{ representation::{InlineNum, InlineVar}, - Atom, AtomOrView, AtomView, Symbol, + Atom, AtomOrView, AtomView, KeyLookup, Symbol, }; /// All core features of expressions, such as expansion and @@ -239,23 +239,22 @@ pub trait AtomCore { /// to an optimized version or generate a compiled version of your expression. /// /// All variables and all user functions in the expression must occur in the map. - fn evaluate<'b, T: Real, F: Fn(&Rational) -> T + Copy>( - &'b self, + fn evaluate T + Copy>( + &self, coeff_map: F, - const_map: &HashMap, T>, - function_map: &HashMap>, - cache: &mut HashMap, T>, + const_map: &HashMap, + function_map: &HashMap>, ) -> Result { self.as_atom_view() - .evaluate(coeff_map, const_map, function_map, cache) + .evaluate(coeff_map, const_map, function_map) } /// Convert nested expressions to a tree suitable for repeated evaluations with /// different values for `params`. /// All variables and all user functions in the expression must occur in the map. - fn to_evaluation_tree<'a>( - &'a self, - fn_map: &FunctionMap<'a, Rational>, + fn to_evaluation_tree( + &self, + fn_map: &FunctionMap, params: &[Atom], ) -> Result, String> { self.as_atom_view().to_evaluation_tree(fn_map, params) @@ -265,9 +264,9 @@ pub trait AtomCore { /// All free parameters must appear in `params` and all other variables /// and user functions in the expression must occur in the function map. /// The function map may have nested expressions. - fn evaluator<'a>( - &'a self, - fn_map: &FunctionMap<'a, Rational>, + fn evaluator( + &self, + fn_map: &FunctionMap, params: &[Atom], optimization_settings: OptimizationSettings, ) -> Result, String> { @@ -283,9 +282,9 @@ pub trait AtomCore { /// Convert nested expressions to a tree suitable for repeated evaluations with /// different values for `params`. /// All variables and all user functions in the expression must occur in the map. - fn evaluator_multiple<'a>( - exprs: &[AtomView<'a>], - fn_map: &FunctionMap<'a, Rational>, + fn evaluator_multiple( + exprs: &[A], + fn_map: &FunctionMap, params: &[Atom], optimization_settings: OptimizationSettings, ) -> Result, String> { @@ -425,7 +424,7 @@ pub trait AtomCore { } /// Construct a printer for the atom with special options. - fn printer<'a>(&'a self, opts: PrintOptions) -> AtomPrinter<'a> { + fn printer(&self, opts: PrintOptions) -> AtomPrinter { AtomPrinter::new_with_options(self.as_atom_view(), opts) } @@ -503,7 +502,7 @@ pub trait AtomCore { } /// Get all variables and functions in the expression. - fn get_all_indeterminates<'a>(&'a self, enter_functions: bool) -> HashSet> { + fn get_all_indeterminates(&self, enter_functions: bool) -> HashSet { self.as_atom_view().get_all_indeterminates(enter_functions) } diff --git a/src/atom/representation.rs b/src/atom/representation.rs index 402c0c3c..b75ff221 100644 --- a/src/atom/representation.rs +++ b/src/atom/representation.rs @@ -2,7 +2,9 @@ use byteorder::{LittleEndian, WriteBytesExt}; use bytes::{Buf, BufMut}; use smartstring::alias::String; use std::{ + borrow::Borrow, cmp::Ordering, + hash::Hash, io::{Read, Write}, }; @@ -37,8 +39,27 @@ const MUL_HAS_COEFF_FLAG: u8 = 0b01000000; const ZERO_DATA: [u8; 3] = [NUM_ID, 1, 0]; +pub type BorrowedRawAtom = [u8]; pub type RawAtom = Vec; +impl Borrow for Atom { + fn borrow(&self) -> &BorrowedRawAtom { + &self.as_view().get_data() + } +} + +impl<'a> Borrow for AtomView<'a> { + fn borrow(&self) -> &BorrowedRawAtom { + &self.get_data() + } +} + +/// Allows the atom to be used as a key and looked up through a mapping to `&[u8]`. +pub trait KeyLookup: Borrow + Eq + Hash {} + +impl KeyLookup for Atom {} +impl<'a> KeyLookup for AtomView<'a> {} + /// An inline variable. #[derive(Copy, Clone)] pub struct InlineVar { diff --git a/src/evaluate.rs b/src/evaluate.rs index 07b65877..7f0a66c0 100644 --- a/src/evaluate.rs +++ b/src/evaluate.rs @@ -13,7 +13,7 @@ use self_cell::self_cell; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use crate::{ - atom::{Atom, AtomOrView, AtomView, Symbol}, + atom::{Atom, AtomCore, AtomView, KeyLookup, Symbol}, coefficient::CoefficientView, combinatorics::unique_permutations, domains::{ @@ -28,52 +28,46 @@ use crate::{ LicenseManager, }; -type EvalFnType = Box< +type EvalFnType = Box< dyn Fn( &[T], - &HashMap, T>, - &HashMap>, + &HashMap, + &HashMap>, &mut HashMap, T>, ) -> T, >; -pub struct EvaluationFn(EvalFnType); +pub struct EvaluationFn(EvalFnType); -impl EvaluationFn { - pub fn new(f: EvalFnType) -> EvaluationFn { +impl EvaluationFn { + pub fn new(f: EvalFnType) -> EvaluationFn { EvaluationFn(f) } /// Get a reference to the function that can be called to evaluate it. - pub fn get(&self) -> &EvalFnType { + pub fn get(&self) -> &EvalFnType { &self.0 } } -#[derive(PartialEq, Eq, Hash)] -enum AtomOrTaggedFunction<'a> { - Atom(AtomOrView<'a>), - TaggedFunction(Symbol, Vec>), -} - -pub struct FunctionMap<'a, T = Rational> { - map: HashMap, ConstOrExpr<'a, T>>, +#[derive(Clone)] +pub struct FunctionMap { + map: HashMap>, + tagged_fn_map: HashMap<(Symbol, Vec), ConstOrExpr>, tag: HashMap, } -impl<'a, T> FunctionMap<'a, T> { +impl FunctionMap { pub fn new() -> Self { FunctionMap { map: HashMap::default(), + tagged_fn_map: HashMap::default(), tag: HashMap::default(), } } - pub fn add_constant>>(&mut self, key: A, value: T) { - self.map.insert( - AtomOrTaggedFunction::Atom(key.into()), - ConstOrExpr::Const(value), - ); + pub fn add_constant(&mut self, key: Atom, value: T) { + self.map.insert(key, ConstOrExpr::Const(value)); } pub fn add_function( @@ -81,7 +75,7 @@ impl<'a, T> FunctionMap<'a, T> { name: Symbol, rename: String, args: Vec, - body: AtomView<'a>, + body: Atom, ) -> Result<(), &str> { if let Some(t) = self.tag.insert(name, 0) { if t != 0 { @@ -89,10 +83,8 @@ impl<'a, T> FunctionMap<'a, T> { } } - self.map.insert( - AtomOrTaggedFunction::TaggedFunction(name, vec![]), - ConstOrExpr::Expr(rename, 0, args, body), - ); + self.tagged_fn_map + .insert((name, vec![]), ConstOrExpr::Expr(rename, 0, args, body)); Ok(()) } @@ -100,10 +92,10 @@ impl<'a, T> FunctionMap<'a, T> { pub fn add_tagged_function( &mut self, name: Symbol, - tags: Vec>, + tags: Vec, rename: String, args: Vec, - body: AtomView<'a>, + body: Atom, ) -> Result<(), &str> { if let Some(t) = self.tag.insert(name, tags.len()) { if t != tags.len() { @@ -112,10 +104,8 @@ impl<'a, T> FunctionMap<'a, T> { } let tag_len = tags.len(); - self.map.insert( - AtomOrTaggedFunction::TaggedFunction(name, tags), - ConstOrExpr::Expr(rename, tag_len, args, body), - ); + self.tagged_fn_map + .insert((name, tags), ConstOrExpr::Expr(rename, tag_len, args, body)); Ok(()) } @@ -124,15 +114,15 @@ impl<'a, T> FunctionMap<'a, T> { self.tag.get(symbol).cloned().unwrap_or(0) } - fn get_constant(&self, a: AtomView<'a>) -> Option<&T> { - match self.map.get(&AtomOrTaggedFunction::Atom(a.into())) { + fn get_constant(&self, a: AtomView) -> Option<&T> { + match self.map.get(a.get_data()) { Some(ConstOrExpr::Const(c)) => Some(c), _ => None, } } - fn get(&self, a: AtomView<'a>) -> Option<&ConstOrExpr<'a, T>> { - if let Some(c) = self.map.get(&AtomOrTaggedFunction::Atom(a.into())) { + fn get(&self, a: AtomView) -> Option<&ConstOrExpr> { + if let Some(c) = self.map.get(a.get_data()) { return Some(c); } @@ -141,8 +131,8 @@ impl<'a, T> FunctionMap<'a, T> { let tag_len = self.get_tag_len(&s); if aa.get_nargs() >= tag_len { - let tag = aa.iter().take(tag_len).map(|x| x.into()).collect(); - return self.map.get(&AtomOrTaggedFunction::TaggedFunction(s, tag)); + let tag = aa.iter().take(tag_len).map(|x| x.to_owned()).collect(); + return self.tagged_fn_map.get(&(s, tag)); } } @@ -150,9 +140,10 @@ impl<'a, T> FunctionMap<'a, T> { } } -enum ConstOrExpr<'a, T> { +#[derive(Clone)] +enum ConstOrExpr { Const(T), - Expr(String, usize, Vec, AtomView<'a>), + Expr(String, usize, Vec, Atom), } #[derive(Debug, Clone)] @@ -3795,22 +3786,25 @@ impl<'a> AtomView<'a> { /// Convert nested expressions to a tree. pub fn to_evaluation_tree( &self, - fn_map: &FunctionMap<'a, Rational>, + fn_map: &FunctionMap, params: &[Atom], ) -> Result, String> { Self::to_eval_tree_multiple(std::slice::from_ref(self), fn_map, params) } /// Convert nested expressions to a tree. - pub fn to_eval_tree_multiple( - exprs: &[Self], - fn_map: &FunctionMap<'a, Rational>, + pub fn to_eval_tree_multiple( + exprs: &[A], + fn_map: &FunctionMap, params: &[Atom], ) -> Result, String> { let mut funcs = vec![]; let tree = exprs .iter() - .map(|t| t.to_eval_tree_impl(fn_map, params, &[], &mut funcs)) + .map(|t| { + t.as_atom_view() + .to_eval_tree_impl(fn_map, params, &[], &mut funcs) + }) .collect::>()?; Ok(EvalTree { @@ -3825,7 +3819,7 @@ impl<'a> AtomView<'a> { fn to_eval_tree_impl( &self, - fn_map: &FunctionMap<'a, Rational>, + fn_map: &FunctionMap, params: &[Atom], args: &[Symbol], funcs: &mut Vec<(String, Vec, SplitExpression)>, @@ -3903,7 +3897,9 @@ impl<'a> AtomView<'a> { if let Some(pos) = funcs.iter().position(|f| f.0 == *name) { Ok(Expression::Eval(pos, eval_args)) } else { - let r = e.to_eval_tree_impl(fn_map, params, arg_spec, funcs)?; + let r = e + .as_view() + .to_eval_tree_impl(fn_map, params, arg_spec, funcs)?; funcs.push(( name.clone(), arg_spec.clone(), @@ -3975,14 +3971,24 @@ impl<'a> AtomView<'a> { /// a variable or a function with fixed arguments. /// /// All variables and all user functions in the expression must occur in the map. - pub(crate) fn evaluate T + Copy>( + pub(crate) fn evaluate T + Copy>( + &self, + coeff_map: F, + const_map: &HashMap, + function_map: &HashMap>, + ) -> Result { + let mut cache = HashMap::default(); + self.evaluate_impl(coeff_map, const_map, function_map, &mut cache) + } + + fn evaluate_impl T + Copy>( &self, coeff_map: F, - const_map: &HashMap, T>, - function_map: &HashMap>, + const_map: &HashMap, + function_map: &HashMap>, cache: &mut HashMap, T>, ) -> Result { - if let Some(c) = const_map.get(self) { + if let Some(c) = const_map.get(self.get_data()) { return Ok(c.clone()); } @@ -4017,7 +4023,7 @@ impl<'a> AtomView<'a> { if [Atom::EXP, Atom::LOG, Atom::SIN, Atom::COS, Atom::SQRT].contains(&name) { assert!(f.get_nargs() == 1); let arg = f.iter().next().unwrap(); - let arg_eval = arg.evaluate(coeff_map, const_map, function_map, cache)?; + let arg_eval = arg.evaluate_impl(coeff_map, const_map, function_map, cache)?; return Ok(match f.get_symbol() { Atom::EXP => arg_eval.exp(), @@ -4035,7 +4041,7 @@ impl<'a> AtomView<'a> { let mut args = Vec::with_capacity(f.get_nargs()); for arg in f { - args.push(arg.evaluate(coeff_map, const_map, function_map, cache)?); + args.push(arg.evaluate_impl(coeff_map, const_map, function_map, cache)?); } let Some(fun) = function_map.get(&f.get_symbol()) else { @@ -4051,7 +4057,7 @@ impl<'a> AtomView<'a> { } AtomView::Pow(p) => { let (b, e) = p.get_base_exp(); - let b_eval = b.evaluate(coeff_map, const_map, function_map, cache)?; + let b_eval = b.evaluate_impl(coeff_map, const_map, function_map, cache)?; if let AtomView::Num(n) = e { if let CoefficientView::Natural(num, den) = n.get_coeff_view() { @@ -4065,7 +4071,7 @@ impl<'a> AtomView<'a> { } } - let e_eval = e.evaluate(coeff_map, const_map, function_map, cache)?; + let e_eval = e.evaluate_impl(coeff_map, const_map, function_map, cache)?; Ok(b_eval.powf(&e_eval)) } AtomView::Mul(m) => { @@ -4073,9 +4079,9 @@ impl<'a> AtomView<'a> { let mut r = it.next() .unwrap() - .evaluate(coeff_map, const_map, function_map, cache)?; + .evaluate_impl(coeff_map, const_map, function_map, cache)?; for arg in it { - r *= arg.evaluate(coeff_map, const_map, function_map, cache)?; + r *= arg.evaluate_impl(coeff_map, const_map, function_map, cache)?; } Ok(r) } @@ -4084,9 +4090,9 @@ impl<'a> AtomView<'a> { let mut r = it.next() .unwrap() - .evaluate(coeff_map, const_map, function_map, cache)?; + .evaluate_impl(coeff_map, const_map, function_map, cache)?; for arg in it { - r += arg.evaluate(coeff_map, const_map, function_map, cache)?; + r += arg.evaluate_impl(coeff_map, const_map, function_map, cache)?; } Ok(r) } @@ -4151,11 +4157,7 @@ impl<'a> AtomView<'a> { }) .collect(); - let mut cache = HashMap::default(); - for _ in 0..iterations { - cache.clear(); - for x in vars.values_mut() { *x = x.sample_unit(&mut rng); } @@ -4176,7 +4178,6 @@ impl<'a> AtomView<'a> { }, &vars, &HashMap::default(), - &mut cache, ) .unwrap(); @@ -4212,11 +4213,7 @@ impl<'a> AtomView<'a> { }) .collect(); - let mut cache = HashMap::default(); - for _ in 0..iterations { - cache.clear(); - for x in vars.values_mut() { *x = x.sample_unit(&mut rng); } @@ -4231,7 +4228,6 @@ impl<'a> AtomView<'a> { }, &vars, &HashMap::default(), - &mut cache, ) .unwrap(); @@ -4272,13 +4268,14 @@ mod test { let p0 = Atom::parse("v2(0)").unwrap(); let a = Atom::parse("v1*cos(v1) + f1(v1, 1)^2 + f2(f2(v1)) + v2(0)").unwrap(); + let v = Atom::new_var(x); + let mut const_map = HashMap::default(); - let mut fn_map: HashMap<_, EvaluationFn<_>> = HashMap::default(); - let mut cache = HashMap::default(); + let mut fn_map: HashMap<_, EvaluationFn<_, _>> = HashMap::default(); // x = 6 and p(0) = 7 - let v = Atom::new_var(x); - const_map.insert(v.as_view(), 6.); + + const_map.insert(v.as_view(), 6.); // .as_view() const_map.insert(p0.as_view(), 7.); // f(x, y) = x^2 + y @@ -4297,9 +4294,7 @@ mod test { })), ); - let r = a - .evaluate(|x| x.into(), &const_map, &fn_map, &mut cache) - .unwrap(); + let r = a.evaluate(|x| x.into(), &const_map, &fn_map).unwrap(); assert_eq!(r, 2905.761021719902); } @@ -4318,7 +4313,6 @@ mod test { |r| r.to_multi_prec_float(200), &const_map, &HashMap::default(), - &mut HashMap::default(), ) .unwrap(); @@ -4350,7 +4344,7 @@ mod test { vec![Atom::new_num(1).into()], "p1".to_string(), vec![State::get_symbol("z")], - p1.as_view(), + p1, ) .unwrap(); fn_map @@ -4358,7 +4352,7 @@ mod test { State::get_symbol("f"), "f".to_string(), vec![State::get_symbol("y"), State::get_symbol("z")], - f.as_view(), + f, ) .unwrap(); fn_map @@ -4366,7 +4360,7 @@ mod test { State::get_symbol("g"), "g".to_string(), vec![State::get_symbol("y")], - g.as_view(), + g, ) .unwrap(); fn_map @@ -4374,7 +4368,7 @@ mod test { State::get_symbol("h"), "h".to_string(), vec![State::get_symbol("y")], - h.as_view(), + h, ) .unwrap(); fn_map @@ -4382,19 +4376,15 @@ mod test { State::get_symbol("i"), "i".to_string(), vec![State::get_symbol("y")], - i.as_view(), + i, ) .unwrap(); let params = vec![Atom::parse("x").unwrap()]; - let evaluator = Atom::evaluator_multiple( - &[e1.as_view(), e2.as_view()], - &fn_map, - ¶ms, - OptimizationSettings::default(), - ) - .unwrap(); + let evaluator = + Atom::evaluator_multiple(&[e1, e2], &fn_map, ¶ms, OptimizationSettings::default()) + .unwrap(); let mut e_f64 = evaluator.map_coeff(&|x| x.into()); let r = e_f64.evaluate_single(&[1.1]); diff --git a/src/poly/evaluate.rs b/src/poly/evaluate.rs index 55425fbb..7da7481c 100644 --- a/src/poly/evaluate.rs +++ b/src/poly/evaluate.rs @@ -7,7 +7,7 @@ use ahash::{AHasher, HashMap, HashSet, HashSetExt}; use rand::{thread_rng, Rng}; use crate::{ - atom::{Atom, AtomView}, + atom::{Atom, AtomView, KeyLookup}, domains::{float::Real, Ring}, evaluate::EvaluationFn, }; @@ -1467,25 +1467,23 @@ impl From<&'b Rational>> InstructionEvaluator { /// a variable or a function with fixed arguments. /// /// All variables and all user functions in the expression must occur in the map. - pub fn evaluate N + Copy>( + pub fn evaluate N + Copy>( &mut self, coeff_map: F, - const_map: &HashMap, N>, - function_map: &HashMap>, + const_map: &HashMap, + function_map: &HashMap>, ) -> &[N] { Workspace::get_local().with(|ws| { for (input, expr) in self.eval.iter_mut().zip(&self.input_map) { match expr { super::Variable::Symbol(s) => { *input = const_map - .get(&ws.new_var(*s).as_view()) + .get(ws.new_var(*s).as_view().get_data()) .expect("Variable not found") .clone(); } super::Variable::Function(_, o) | super::Variable::Other(o) => { - *input = o - .evaluate(coeff_map, const_map, function_map, &mut HashMap::default()) - .unwrap(); + *input = o.evaluate(coeff_map, const_map, function_map).unwrap(); } super::Variable::Temporary(_) => panic!("Temporary variable in input"), }