From 258d3725e73688ef08bc410e4cc8ee91a5641f13 Mon Sep 17 00:00:00 2001 From: Kim Morrison Date: Mon, 11 Nov 2024 18:53:24 +1100 Subject: [PATCH] feat: change Array.set to take a Nat and a tactic provided bound (#5988) This PR changes the signature of `Array.set` to take a `Nat`, and a tactic-provided bound, rather than a `Fin`. Corresponding changes (but without the auto-param) for `Array.get` will arrive shortly, after which I'll go more pervasively through the Array API. --- src/Init.lean | 1 + src/Init/Data/Array.lean | 1 + src/Init/Data/Array/Basic.lean | 12 +-- src/Init/Data/Array/BasicAux.lean | 2 +- src/Init/Data/Array/Lemmas.lean | 90 ++++++++----------- src/Init/Data/Array/Set.lean | 39 ++++++++ src/Init/Data/ByteArray/Basic.lean | 2 +- src/Init/Data/FloatArray/Basic.lean | 2 +- src/Init/Meta.lean | 3 +- src/Init/Prelude.lean | 61 ++----------- src/Init/Syntax.lean | 36 ++++++++ src/Lean/Data/PersistentArray.lean | 2 +- src/Lean/Elab/Inductive.lean | 5 +- src/Lean/Environment.lean | 2 +- src/Lean/Meta/Closure.lean | 2 +- src/Lean/Meta/DiscrTree.lean | 2 +- src/Lean/Meta/ExprDefEq.lean | 2 +- src/Lean/Meta/GeneralizeTelescope.lean | 2 +- src/Lean/Meta/Match/MatcherApp/Transform.lean | 2 +- src/Lean/Meta/SynthInstance.lean | 2 +- src/Std/Data/DHashMap/Internal/WF.lean | 2 +- src/Std/Sat/AIG/CNF.lean | 12 +-- src/lake/Lake/Toml/Data/Dict.lean | 2 +- tests/lean/arrayGetU.lean | 2 +- tests/lean/run/heapSort.lean | 8 +- tests/lean/run/inlineWithNestedRecIssue.lean | 2 +- tests/lean/run/issue3204.lean | 2 +- 27 files changed, 156 insertions(+), 144 deletions(-) create mode 100644 src/Init/Data/Array/Set.lean create mode 100644 src/Init/Syntax.lean diff --git a/src/Init.lean b/src/Init.lean index b4112020cacc..568452373ed2 100644 --- a/src/Init.lean +++ b/src/Init.lean @@ -36,3 +36,4 @@ import Init.Omega import Init.MacroTrace import Init.Grind import Init.While +import Init.Syntax diff --git a/src/Init/Data/Array.lean b/src/Init/Data/Array.lean index 423dae6b75bc..ab2e05a270c6 100644 --- a/src/Init/Data/Array.lean +++ b/src/Init/Data/Array.lean @@ -17,3 +17,4 @@ import Init.Data.Array.TakeDrop import Init.Data.Array.Bootstrap import Init.Data.Array.GetLit import Init.Data.Array.MapIdx +import Init.Data.Array.Set diff --git a/src/Init/Data/Array/Basic.lean b/src/Init/Data/Array/Basic.lean index 3e924e1b3c1c..f351e81fe508 100644 --- a/src/Init/Data/Array/Basic.lean +++ b/src/Init/Data/Array/Basic.lean @@ -12,6 +12,7 @@ import Init.Data.Repr import Init.Data.ToString.Basic import Init.GetElem import Init.Data.List.ToArray +import Init.Data.Array.Set universe u v w /-! ### Array literal syntax -/ @@ -29,7 +30,8 @@ namespace Array /-! ### Preliminary theorems -/ -@[simp] theorem size_set (a : Array α) (i : Fin a.size) (v : α) : (set a i v).size = a.size := +@[simp] theorem size_set (a : Array α) (i : Nat) (v : α) (h : i < a.size) : + (set a i v h).size = a.size := List.length_set .. @[simp] theorem size_push (a : Array α) (v : α) : (push a v).size = a.size + 1 := @@ -141,7 +143,7 @@ def uget (a : @& Array α) (i : USize) (h : i.toNat < a.size) : α := `fset` may be slightly slower than `uset`. -/ @[extern "lean_array_uset"] def uset (a : Array α) (i : USize) (v : α) (h : i.toNat < a.size) : Array α := - a.set ⟨i.toNat, h⟩ v + a.set i.toNat v h @[extern "lean_array_pop"] def pop (a : Array α) : Array α where @@ -167,10 +169,10 @@ def swap (a : Array α) (i j : @& Fin a.size) : Array α := let v₁ := a.get i let v₂ := a.get j let a' := a.set i v₂ - a'.set (size_set a i v₂ ▸ j) v₁ + a'.set j v₁ (Nat.lt_of_lt_of_eq j.isLt (size_set a i v₂ _).symm) @[simp] theorem size_swap (a : Array α) (i j : Fin a.size) : (a.swap i j).size = a.size := by - show ((a.set i (a.get j)).set (size_set a i _ ▸ j) (a.get i)).size = a.size + show ((a.set i (a.get j)).set j (a.get i) (Nat.lt_of_lt_of_eq j.isLt (size_set a i (a.get j) _).symm)).size = a.size rw [size_set, size_set] /-- @@ -278,7 +280,7 @@ unsafe def modifyMUnsafe [Monad m] (a : Array α) (i : Nat) (f : α → m α) : -- of the element type, and that it is valid to store `box(0)` in any array. let a' := a.set idx (unsafeCast ()) let v ← f v - pure <| a'.set (size_set a .. ▸ idx) v + pure <| a'.set idx v (Nat.lt_of_lt_of_eq h (size_set a ..).symm) else pure a diff --git a/src/Init/Data/Array/BasicAux.lean b/src/Init/Data/Array/BasicAux.lean index 21f9cc3a2f4b..846183b5ab96 100644 --- a/src/Init/Data/Array/BasicAux.lean +++ b/src/Init/Data/Array/BasicAux.lean @@ -60,7 +60,7 @@ where if ptrEq a b then go (i+1) as else - go (i+1) (as.set ⟨i, h⟩ b) + go (i+1) (as.set i b h) else return as diff --git a/src/Init/Data/Array/Lemmas.lean b/src/Init/Data/Array/Lemmas.lean index b4c1ef87d4d4..c16f8dbcbfae 100644 --- a/src/Init/Data/Array/Lemmas.lean +++ b/src/Init/Data/Array/Lemmas.lean @@ -483,25 +483,26 @@ theorem get!_eq_getD [Inhabited α] (a : Array α) : a.get! n = a.getD n default /-! # set -/ -@[simp] theorem getElem_set_eq (a : Array α) (i : Fin a.size) (v : α) {j : Nat} - (eq : i.val = j) (p : j < (a.set i v).size) : +@[simp] theorem getElem_set_eq (a : Array α) (i : Nat) (h : i < a.size) (v : α) {j : Nat} + (eq : i = j) (p : j < (a.set i v).size) : (a.set i v)[j]'p = v := by simp [set, getElem_eq_getElem_toList, ←eq] -@[simp] theorem getElem_set_ne (a : Array α) (i : Fin a.size) (v : α) {j : Nat} (pj : j < (a.set i v).size) - (h : i.val ≠ j) : (a.set i v)[j]'pj = a[j]'(size_set a i v ▸ pj) := by +@[simp] theorem getElem_set_ne (a : Array α) (i : Nat) (h' : i < a.size) (v : α) {j : Nat} + (pj : j < (a.set i v).size) (h : i ≠ j) : + (a.set i v)[j]'pj = a[j]'(size_set a i v _ ▸ pj) := by simp only [set, getElem_eq_getElem_toList, List.getElem_set_ne h] -theorem getElem_set (a : Array α) (i : Fin a.size) (v : α) (j : Nat) +theorem getElem_set (a : Array α) (i : Nat) (h' : i < a.size) (v : α) (j : Nat) (h : j < (a.set i v).size) : - (a.set i v)[j]'h = if i = j then v else a[j]'(size_set a i v ▸ h) := by - by_cases p : i.1 = j <;> simp [p] + (a.set i v)[j]'h = if i = j then v else a[j]'(size_set a i v _ ▸ h) := by + by_cases p : i = j <;> simp [p] -@[simp] theorem getElem?_set_eq (a : Array α) (i : Fin a.size) (v : α) : - (a.set i v)[i.1]? = v := by simp [getElem?_lt, i.2] +@[simp] theorem getElem?_set_eq (a : Array α) (i : Nat) (h : i < a.size) (v : α) : + (a.set i v)[i]? = v := by simp [getElem?_lt, h] -@[simp] theorem getElem?_set_ne (a : Array α) (i : Fin a.size) {j : Nat} (v : α) - (ne : i.val ≠ j) : (a.set i v)[j]? = a[j]? := by +@[simp] theorem getElem?_set_ne (a : Array α) (i : Nat) (h : i < a.size) {j : Nat} (v : α) + (ne : i ≠ j) : (a.set i v)[j]? = a[j]? := by by_cases h : j < a.size <;> simp [getElem?_lt, getElem?_ge, Nat.ge_of_not_lt, ne, h] /-! # setD -/ @@ -518,7 +519,7 @@ theorem getElem_set (a : Array α) (i : Fin a.size) (v : α) (j : Nat) @[simp] theorem getElem_setD_eq (a : Array α) {i : Nat} (v : α) (h : _) : (setD a i v)[i]'h = v := by simp at h - simp only [setD, h, dite_true, getElem_set, ite_true] + simp only [setD, h, ↓reduceDIte, getElem_set_eq] @[simp] theorem getElem?_setD_eq (a : Array α) {i : Nat} (p : i < a.size) (v : α) : (a.setD i v)[i]? = some v := by @@ -693,43 +694,43 @@ theorem getElem?_push {a : Array α} : (a.push x)[i]? = if i = a.size then some @[deprecated getElem?_size (since := "2024-10-21")] abbrev get?_size := @getElem?_size -@[simp] theorem toList_set (a : Array α) (i v) : (a.set i v).toList = a.toList.set i.1 v := rfl +@[simp] theorem toList_set (a : Array α) (i v h) : (a.set i v).toList = a.toList.set i v := rfl -theorem get_set_eq (a : Array α) (i : Fin a.size) (v : α) : - (a.set i v)[i.1] = v := by +theorem get_set_eq (a : Array α) (i : Nat) (v : α) (h : i < a.size) : + (a.set i v h)[i]'(by simp [h]) = v := by simp only [set, getElem_eq_getElem_toList, List.getElem_set_self] -theorem get?_set_eq (a : Array α) (i : Fin a.size) (v : α) : - (a.set i v)[i.1]? = v := by simp [getElem?_pos, i.2] +theorem get?_set_eq (a : Array α) (i : Nat) (v : α) (h : i < a.size) : + (a.set i v)[i]? = v := by simp [getElem?_pos, h] -@[simp] theorem get?_set_ne (a : Array α) (i : Fin a.size) {j : Nat} (v : α) - (h : i.1 ≠ j) : (a.set i v)[j]? = a[j]? := by +@[simp] theorem get?_set_ne (a : Array α) (i : Nat) (h' : i < a.size) {j : Nat} (v : α) + (h : i ≠ j) : (a.set i v)[j]? = a[j]? := by by_cases j < a.size <;> simp [getElem?_pos, getElem?_neg, *] -theorem get?_set (a : Array α) (i : Fin a.size) (j : Nat) (v : α) : - (a.set i v)[j]? = if i.1 = j then some v else a[j]? := by - if h : i.1 = j then subst j; simp [*] else simp [*] +theorem get?_set (a : Array α) (i : Nat) (h : i < a.size) (j : Nat) (v : α) : + (a.set i v)[j]? = if i = j then some v else a[j]? := by + if h : i = j then subst j; simp [*] else simp [*] -theorem get_set (a : Array α) (i : Fin a.size) (j : Nat) (hj : j < a.size) (v : α) : +theorem get_set (a : Array α) (i : Nat) (hi : i < a.size) (j : Nat) (hj : j < a.size) (v : α) : (a.set i v)[j]'(by simp [*]) = if i = j then v else a[j] := by - if h : i.1 = j then subst j; simp [*] else simp [*] + if h : i = j then subst j; simp [*] else simp [*] -@[simp] theorem get_set_ne (a : Array α) (i : Fin a.size) {j : Nat} (v : α) (hj : j < a.size) - (h : i.1 ≠ j) : (a.set i v)[j]'(by simp [*]) = a[j] := by +@[simp] theorem get_set_ne (a : Array α) (i : Nat) (hi : i < a.size) {j : Nat} (v : α) (hj : j < a.size) + (h : i ≠ j) : (a.set i v)[j]'(by simp [*]) = a[j] := by simp only [set, getElem_eq_getElem_toList, List.getElem_set_ne h] theorem getElem_setD (a : Array α) (i : Nat) (v : α) (h : i < (setD a i v).size) : (setD a i v)[i] = v := by simp at h - simp only [setD, h, dite_true, get_set, ite_true] + simp only [setD, h, ↓reduceDIte, getElem_set_eq] -theorem set_set (a : Array α) (i : Fin a.size) (v v' : α) : - (a.set i v).set ⟨i, by simp [i.2]⟩ v' = a.set i v' := by simp [set, List.set_set] +theorem set_set (a : Array α) (i : Nat) (h) (v v' : α) : + (a.set i v h).set i v' (by simp [h]) = a.set i v' := by simp [set, List.set_set] private theorem fin_cast_val (e : n = n') (i : Fin n) : e ▸ i = ⟨i.1, e ▸ i.2⟩ := by cases e; rfl theorem swap_def (a : Array α) (i j : Fin a.size) : - a.swap i j = (a.set i (a.get j)).set ⟨j.1, by simp [j.2]⟩ (a.get i) := by + a.swap i j = (a.set i (a.get j)).set j (a.get i) := by simp [swap, fin_cast_val] @[simp] theorem toList_swap (a : Array α) (i j : Fin a.size) : @@ -747,7 +748,7 @@ theorem getElem?_swap (a : Array α) (i j : Fin a.size) (k : Nat) : (a.swap i j) @[simp] theorem swapAt!_def (a : Array α) (i : Nat) (v : α) (h : i < a.size) : - a.swapAt! i v = (a[i], a.set ⟨i, h⟩ v) := by simp [swapAt!, h] + a.swapAt! i v = (a[i], a.set i v) := by simp [swapAt!, h] @[simp] theorem size_swapAt! (a : Array α) (i : Nat) (v : α) : (a.swapAt! i v).2.size = a.size := by @@ -1112,7 +1113,7 @@ theorem getElem_modify {as : Array α} {x i} (h : i < (as.modify x f).size) : (as.modify x f)[i] = if x = i then f (as[i]'(by simpa using h)) else as[i]'(by simpa using h) := by simp only [modify, modifyM, get_eq_getElem, Id.run, Id.pure_eq] split - · simp only [Id.bind_eq, get_set _ _ _ (by simpa using h)]; split <;> simp [*] + · simp only [Id.bind_eq, get_set _ _ _ _ (by simpa using h)]; split <;> simp [*] · rw [if_neg (mt (by rintro rfl; exact h) (by simp_all))] @[simp] theorem toList_modify (as : Array α) (f : α → α) : @@ -1541,30 +1542,15 @@ instance [DecidableEq α] (a : α) (as : Array α) : Decidable (a ∈ as) := open Fin -@[simp] theorem getElem_swap_right (a : Array α) {i j : Fin a.size} : (a.swap i j)[j.val] = a[i] := - by simp only [swap, fin_cast_val, get_eq_getElem, getElem_set_eq, getElem_fin] +@[simp] theorem getElem_swap_right (a : Array α) {i j : Fin a.size} : (a.swap i j)[j.1] = a[i] := by + simp [swap_def, getElem_set] -@[simp] theorem getElem_swap_left (a : Array α) {i j : Fin a.size} : (a.swap i j)[i.val] = a[j] := - if he : ((Array.size_set _ _ _).symm ▸ j).val = i.val then by - simp only [←he, fin_cast_val, getElem_swap_right, getElem_fin] - else by - apply Eq.trans - · apply Array.get_set_ne - · simp only [size_set, Fin.isLt] - · assumption - · simp [get_set_ne] +@[simp] theorem getElem_swap_left (a : Array α) {i j : Fin a.size} : (a.swap i j)[i.1] = a[j] := by + simp +contextual [swap_def, getElem_set] @[simp] theorem getElem_swap_of_ne (a : Array α) {i j : Fin a.size} (hp : p < a.size) (hi : p ≠ i) (hj : p ≠ j) : (a.swap i j)[p]'(a.size_swap .. |>.symm ▸ hp) = a[p] := by - apply Eq.trans - · have : ((a.size_set i (a.get j)).symm ▸ j).val = j.val := by simp only [fin_cast_val] - apply Array.get_set_ne - · simp only [this] - apply Ne.symm - · assumption - · apply Array.get_set_ne - · apply Ne.symm - · assumption + simp [swap_def, getElem_set, hi.symm, hj.symm] theorem getElem_swap' (a : Array α) (i j : Fin a.size) (k : Nat) (hk : k < a.size) : (a.swap i j)[k]'(by simp_all) = if k = i then a[j] else if k = j then a[i] else a[k] := by diff --git a/src/Init/Data/Array/Set.lean b/src/Init/Data/Array/Set.lean new file mode 100644 index 000000000000..9fdec052abef --- /dev/null +++ b/src/Init/Data/Array/Set.lean @@ -0,0 +1,39 @@ +/- +Copyright (c) 2020 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura, Mario Carneiro +-/ +prelude +import Init.Tactics + + +/-- +Set an element in an array, using a proof that the index is in bounds. +(This proof can usually be omitted, and will be synthesized automatically.) + +This will perform the update destructively provided that `a` has a reference +count of 1 when called. +-/ +@[extern "lean_array_fset"] +def Array.set (a : Array α) (i : @& Nat) (v : α) (h : i < a.size := by get_elem_tactic) : + Array α where + toList := a.toList.set i v + +/-- +Set an element in an array, or do nothing if the index is out of bounds. + +This will perform the update destructively provided that `a` has a reference +count of 1 when called. +-/ +@[inline] def Array.setD (a : Array α) (i : Nat) (v : α) : Array α := + dite (LT.lt i a.size) (fun h => a.set i v h) (fun _ => a) + +/-- +Set an element in an array, or panic if the index is out of bounds. + +This will perform the update destructively provided that `a` has a reference +count of 1 when called. +-/ +@[extern "lean_array_set"] +def Array.set! (a : Array α) (i : @& Nat) (v : α) : Array α := + Array.setD a i v diff --git a/src/Init/Data/ByteArray/Basic.lean b/src/Init/Data/ByteArray/Basic.lean index 4d34ae368f96..2e3efca53bbd 100644 --- a/src/Init/Data/ByteArray/Basic.lean +++ b/src/Init/Data/ByteArray/Basic.lean @@ -65,7 +65,7 @@ def set! : ByteArray → (@& Nat) → UInt8 → ByteArray @[extern "lean_byte_array_fset"] def set : (a : ByteArray) → (@& Fin a.size) → UInt8 → ByteArray - | ⟨bs⟩, i, b => ⟨bs.set i b⟩ + | ⟨bs⟩, i, b => ⟨bs.set i.1 b i.2⟩ @[extern "lean_byte_array_uset"] def uset : (a : ByteArray) → (i : USize) → UInt8 → i.toNat < a.size → ByteArray diff --git a/src/Init/Data/FloatArray/Basic.lean b/src/Init/Data/FloatArray/Basic.lean index ebe5e17b8e08..44ca479f18c7 100644 --- a/src/Init/Data/FloatArray/Basic.lean +++ b/src/Init/Data/FloatArray/Basic.lean @@ -71,7 +71,7 @@ def uset : (a : FloatArray) → (i : USize) → Float → i.toNat < a.size → F @[extern "lean_float_array_fset"] def set : (ds : FloatArray) → (@& Fin ds.size) → Float → FloatArray - | ⟨ds⟩, i, d => ⟨ds.set i d⟩ + | ⟨ds⟩, i, d => ⟨ds.set i.1 d i.2⟩ @[extern "lean_float_array_set"] def set! : FloatArray → (@& Nat) → Float → FloatArray diff --git a/src/Init/Meta.lean b/src/Init/Meta.lean index 41121ab15fa5..146b068820c3 100644 --- a/src/Init/Meta.lean +++ b/src/Init/Meta.lean @@ -7,6 +7,7 @@ Additional goodies for writing macros -/ prelude import Init.MetaTypes +import Init.Syntax import Init.Data.Array.GetLit import Init.Data.Option.BasicAux @@ -442,7 +443,7 @@ def unsetTrailing (stx : Syntax) : Syntax := if h : i < a.size then let v := a[i] match f v with - | some v => some <| a.set ⟨i, h⟩ v + | some v => some <| a.set i v h | none => updateFirst a f (i+1) else none diff --git a/src/Init/Prelude.lean b/src/Init/Prelude.lean index 2f34e230a517..8942a4a142de 100644 --- a/src/Init/Prelude.lean +++ b/src/Init/Prelude.lean @@ -2688,35 +2688,6 @@ def Array.mkArray7 {α : Type u} (a₁ a₂ a₃ a₄ a₅ a₆ a₇ : α) : Arr def Array.mkArray8 {α : Type u} (a₁ a₂ a₃ a₄ a₅ a₆ a₇ a₈ : α) : Array α := ((((((((mkEmpty 8).push a₁).push a₂).push a₃).push a₄).push a₅).push a₆).push a₇).push a₈ -/-- -Set an element in an array without bounds checks, using a `Fin` index. - -This will perform the update destructively provided that `a` has a reference -count of 1 when called. --/ -@[extern "lean_array_fset"] -def Array.set (a : Array α) (i : @& Fin a.size) (v : α) : Array α where - toList := a.toList.set i.val v - -/-- -Set an element in an array, or do nothing if the index is out of bounds. - -This will perform the update destructively provided that `a` has a reference -count of 1 when called. --/ -@[inline] def Array.setD (a : Array α) (i : Nat) (v : α) : Array α := - dite (LT.lt i a.size) (fun h => a.set ⟨i, h⟩ v) (fun _ => a) - -/-- -Set an element in an array, or panic if the index is out of bounds. - -This will perform the update destructively provided that `a` has a reference -count of 1 when called. --/ -@[extern "lean_array_set"] -def Array.set! (a : Array α) (i : @& Nat) (v : α) : Array α := - Array.setD a i v - /-- Slower `Array.append` used in quotations. -/ protected def Array.appendCore {α : Type u} (as : Array α) (bs : Array α) : Array α := let rec loop (i : Nat) (j : Nat) (as : Array α) : Array α := @@ -3637,6 +3608,13 @@ def appendCore : Name → Name → Name end Name +/-- The default maximum recursion depth. This is adjustable using the `maxRecDepth` option. -/ +def defaultMaxRecDepth := 512 + +/-- The message to display on stack overflow. -/ +def maxRecDepthErrorMessage : String := + "maximum recursion depth has been reached\nuse `set_option maxRecDepth ` to increase limit\nuse `set_option diagnostics true` to get diagnostic information" + /-! # Syntax -/ /-- Source information of tokens. -/ @@ -3969,24 +3947,6 @@ def getId : Syntax → Name | ident _ _ val _ => val | _ => Name.anonymous -/-- -Updates the argument list without changing the node kind. -Does nothing for non-`node` nodes. --/ -def setArgs (stx : Syntax) (args : Array Syntax) : Syntax := - match stx with - | node info k _ => node info k args - | stx => stx - -/-- -Updates the `i`'th argument of the syntax. -Does nothing for non-`node` nodes, or if `i` is out of bounds of the node list. --/ -def setArg (stx : Syntax) (i : Nat) (arg : Syntax) : Syntax := - match stx with - | node info k args => node info k (args.setD i arg) - | stx => stx - /-- Retrieve the left-most node or leaf's info in the Syntax tree. -/ partial def getHeadInfo? : Syntax → Option SourceInfo | atom info _ => some info @@ -4423,13 +4383,6 @@ main module and current macro scope. bind getCurrMacroScope fun scp => pure (Lean.addMacroScope mainModule n scp) -/-- The default maximum recursion depth. This is adjustable using the `maxRecDepth` option. -/ -def defaultMaxRecDepth := 512 - -/-- The message to display on stack overflow. -/ -def maxRecDepthErrorMessage : String := - "maximum recursion depth has been reached\nuse `set_option maxRecDepth ` to increase limit\nuse `set_option diagnostics true` to get diagnostic information" - namespace Syntax /-- Is this syntax a null `node`? -/ diff --git a/src/Init/Syntax.lean b/src/Init/Syntax.lean new file mode 100644 index 000000000000..c9e589af16e3 --- /dev/null +++ b/src/Init/Syntax.lean @@ -0,0 +1,36 @@ +/- +Copyright (c) 2020 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura, Mario Carneiro +-/ + +prelude +import Init.Data.Array.Set + +/-! +# Helper functions for `Syntax`. + +These are delayed here to allow some time to bootstrap `Array`. +-/ + +namespace Lean.Syntax + +/-- +Updates the argument list without changing the node kind. +Does nothing for non-`node` nodes. +-/ +def setArgs (stx : Syntax) (args : Array Syntax) : Syntax := + match stx with + | node info k _ => node info k args + | stx => stx + +/-- +Updates the `i`'th argument of the syntax. +Does nothing for non-`node` nodes, or if `i` is out of bounds of the node list. +-/ +def setArg (stx : Syntax) (i : Nat) (arg : Syntax) : Syntax := + match stx with + | node info k args => node info k (args.setD i arg) + | stx => stx + +end Lean.Syntax diff --git a/src/Lean/Data/PersistentArray.lean b/src/Lean/Data/PersistentArray.lean index 427d17fc38af..095c3f088f8b 100644 --- a/src/Lean/Data/PersistentArray.lean +++ b/src/Lean/Data/PersistentArray.lean @@ -159,7 +159,7 @@ partial def popLeaf : PersistentArrayNode α → Option (Array α) × Array (Per let cs' := cs'.pop if cs'.isEmpty then (some l, emptyArray) else (some l, cs') else - (some l, cs'.set (Array.size_set cs idx _ ▸ idx) (node newLast)) + (some l, cs'.set idx (node newLast) (by simp only [cs', Array.size_set]; omega)) else (none, emptyArray) | leaf vs => (some vs, emptyArray) diff --git a/src/Lean/Elab/Inductive.lean b/src/Lean/Elab/Inductive.lean index cb59576744d6..671fb6bf8e7d 100644 --- a/src/Lean/Elab/Inductive.lean +++ b/src/Lean/Elab/Inductive.lean @@ -740,10 +740,7 @@ private def getArity (indType : InductiveType) : MetaM Nat := forallTelescopeReducing indType.type fun xs _ => return xs.size private def resetMaskAt (mask : Array Bool) (i : Nat) : Array Bool := - if h : i < mask.size then - mask.set ⟨i, h⟩ false - else - mask + mask.setD i false /-- Compute a bit-mask that for `indType`. The size of the resulting array `result` is the arity of `indType`. diff --git a/src/Lean/Environment.lean b/src/Lean/Environment.lean index ee5ff34b53e2..f69679b62729 100644 --- a/src/Lean/Environment.lean +++ b/src/Lean/Environment.lean @@ -328,7 +328,7 @@ private def invalidExtMsg := "invalid environment extension has been accessed" unsafe def setState {σ} (ext : Ext σ) (exts : Array EnvExtensionState) (s : σ) : Array EnvExtensionState := if h : ext.idx < exts.size then - exts.set ⟨ext.idx, h⟩ (unsafeCast s) + exts.set ext.idx (unsafeCast s) else have : Inhabited (Array EnvExtensionState) := ⟨exts⟩ panic! invalidExtMsg diff --git a/src/Lean/Meta/Closure.lean b/src/Lean/Meta/Closure.lean index ba9b9456a638..be0b42d85620 100644 --- a/src/Lean/Meta/Closure.lean +++ b/src/Lean/Meta/Closure.lean @@ -226,7 +226,7 @@ partial def pickNextToProcessAux (lctx : LocalContext) (i : Nat) (toProcess : Ar if h : i < toProcess.size then let elem' := toProcess.get ⟨i, h⟩ if (lctx.get! elem.fvarId).index < (lctx.get! elem'.fvarId).index then - pickNextToProcessAux lctx (i+1) (toProcess.set ⟨i, h⟩ elem) elem' + pickNextToProcessAux lctx (i+1) (toProcess.set i elem) elem' else pickNextToProcessAux lctx (i+1) toProcess elem else diff --git a/src/Lean/Meta/DiscrTree.lean b/src/Lean/Meta/DiscrTree.lean index a3114086f72a..4b2bb27f00cf 100644 --- a/src/Lean/Meta/DiscrTree.lean +++ b/src/Lean/Meta/DiscrTree.lean @@ -460,7 +460,7 @@ where loop (i : Nat) : Array α := if h : i < vs.size then if v == vs[i] then - vs.set ⟨i,h⟩ v + vs.set i v else loop (i+1) else diff --git a/src/Lean/Meta/ExprDefEq.lean b/src/Lean/Meta/ExprDefEq.lean index 9a6e1b807782..10c65bf862a3 100644 --- a/src/Lean/Meta/ExprDefEq.lean +++ b/src/Lean/Meta/ExprDefEq.lean @@ -1205,7 +1205,7 @@ private partial def processAssignment (mvarApp : Expr) (v : Expr) : MetaM Bool : if h : i < args.size then let arg := args.get ⟨i, h⟩ let arg ← simpAssignmentArg arg - let args := args.set ⟨i, h⟩ arg + let args := args.set i arg match arg with | Expr.fvar fvarId => if args[0:i].any fun prevArg => prevArg == arg then diff --git a/src/Lean/Meta/GeneralizeTelescope.lean b/src/Lean/Meta/GeneralizeTelescope.lean index 7fadafbd9e4c..9d6becf61ab5 100644 --- a/src/Lean/Meta/GeneralizeTelescope.lean +++ b/src/Lean/Meta/GeneralizeTelescope.lean @@ -23,7 +23,7 @@ partial def updateTypes (e eNew : Expr) (entries : Array Entry) (i : Nat) : Meta let typeAbst ← kabstract type e if typeAbst.hasLooseBVars then do let typeNew := typeAbst.instantiate1 eNew - let entries := entries.set ⟨i, h⟩ { entry with type := typeNew, modified := true } + let entries := entries.set i { entry with type := typeNew, modified := true } updateTypes e eNew entries (i+1) else updateTypes e eNew entries (i+1) diff --git a/src/Lean/Meta/Match/MatcherApp/Transform.lean b/src/Lean/Meta/Match/MatcherApp/Transform.lean index 556f33d6d200..3de48b61f910 100644 --- a/src/Lean/Meta/Match/MatcherApp/Transform.lean +++ b/src/Lean/Meta/Match/MatcherApp/Transform.lean @@ -29,7 +29,7 @@ private partial def updateAlts (unrefinedArgType : Expr) (typeNew : Expr) (altNu else pure <| !(← isDefEq unrefinedArgType (← inferType x[0]!)) return (← mkLambdaFVars xs alt, refined) - updateAlts unrefinedArgType (b.instantiate1 alt) (altNumParams.set! i (numParams+1)) (alts.set ⟨i, h⟩ alt) refined (i+1) + updateAlts unrefinedArgType (b.instantiate1 alt) (altNumParams.set! i (numParams+1)) (alts.set i alt) refined (i+1) | _ => throwError "unexpected type at MatcherApp.addArg" else if refined then diff --git a/src/Lean/Meta/SynthInstance.lean b/src/Lean/Meta/SynthInstance.lean index 2e16dbd2c27f..d18e8d4be0db 100644 --- a/src/Lean/Meta/SynthInstance.lean +++ b/src/Lean/Meta/SynthInstance.lean @@ -671,7 +671,7 @@ private partial def preprocessArgs (type : Expr) (i : Nat) (args : Array Expr) ( If an instance implicit argument depends on an `outParam`, it is treated as an `outParam` too. -/ let arg ← if outParamsPos.contains i then mkFreshExprMVar d else pure arg - let args := args.set ⟨i, h⟩ arg + let args := args.set i arg preprocessArgs (b.instantiate1 arg) (i+1) args outParamsPos | _ => throwError "type class resolution failed, insufficient number of arguments" -- TODO improve error message diff --git a/src/Std/Data/DHashMap/Internal/WF.lean b/src/Std/Data/DHashMap/Internal/WF.lean index f72c379904b3..61ecde77f2be 100644 --- a/src/Std/Data/DHashMap/Internal/WF.lean +++ b/src/Std/Data/DHashMap/Internal/WF.lean @@ -167,7 +167,7 @@ theorem toListModel_foldl_reinsertAux [BEq α] [Hashable α] [PartialEquivBEq α theorem expand.go_pos [Hashable α] {i : Nat} {source : Array (AssocList α β)} {target : { d : Array (AssocList α β) // 0 < d.size }} (h : i < source.size) : expand.go i source target = go (i + 1) - (source.set ⟨i, h⟩ .nil) ((source.get ⟨i, h⟩).foldl (reinsertAux hash) target) := by + (source.set i .nil) ((source.get ⟨i, h⟩).foldl (reinsertAux hash) target) := by rw [expand.go] simp only [h, dite_true] diff --git a/src/Std/Sat/AIG/CNF.lean b/src/Std/Sat/AIG/CNF.lean index c81f3dee528e..95ecaa3ae8d5 100644 --- a/src/Std/Sat/AIG/CNF.lean +++ b/src/Std/Sat/AIG/CNF.lean @@ -243,7 +243,7 @@ theorem Cache.IsExtensionBy_rfl (cache : Cache aig cnf) {h} (hmarked : cache.mar · exact hmarked theorem Cache.IsExtensionBy_set (cache1 : Cache aig cnf1) (cache2 : Cache aig cnf2) (idx : Nat) - (hbound : idx < cache1.marks.size) (h : cache2.marks = cache1.marks.set ⟨idx, hbound⟩ true) : + (hbound : idx < cache1.marks.size) (h : cache2.marks = cache1.marks.set idx true) : IsExtensionBy cache1 cache2 idx (by have := cache1.hmarks; omega) := by apply IsExtensionBy.mk · intro idx hidx hmark @@ -271,7 +271,7 @@ def Cache.addConst (cache : Cache aig cnf) (idx : Nat) (h : idx < aig.decls.size have hmarkbound : idx < cache.marks.size := by have := cache.hmarks; omega let out := { cache with - marks := cache.marks.set ⟨idx, hmarkbound⟩ true + marks := cache.marks.set idx true hmarks := by simp [cache.hmarks] inv := by constructor @@ -285,7 +285,6 @@ def Cache.addConst (cache : Cache aig cnf) (idx : Nat) (h : idx < aig.decls.size rw [Array.getElem_set] at hmarked split at hmarked · next heq => - dsimp only at heq simp only [heq, CNF.eval_append, Decl.constToCNF_eval, Bool.and_eq_true, beq_iff_eq] at htip heval simp only [denote_idx_const htip, projectRightAssign_property, heval] @@ -309,7 +308,7 @@ def Cache.addAtom (cache : Cache aig cnf) (idx : Nat) (h : idx < aig.decls.size) have hmarkbound : idx < cache.marks.size := by have := cache.hmarks; omega let out := { cache with - marks := cache.marks.set ⟨idx, hmarkbound⟩ true + marks := cache.marks.set idx true hmarks := by simp [cache.hmarks] inv := by constructor @@ -323,7 +322,6 @@ def Cache.addAtom (cache : Cache aig cnf) (idx : Nat) (h : idx < aig.decls.size) rw [Array.getElem_set] at hmarked split at hmarked · next heq => - dsimp only at heq simp only [heq, CNF.eval_append, Decl.atomToCNF_eval, Bool.and_eq_true, beq_iff_eq] at htip heval simp [heval, denote_idx_atom htip] · next heq => @@ -356,7 +354,7 @@ def Cache.addGate (cache : Cache aig cnf) {hlb} {hrb} (idx : Nat) (h : idx < aig have hmarkbound : idx < cache.marks.size := by have := cache.hmarks; omega let out := { cache with - marks := cache.marks.set ⟨idx, hmarkbound⟩ true + marks := cache.marks.set idx true hmarks := by simp [cache.hmarks] inv := by constructor @@ -364,7 +362,6 @@ def Cache.addGate (cache : Cache aig cnf) {hlb} {hrb} (idx : Nat) (h : idx < aig rw [Array.getElem_set] at hmarked split at hmarked · next heq2 => - simp only at heq2 simp only [heq2] at htip rw [htip] at heq cases heq @@ -375,7 +372,6 @@ def Cache.addGate (cache : Cache aig cnf) {hlb} {hrb} (idx : Nat) (h : idx < aig rw [Array.getElem_set] at hmarked split at hmarked · next heq => - dsimp only at heq simp only [heq, CNF.eval_append, Decl.gateToCNF_eval, Bool.and_eq_true, beq_iff_eq] at htip heval have hleval := cache.inv.heval assign heval.right lhs (by omega) hl diff --git a/src/lake/Lake/Toml/Data/Dict.lean b/src/lake/Lake/Toml/Data/Dict.lean index e6db14668baf..ee7685c84543 100644 --- a/src/lake/Lake/Toml/Data/Dict.lean +++ b/src/lake/Lake/Toml/Data/Dict.lean @@ -80,7 +80,7 @@ def push (k : α) (v : β) (t : RBDict α β cmp) : RBDict α β cmp := def insert (k : α) (v : β) (t : RBDict α β cmp) : RBDict α β cmp := if let some i := t.findIdx? k then if h : i < t.items.size then - {t with items := t.items.set ⟨i,h⟩ (k,v)} + {t with items := t.items.set i (k,v)} else t.push k v else diff --git a/tests/lean/arrayGetU.lean b/tests/lean/arrayGetU.lean index 8b926f3ecc1e..3bbdaa5e30cb 100644 --- a/tests/lean/arrayGetU.lean +++ b/tests/lean/arrayGetU.lean @@ -1,5 +1,5 @@ def f (a : Array Nat) (i : Nat) (v : Nat) (h : i < a.size) : Array Nat := - a.set ⟨i, h⟩ (a.get ⟨i, h⟩ + v) + a.set i (a.get ⟨i, h⟩ + v) set_option pp.proofs true diff --git a/tests/lean/run/heapSort.lean b/tests/lean/run/heapSort.lean index fa9e9575ca48..7fff1008b5f7 100644 --- a/tests/lean/run/heapSort.lean +++ b/tests/lean/run/heapSort.lean @@ -132,7 +132,7 @@ def insertExtractMax {lt} (self : BinaryHeap α lt) (x : α) : α × BinaryHeap | none => (x, self) | some m => if lt x m then - let a := self.1.set ⟨0, size_pos_of_max e⟩ x + let a := self.1.set 0 x (size_pos_of_max e) (m, ⟨heapifyDown lt a ⟨0, by simp only [Array.size_set, a]; exact size_pos_of_max e⟩⟩) else (x, self) @@ -141,16 +141,16 @@ def replaceMax {lt} (self : BinaryHeap α lt) (x : α) : Option α × BinaryHeap match e: self.max with | none => (none, ⟨self.1.push x⟩) | some m => - let a := self.1.set ⟨0, size_pos_of_max e⟩ x + let a := self.1.set 0 x (size_pos_of_max e) (some m, ⟨heapifyDown lt a ⟨0, by simp only [Array.size_set, a]; exact size_pos_of_max e⟩⟩) /-- `O(log n)`. Replace the value at index `i` by `x`. Assumes that `x ≤ self.get i`. -/ def decreaseKey {lt} (self : BinaryHeap α lt) (i : Fin self.size) (x : α) : BinaryHeap α lt where - arr := heapifyDown lt (self.1.set i x) ⟨i, by rw [self.1.size_set]; exact i.2⟩ + arr := heapifyDown lt (self.1.set i x i.2) ⟨i, by rw [self.1.size_set]; exact i.2⟩ /-- `O(log n)`. Replace the value at index `i` by `x`. Assumes that `self.get i ≤ x`. -/ def increaseKey {lt} (self : BinaryHeap α lt) (i : Fin self.size) (x : α) : BinaryHeap α lt where - arr := heapifyUp lt (self.1.set i x) ⟨i, by rw [self.1.size_set]; exact i.2⟩ + arr := heapifyUp lt (self.1.set i x i.2) ⟨i, by rw [self.1.size_set]; exact i.2⟩ end BinaryHeap diff --git a/tests/lean/run/inlineWithNestedRecIssue.lean b/tests/lean/run/inlineWithNestedRecIssue.lean index 8fa2e03bcd69..1b865d7a7562 100644 --- a/tests/lean/run/inlineWithNestedRecIssue.lean +++ b/tests/lean/run/inlineWithNestedRecIssue.lean @@ -10,7 +10,7 @@ where if ptrEq a b then go (i+1) as else - go (i+1) (as.set ⟨i, h⟩ b) + go (i+1) (as.set i b) else return as diff --git a/tests/lean/run/issue3204.lean b/tests/lean/run/issue3204.lean index e15c8c5033a8..835923bcb09c 100644 --- a/tests/lean/run/issue3204.lean +++ b/tests/lean/run/issue3204.lean @@ -1,6 +1,6 @@ def zero_out (arr : Array Nat) (i : Nat) : Array Nat := if h : i < arr.size then - zero_out (arr.set ⟨i, h⟩ 0) (i + 1) + zero_out (arr.set i 0) (i + 1) else arr termination_by arr.size - i