Skip to content

Commit

Permalink
feat: try? composite suggestions (#6979)
Browse files Browse the repository at this point in the history
This PR adds support for more complex suggestions in `try?`. Example:
```lean
example (as : List α) (a : α) : concat as a = as ++ [a] := by
  try?
```
suggestion
```
Try this: · induction as, a using concat.induct
  · rfl
  · simp_all
```
  • Loading branch information
leodemoura authored Feb 6, 2025
1 parent 45d3942 commit eab0908
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 29 deletions.
5 changes: 4 additions & 1 deletion src/Init/Try.lean
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ namespace Lean.Parser.Tactic

syntax (name := tryTrace) "try?" optConfig : tactic

/-- Helper tactic for implementing the tactic `try?`. -/
/-- Helper internal tactic for implementing the tactic `try?`. -/
syntax (name := attemptAll) "attempt_all " withPosition((ppDedent(ppLine) colGe "| " tacticSeq)+) : tactic

/-- Helper internal tactic used to implement `evalSuggest` in `try?` -/
syntax (name := tryResult) "try_suggestions " tactic* : tactic

end Lean.Parser.Tactic
68 changes: 47 additions & 21 deletions src/Lean/Elab/Tactic/Try.lean
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,6 @@ import Lean.Elab.Tactic.Config
import Lean.Elab.Tactic.SimpTrace
import Lean.Elab.Tactic.Grind

namespace Lean.Parser.Tactic
/-- Internal tactic used to implement `evalSuggest` -/
syntax (name := tryResult) "try_suggestions " tactic* : tactic
end Lean.Parser.Tactic

namespace Lean.Elab.Tactic
open Meta
/-!
Expand Down Expand Up @@ -52,7 +47,7 @@ private def appendSeqResult (suggestionSeqs : Array (Array (TSyntax `tactic))) (
/-- Returns a tactic representing all given suggestions `tacs`. -/
private def mkTrySuggestions (tacs : Array (TSyntax `tactic)) : TacticM (TSyntax `tactic) := do
if tacs.isEmpty then
throwError "`mkSuggestions` failed"
throwError "`mkTrySuggestions` failed"
else if tacs.size == 1 then
return tacs[0]!
else
Expand Down Expand Up @@ -130,16 +125,38 @@ private def getKindsSolvedAll (tacss : Array (Array (TSyntax `tactic))) : Array
r := r.push k
return r

