Skip to content

Commit

Permalink
optimize common partitioning patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
kaikalii committed Nov 25, 2024
1 parent 0023e2f commit de23572
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 12 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ This version is not yet released. If you are reading this on the website, then t
- Improve pattern matching error messages
- Optimize the "root" pattern `ⁿ%:1`
- Optimize format strings applied to strings or boxed strings
- Optimize common [`partition ⊜`](https://uiua.org/docs/partition) patterns
- Add an `-e`/`--experimental` flag to the `uiua eval` command to enable experimental features
### Website
- Add a new pad setting to show line values to the right of the code
Expand Down
96 changes: 90 additions & 6 deletions src/algorithm/loops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,96 @@ pub fn do_(ops: Ops, env: &mut Uiua) -> UiuaResult {
Ok(())
}

pub fn partition(ops: Ops, env: &mut Uiua) -> UiuaResult {
pub fn split_by(scalar: bool, env: &mut Uiua) -> UiuaResult {
let delim = env.pop(1)?;
let haystack = env.pop(2)?;
if haystack.rank() > 1
|| delim.rank() > 1
|| scalar && !(delim.rank() == 0 || delim.rank() == 1 && delim.row_count() == 1)
{
let mask = if scalar {
delim.is_ne(haystack.clone(), 0, 0, env)?
} else {
delim.mask(&haystack, env)?.not(env)?
};
env.push(haystack);
env.push(mask);
return partition(
SigNode {
node: Node::Prim(Primitive::Box, 0),
sig: Signature::new(1, 1),
},
env,
);
}
let val = haystack.generic_bin_ref(
&delim,
|a, b| a.split_by(b, env),
|a, b| a.split_by(b, env),
|a, b| a.split_by(b, env),
|a, b| a.split_by(b, env),
|a, b| a.split_by(b, env),
|a, b| {
env.error(format!(
"Cannot split {} by {}",
a.type_name_plural(),
b.type_name_plural()
))
},
)?;
env.push(val);
Ok(())
}

impl<T: ArrayValue> Array<T>
where
Value: From<CowSlice<T>>,
{
fn split_by(&self, delim: &Self, _env: &Uiua) -> UiuaResult<Array<Boxed>> {
let haystack = self.data.as_slice();
let delim_slice = delim.data.as_slice();
Ok(if delim.rank() == 0 || delim.row_count() == 1 {
let mut curr = 0;
let mut data = EcoVec::new();
let delim = &delim_slice[0];
for slice in haystack.split(|elem| elem.array_eq(delim)) {
if slice.is_empty() {
curr += 1;
continue;
}
let start = curr;
let end = start + slice.len();
data.push(Boxed(self.data.slice(start..end).into()));
curr = end + 1;
}
data.into()
} else {
let mut curr = 0;
let mut data = EcoVec::new();
while curr < haystack.len() {
let prev_end = haystack[curr..]
.windows(delim_slice.len())
.position(|win| win.iter().zip(delim_slice).all(|(a, b)| a.array_eq(b)))
.map(|i| curr + i)
.unwrap_or(haystack.len());
let next_start = prev_end + delim_slice.len();
if curr == prev_end {
curr = next_start;
continue;
}
data.push(Boxed(self.data.slice(curr..prev_end).into()));
curr = next_start;
}
data.into()
})
}
}

pub fn partition(f: SigNode, env: &mut Uiua) -> UiuaResult {
crate::profile_function!();
collapse_groups(
Primitive::Partition,
ops,
f,
Value::partition_groups,
|val, markers, _| Ok(val.partition_firsts(markers)),
|val, markers, _| Ok(val.partition_lasts(markers)),
Expand Down Expand Up @@ -642,11 +727,11 @@ fn update_array_at<T: Clone>(arr: &mut Array<T>, start: usize, new: &[T]) {
arr.data.as_mut_slice()[start..end].clone_from_slice(new);
}

pub fn group(ops: Ops, env: &mut Uiua) -> UiuaResult {
pub fn group(f: SigNode, env: &mut Uiua) -> UiuaResult {
crate::profile_function!();
collapse_groups(
Primitive::Group,
ops,
f,
Value::group_groups,
Value::group_firsts,
Value::group_lasts,
Expand Down Expand Up @@ -840,7 +925,7 @@ pub fn undo_group_part2(env: &mut Uiua) -> UiuaResult {
#[allow(clippy::too_many_arguments)]
fn collapse_groups<I>(
prim: Primitive,
ops: Ops,
f: SigNode,
get_groups: impl Fn(Value, &Array<isize>) -> I,
firsts: impl Fn(Value, &[isize], &Uiua) -> UiuaResult<Value>,
lasts: impl Fn(Value, &[isize], &Uiua) -> UiuaResult<Value>,
Expand All @@ -852,7 +937,6 @@ where
I: IntoIterator<Item = Value>,
I::IntoIter: ExactSizeIterator,
{
let [f] = get_ops(ops, env)?;
let sig = f.sig;
let indices = env.pop(1)?.as_integer_array(env, indices_error)?;
let values: Vec<Value> = (0..sig.args.max(1))
Expand Down
52 changes: 52 additions & 0 deletions src/compile/optimize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ static OPTIMIZATIONS: &[&dyn Optimization] = &[
&ReduceDepthOpt,
&AdjacentOpt,
&AstarOpt,
&SplitByOpt,
&PopConst,
&TraceOpt,
&ValidateTypeOpt,
Expand Down Expand Up @@ -216,6 +217,57 @@ opt!(
),
);

#[derive(Debug)]
struct SplitByOpt;
impl Optimization for SplitByOpt {
fn match_and_replace(&self, nodes: &mut EcoVec<Node>) -> bool {
fn is_par_box(node: &Node) -> bool {
let Mod(Partition, args, _) = node else {
return false;
};
let [f] = args.as_slice() else {
return false;
};
matches!(f.node, Prim(Box, _))
}
for i in 0..nodes.len() {
match &nodes[i..] {
[Mod(By, args, span), last, ..]
if is_par_box(last)
&& matches!(args.as_slice(), [f]
if matches!(f.node, Prim(Ne, _))) =>
{
replace_nodes(nodes, i, 2, ImplPrim(SplitByScalar, *span));
break;
}
[Mod(By, args, span), Prim(Not, _), last, ..]
if is_par_box(last)
&& matches!(args.as_slice(), [f]
if matches!(f.node, Prim(Mask, _))) =>
{
replace_nodes(nodes, i, 3, ImplPrim(SplitBy, *span));
break;
}
[Prim(Dup, span), Push(delim), Prim(Ne, _), last, ..] if is_par_box(last) => {
let new =
Node::from_iter([Push(delim.clone()), ImplPrim(SplitByScalar, *span)]);
replace_nodes(nodes, i, 4, new);
break;
}
[Prim(Dup, span), Push(delim), Prim(Mask, _), Prim(Not, _), last, ..]
if is_par_box(last) =>
{
let new = Node::from_iter([Push(delim.clone()), ImplPrim(SplitBy, *span)]);
replace_nodes(nodes, i, 5, new);
break;
}
_ => {}
}
}
false
}
}

opt!(
TraceOpt,
(
Expand Down
2 changes: 2 additions & 0 deletions src/primitive/defs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3690,6 +3690,8 @@ impl_primitive!(
(1, CountUnique),
(1(2)[3], AstarFirst),
(1[3], AstarPop),
(2, SplitByScalar),
(2, SplitBy),
// Implementation details
(1[2], RepeatWithInverse),
(2(1), ValidateType),
Expand Down
15 changes: 13 additions & 2 deletions src/primitive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ impl fmt::Display for ImplPrimitive {
UndoChunks => write!(f, "{Un}{Chunks}"),
UndoWindows => write!(f, "{Un}{Windows}"),
UndoJoin => write!(f, "{Under}{Join}"),
// Optimizations
FirstMinIndex => write!(f, "{First}{Rise}"),
FirstMaxIndex => write!(f, "{First}{Fall}"),
LastMinIndex => write!(f, "{First}{Reverse}{Rise}"),
Expand All @@ -280,6 +281,8 @@ impl fmt::Display for ImplPrimitive {
MatchGe => write!(f, "match ≥"),
AstarFirst => write!(f, "{First}{Astar}"),
AstarPop => write!(f, "{Pop}{Astar}"),
SplitByScalar => write!(f, "{Partition}{Box}{By}{Ne}"),
SplitBy => write!(f, "{Partition}{Box}{Not}{By}{Mask}"),
&ReduceDepth(n) => {
for _ in 0..n {
write!(f, "{Rows}")?;
Expand Down Expand Up @@ -1028,8 +1031,14 @@ impl Primitive {
Primitive::Table => table::table(ops, env)?,
Primitive::Repeat => loops::repeat(ops, false, env)?,
Primitive::Do => loops::do_(ops, env)?,
Primitive::Group => loops::group(ops, env)?,
Primitive::Partition => loops::partition(ops, env)?,
Primitive::Group => {
let [f] = get_ops(ops, env)?;
loops::group(f, env)?
}
Primitive::Partition => {
let [f] = get_ops(ops, env)?;
loops::partition(f, env)?
}
Primitive::Tuples => permute::tuples(ops, env)?,

// Stack
Expand Down Expand Up @@ -1505,6 +1514,8 @@ impl ImplPrimitive {
env.push(random());
}
ImplPrimitive::CountUnique => env.monadic_ref(Value::count_unique)?,
ImplPrimitive::SplitByScalar => loops::split_by(true, env)?,
ImplPrimitive::SplitBy => loops::split_by(false, env)?,
ImplPrimitive::MatchPattern => {
let expected = env.pop(1)?;
let got = env.pop(2)?;
Expand Down
16 changes: 12 additions & 4 deletions tests/optimized.ua
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,18 @@
⍤⤙≍ ⊃≡(∘⊢∘)≡⊢ °△2_1_4
⍤⤙≍ ⊃≡≡(∘⊢∘)≡≡⊢ °△2_1_4

⍤⤙≍ ⊃≡(∘(⊢⇌)∘)≡(⊢⇌) °△2_3_4
⍤⤙≍ ⊃≡≡(∘(⊢⇌)∘)≡≡(⊢⇌) °△2_3_4
⍤⤙≍ ⊃≡(∘(⊢⇌)∘)≡(⊢⇌) °△2_1_4
⍤⤙≍ ⊃≡≡(∘(⊢⇌)∘)≡≡(⊢⇌) °△2_1_4
⍤⤙≍ ⊃≡(∘⊣∘)≡⊣ °△2_3_4
⍤⤙≍ ⊃≡≡(∘⊣∘)≡≡⊣ °△2_3_4
⍤⤙≍ ⊃≡(∘⊣∘)≡⊣ °△2_1_4
⍤⤙≍ ⊃≡≡(∘⊣∘)≡≡⊣ °△2_1_4

# Split by
⍤⤙≍ ⊃(⊜(□∘)⊸≠|⊜□⊸≠) @ " Hey there buddy "
⍤⤙≍ ⊃(⊜(□∘)≠|⊜□≠) @ . " Hey there buddy "
⍤⤙≍ ⊃(⊜(□∘)¬⊸⦷|⊜□¬⊸⦷) @ " Hey there buddy "
⍤⤙≍ ⊃(⊜(□∘)⊸≠|⊜□⊸≠) 5 ◿20 ⇡100
⍤⤙≍ ⊃(⊜(□∘)¬⊸⦷|⊜□¬⊸⦷) " - " " - Hey - there - buddy - "
⍤⤙≍ ⊃(⊜(□∘)¬⊸⦷|⊜□¬⊸⦷) +5⇡5 ◿20 ⇡100

# Reduce table
Test‼ ← (
Expand Down

0 comments on commit de23572

Please sign in to comment.