private def mkChainResultCore (tac1 : TSyntax `tactic) (tacs2 : Array (TSyntax `tactic)) : TacticM (Array (TSyntax `tactic)) := do
let tacs2 := tacs2.map getSuggestionsCore
private def peekOne (tac1 : TSyntax `tactic) (tacss2 : Array (Array (TSyntax `tactic))) : TacticM (TSyntax `tactic) := do
let mut tacs2 := #[]
for s in tacss2 do
if s.isEmpty then
tacs2 := tacs2.push (← `(tactic| · sorry))
else
tacs2 := tacs2.push (← `(tactic| · $(s[0]!):tactic))
`(tactic| · $tac1:tactic
$tacs2*)

private def mkChainResultCore (tac1 : TSyntax `tactic) (tacss2 : Array (TSyntax `tactic)) : TacticM (Array (TSyntax `tactic)) := do
let tacss2 := tacss2.map getSuggestionsCore
if (← isTracingEnabledFor `try.debug) then
trace[try.debug] "mkChainResultCore tac1{indentD tac1}"
let mut i : Nat := 0
for tacs2 in tacss2 do
i := i + 1
trace[try.debug] "goal #{i} tactics"
for tac2 in tacs2 do
trace[try.debug] " {tac2}"
trace[try.debug] "mkChainResult -----"
let mut acc := #[]
let solvedAll := getTacsSolvedAll tacs2
let solvedAll := getTacsSolvedAll tacss2
for tac2 in solvedAll do
acc := acc.push (← `(tactic| $tac1 <;> $tac2))
let tacs2 := eraseTacs tacs2 solvedAll
let tacss2 := eraseTacs tacss2 solvedAll
-- TODO: mixed cases
trace[Meta.debug] "CHAIN tacs2: {tacs2}"
trace[Meta.debug] "CHAIN kinds: {getKindsSolvedAll tacs2}"
trace[try.debug] "kinds: {getKindsSolvedAll tacss2}"
if (!acc.isEmpty && tacss2.all fun s => !s.isEmpty)
-- We only include partial solutions if there are no other solutions.
|| (acc.isEmpty && tacss2.any fun s => !s.isEmpty) then
acc := acc.push <| (← peekOne tac1 tacss2)
return acc

private def mkChainResult (tac1 : TSyntax `tactic) (tacs2 : Array (TSyntax `tactic)) : TacticM (TSyntax `tactic) := do
Expand Down Expand Up @@ -178,6 +195,7 @@ private def evalSuggestGrindTrace (tac : TSyntax `tactic) : TacticM (TSyntax `ta
let trace ← evalGrindCore tac config only params fallback?
let tac ← grindTraceToGrind tac
let tac' ← mkGrindOnly configStx fallback? trace
trace[try.debug] "`grind` succeeded"
mkTrySuggestions #[tac, tac']
| _ => throwUnsupportedSyntax

Expand All @@ -188,6 +206,7 @@ private def evalSuggestSimpTrace (tac : TSyntax `tactic) : TacticM (TSyntax `tac
let { ctx, simprocs, .. } ← mkSimpContext tac (eraseLocal := false)
let stats ← simpLocation ctx (simprocs := simprocs) none <| (loc.map expandLocation).getD (.targets #[] true)
let tac' ← mkSimpCallStx tac stats.usedTheorems
trace[try.debug] "`simp` succeeded"
mkTrySuggestions #[tac, tac']
| _ => throwUnsupportedSyntax

Expand Down Expand Up @@ -215,11 +234,14 @@ private def evalSuggestChain (tac1 tac2 : TSyntax `tactic) : TacticM (TSyntax `t
let goals ← getGoals
setGoals []
let mut tac2s := #[]
let mut i : Nat := 0
for goal in goals do
setGoals [goal]
let tac2' ← (evalSuggest tac2) <|> `(tactic| sorry)
let tac2' : TSyntax `tactic ← (evalSuggest tac2) <|> `(tactic| sorry)
i := i + 1
trace[try.debug] "`<;>` goal #{i}, tactic{indentD tac2'}"
unless (← getGoals).isEmpty do
throwError "unsolved goals, `<;>` in `try?` requires all goals to be solved"
throwError "unsolved goals, `<;>` in `try?` requires all goals to be solved{indentD tac2}\n{goalsToMessageData (← getGoals)}"
tac2s := tac2s.push tac2'
if tac2s.all isSorry then
throwError "`<;>` failed"
Expand Down Expand Up @@ -269,8 +291,11 @@ where
go (i : Nat) (saved? : Option SavedState) (acc : Array (TSyntax `tactic)) : TacticM (TSyntax `tactic) := do
if i < tacs.size then
match (← observing (evalSuggestTacticSeq tacs[i]!)) with
| .ok tac s => go (i+1) (saved? <|> some s) (appendSuggestion acc tac)
| _ => go (i+1) saved? acc
| .ok tac s =>
trace[try.debug] "`attempt_all` argument succeeded{indentD tac}"
go (i+1) (saved? <|> some s) (appendSuggestion acc tac)
| _ =>
go (i+1) saved? acc
else
if let some saved := saved? then
saved.restore
Expand All @@ -281,6 +306,7 @@ where
-- `evalSuggest` implementation
@[export lean_eval_suggest_tactic]
private partial def evalSuggestImpl (tac : TSyntax `tactic) : TacticM (TSyntax `tactic) := do
trace[try.debug] "{tac}"
match tac with
| `(tactic| $tac1 <;> $tac2) => evalSuggestChain tac1 tac2
| `(tactic| first $[| $tacs]*) => evalSuggestFirst tacs
Expand Down Expand Up @@ -343,17 +369,17 @@ private def setGrindParams (tac : TSyntax `tactic) (params : Array (TSyntax ``Pa
⟨tac.raw.setArg 3 (mkNullNode paramsStx)⟩

/-- Given a set of declaration names, returns `grind` parameters of the form `= <declName>` -/
private def mkGrindEqnParams (declNames : Std.HashSet Name) : MetaM (Array (TSyntax ``Parser.Tactic.grindParam)) := do
declNames.toArray.mapM fun declName => do
private def mkGrindEqnParams (declNames : Array Name) : MetaM (Array (TSyntax ``Parser.Tactic.grindParam)) := do
declNames.mapM fun declName => do
`(Parser.Tactic.grindParam| = $(← toIdent declName))

private def mkGrindStx (info : Try.Info) : MetaM (TSyntax `tactic) := do
let grind ← `(tactic| grind?)
let mut tacs := #[grind]
unless info.eqnCandidates.isEmpty do
tacs := tacs.push (setGrindParams grind (← mkGrindEqnParams info.eqnCandidates))
tacs := tacs.push (setGrindParams grind (← mkGrindEqnParams info.eqnCandidates.elems))
unless info.unfoldCandidates.isEmpty do
tacs := tacs.push (setGrindParams grind (← mkGrindEqnParams info.unfoldCandidates))
tacs := tacs.push (setGrindParams grind (← mkGrindEqnParams info.unfoldCandidates.elems))
mkFirstStx tacs

/-! Other generators -/
Expand Down Expand Up @@ -400,7 +426,7 @@ where
`(tactic| induction $terms,* using $indFn <;> $cont)

private def mkAllFunIndStx (info : Try.Info) (cont : TSyntax `tactic) : MetaM (TSyntax `tactic) := do
let tacs ← info.funIndCandidates.toArray.mapM (mkFunIndStx · cont)
let tacs ← info.funIndCandidates.elems.mapM (mkFunIndStx · cont)
mkFirstStx tacs

/-! Main code -/
Expand Down
1 change: 1 addition & 0 deletions src/Lean/Meta/Tactic/Try.lean
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ builtin_initialize registerTraceClass `try
builtin_initialize registerTraceClass `try.collect
builtin_initialize registerTraceClass `try.collect.funInd

builtin_initialize registerTraceClass `try.debug
builtin_initialize registerTraceClass `try.debug.funInd

end Lean
26 changes: 21 additions & 5 deletions src/Lean/Meta/Tactic/Try/Collect.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,35 @@ structure FunIndCandidate where
majors : Array FVarId
deriving Hashable, BEq

/-- `Set` with insertion order preserved. -/
structure OrdSet (α : Type) [Hashable α] [BEq α] where
elems : Array α := #[]
set : Std.HashSet α := {}
deriving Inhabited

def OrdSet.insert {_ : Hashable α} {_ : BEq α} (s : OrdSet α) (a : α) : OrdSet α :=
if s.set.contains a then
s
else
let { elems, set } := s
{ elems := elems.push a, set := set.insert a }

def OrdSet.isEmpty {_ : Hashable α} {_ : BEq α} (s : OrdSet α) : Bool :=
s.elems.isEmpty

structure Result where
/-- All constant symbols occurring in the gal. -/
allConsts : Std.HashSet Name := {}
allConsts : OrdSet Name := {}
/-- Unfolding candiates. -/
unfoldCandidates : Std.HashSet Name := {}
unfoldCandidates : OrdSet Name := {}
/-- Equation function candiates. -/
eqnCandidates : Std.HashSet Name := {}
eqnCandidates : OrdSet Name := {}
/-- Function induction candidates. -/
funIndCandidates : Std.HashSet FunIndCandidate := {}
funIndCandidates : OrdSet FunIndCandidate := {}
/-- Induction candidates. -/
indCandidates : Array InductionCandidate := #[]
/-- Relevant declarations by `libSearch` -/
libSearchResults : Std.HashSet (Name × Grind.EMatchTheoremKind) := {}
libSearchResults : OrdSet (Name × Grind.EMatchTheoremKind) := {}

structure Context where
config : Try.Config
Expand Down
25 changes: 23 additions & 2 deletions tests/lean/run/grind_constProp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,15 @@ def evalExpr (e : Expr) : EvalM Val := do
@[grind] theorem UnaryOp.simplify_eval (op : UnaryOp) : (op.simplify a).eval σ = (Expr.una op a).eval σ := by
grind [UnaryOp.simplify.eq_def]

/-- info: Try this: (induction e using Expr.simplify.induct) <;> grind -/
/--
info: Try these:
• (induction e using Expr.simplify.induct) <;> grind
• ·
induction e using Expr.simplify.induct
· grind only [Expr.simplify, BinOp.simplify, Expr.eval, BinaryOp.simplify_eval]
· grind only [UnaryOp.simplify_eval, UnaryOp.simplify, Expr.simplify, Expr.eval]
· simp
-/
#guard_msgs (info) in
example (e : Expr) : e.simplify.eval σ = e.eval σ := by
try?
Expand Down Expand Up @@ -304,7 +312,20 @@ theorem State.cons_le_of_eq (h₁ : σ' ≼ σ) (h₂ : σ.find? x = some v) : (
@[grind] theorem State.join_le_left_of (h : σ₁ ≼ σ₂) (σ₃ : State) : σ₁.join σ₃ ≼ σ₂ := by
grind

/-- info: Try this: (induction σ₁, σ₂ using State.join.induct) <;> grind -/
/--
info: Try these:
• (induction σ₁, σ₂ using State.join.induct) <;> grind
• ·
induction σ₁, σ₂ using State.join.induct
·
grind only [State.join_le_left, State.find?, State.join, State.join_le_left_of, State.le, = State.find?_nil,
State.bot_le, State.le_refl]
·
grind only [State.join, State.join_le_left, State.length_erase_le, State.find?, State.join_le_left_of, State.le, =
State.find?_erase_eq, State.erase_le, State.le_refl, cases Or]
· grind only [State.join, State.join_le_left, State.length_erase_le, State.join_le_left_of, State.le, State.erase_le]
· grind only [State.join, State.join_le_left, State.length_erase_le, State.join_le_left_of, State.le, State.erase_le]
-/
#guard_msgs (info) in
example (σ₁ σ₂ : State) : σ₁.join σ₂ ≼ σ₂ := by
try?
Expand Down
17 changes: 17 additions & 0 deletions tests/lean/run/try_trace1.lean
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
set_option grind.warning false
%reset_grind_attrs

/--
info: Try these:
Expand Down Expand Up @@ -97,3 +98,19 @@ example : app (app as bs) cs = app as (app bs cs) := by
intro _ _ _
-- `as`, `bs`, and `cs` now have inaccessible names.
try?

def concat : List α → α → List α
| .nil, b => .cons b .nil
| .cons a as, b => .cons a (concat as b)

attribute [simp] concat

/--
info: Try this: ·
induction as, a using concat.induct
· rfl
· simp_all
-/
#guard_msgs (info) in
example (as : List α) (a : α) : concat as a = as ++ [a] := by
try?

0 comments on commit eab0908

Please sign in to comment.