Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: try? validation and cleanup #6980

Merged
merged 4 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 72 additions & 65 deletions src/Lean/Elab/Tactic/Try.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ A **very** simple `try?` tactic implementation.

declare_config_elab elabTryConfig Try.Config

namespace Try

/-!
`evalSuggest` is a `evalTactic` variant that returns suggestions after executing a tactic built using
combinatiors such as `first`, `attempt_all`, `<;>`, `;`, and `try`.
Expand All @@ -36,14 +38,6 @@ private def getSuggestionOfTactic (tac : TSyntax `tactic) : Array (TSyntax `tact
private def appendSuggestion (suggestions : Array (TSyntax `tactic)) (tac : TSyntax `tactic) : Array (TSyntax `tactic) :=
suggestions ++ getSuggestionOfTactic tac

/--
Given the suggestion sequecences `suggestionsSeqs`, extends each sequence using `tac`.
-/
private def appendSeqResult (suggestionSeqs : Array (Array (TSyntax `tactic))) (tac : TSyntax `tactic) : Array (Array (TSyntax `tactic)) :=
match tac with
| `(tactic| try_suggestions $tacs:tactic*) => suggestionSeqs.foldl (init := #[]) fun result seq => result ++ tacs.map (seq.push ·)
| _ => suggestionSeqs.map (·.push tac)

/-- Returns a tactic representing all given suggestions `tacs`. -/
private def mkTrySuggestions (tacs : Array (TSyntax `tactic)) : TacticM (TSyntax `tactic) := do
if tacs.isEmpty then
Expand All @@ -59,19 +53,16 @@ private def filterSkipDone (tacs : Array (TSyntax `tactic)) : Array (TSyntax `ta
| `(tactic| done) | `(tactic| skip) => false
| _ => true

/--
Returns a tactic representing the given suggestions.
-/
private def mkSeqResult (suggestionSeqs : Array (Array (TSyntax `tactic))) : TacticM (TSyntax `tactic) := do
let tacs ← suggestionSeqs.mapM fun tacs => do
let tacs := filterSkipDone tacs
if tacs.size = 0 then
`(tactic| done)
else if tacs.size = 1 then
return tacs[0]!
else
`(tactic| · $tacs;*)
mkTrySuggestions tacs
private def mkSeq (tacs : Array (TSyntax `tactic)) (terminal : Bool) : CoreM (TSyntax `tactic) := do
let tacs := filterSkipDone tacs
if tacs.size = 0 then
if terminal then `(tactic| done) else `(tactic| skip)
else if tacs.size = 1 then
return tacs[0]!
else if terminal then
`(tactic| · $tacs;*)
else
`(tactic| ($tacs;*))

/-- Returns `true` if `tac` is `sorry` -/
private def isSorry (tac : TSyntax `tactic) : Bool :=
Expand Down Expand Up @@ -135,7 +126,7 @@ private def peekOne (tac1 : TSyntax `tactic) (tacss2 : Array (Array (TSyntax `ta
`(tactic| · $tac1:tactic
$tacs2*)

private def mkChainResultCore (tac1 : TSyntax `tactic) (tacss2 : Array (TSyntax `tactic)) : TacticM (Array (TSyntax `tactic)) := do
private def mkChainResult (tac1 : TSyntax `tactic) (tacss2 : Array (TSyntax `tactic)) : TacticM (TSyntax `tactic) := do
let tacss2 := tacss2.map getSuggestionsCore
if (← isTracingEnabledFor `try.debug) then
trace[try.debug] "mkChainResultCore tac1{indentD tac1}"
Expand All @@ -157,14 +148,7 @@ private def mkChainResultCore (tac1 : TSyntax `tactic) (tacss2 : Array (TSyntax
-- We only include partial solutions if there are no other solutions.
|| (acc.isEmpty && tacss2.any fun s => !s.isEmpty) then
acc := acc.push <| (← peekOne tac1 tacss2)
return acc

private def mkChainResult (tac1 : TSyntax `tactic) (tacs2 : Array (TSyntax `tactic)) : TacticM (TSyntax `tactic) := do
match tac1 with
| `(tactic| try_suggestions $tacs1:tactic*) =>
let tacs ← tacs1.foldlM (init := #[]) fun acc tac1 => return acc ++ (← mkChainResultCore tac1 tacs2)
mkTrySuggestions tacs
| _ => mkTrySuggestions (← mkChainResultCore tac1 tacs2)
mkTrySuggestions acc

private def evalSuggestAtomic (tac : TSyntax `tactic) : TacticM (TSyntax `tactic) := do
let goals ← getGoals
Expand Down Expand Up @@ -212,7 +196,24 @@ private def evalSuggestSimpTrace (tac : TSyntax `tactic) : TacticM (TSyntax `tac

abbrev TacticResult (α : Type) := EStateM.Result Exception SavedState α

def observing (x : TacticM α) : TacticM (TacticResult α) := do
structure Ctx where
root : TSyntax `tactic
terminal : Bool
config : Try.Config

abbrev M := ReaderT Ctx TacticM

instance : MonadBacktrack SavedState M where
saveState := fun _ => saveState
restoreState s := fun _ => restoreState s

abbrev withNonTerminal (x : M α) : M α :=
withReader (fun c => { c with terminal := false}) x

-- TODO: polymorphic `Tactic.focus`
abbrev focus (x : M α) : M α := fun ctx => Tactic.focus (x ctx)

def observing (x : M α) : M (TacticResult α) := do
let s ← saveState
try
let e ← x
Expand All @@ -226,11 +227,13 @@ def observing (x : TacticM α) : TacticM (TacticResult α) := do
return .error ex sNew

@[extern "lean_eval_suggest_tactic"] -- forward definition to avoid mutual block
opaque evalSuggest (tac : TSyntax `tactic) : TacticM (TSyntax `tactic)
opaque evalSuggest (tac : TSyntax `tactic) : M (TSyntax `tactic)

/-- `evalSuggest` for `tac1 <;> tac2` -/
private def evalSuggestChain (tac1 tac2 : TSyntax `tactic) : TacticM (TSyntax `tactic) := focus do
let tac1 ← evalSuggest tac1
private def evalSuggestChain (tac1 tac2 : TSyntax `tactic) : M (TSyntax `tactic) := focus do
unless (← read).terminal do
throwError "invalid `<;>` occurrence in non-terminal position for `try?` script{indentD (← read).root}"
let tac1 ← withNonTerminal do evalSuggest tac1
let goals ← getGoals
setGoals []
let mut tac2s := #[]
Expand All @@ -239,56 +242,56 @@ private def evalSuggestChain (tac1 tac2 : TSyntax `tactic) : TacticM (TSyntax `t
setGoals [goal]
let tac2' : TSyntax `tactic ← (evalSuggest tac2) <|> `(tactic| sorry)
i := i + 1
trace[try.debug] "`<;>` goal #{i}, tactic{indentD tac2'}"
unless (← getGoals).isEmpty do
throwError "unsolved goals, `<;>` in `try?` requires all goals to be solved{indentD tac2}\n{goalsToMessageData (← getGoals)}"
tac2s := tac2s.push tac2'
if tac2s.all isSorry then
throwError "`<;>` failed"
mkChainResult tac1 tac2s

/-- `evalSuggest` for a sequence of tactics. -/
private def evalSuggestSeq (tacs : Array (TSyntax `tactic)) : TacticM (TSyntax `tactic) := do
go 0 #[#[]]
where
go (i : Nat) (accs : Array (Array (TSyntax `tactic))) : TacticM (TSyntax `tactic) := do
if i < tacs.size then
let tac' ← evalSuggest tacs[i]!
go (i+1) (appendSeqResult accs tac')
else
mkSeqResult accs
private def evalSuggestSeq (tacs : Array (TSyntax `tactic)) : M (TSyntax `tactic) := do
if (← read).terminal then
let mut result := #[]
for i in [:tacs.size - 1] do
result := result.push (← withNonTerminal <| evalSuggest tacs[i]!)
let suggestions ← getSuggestionOfTactic (← evalSuggest tacs.back!) |>.mapM fun tac =>
mkSeq (result.push tac) (terminal := true)
mkTrySuggestions suggestions
else
mkSeq (← tacs.mapM evalSuggest) (terminal := false)

private def evalSuggestSeqCore (tacs : Array Syntax) : TacticM (TSyntax `tactic) := do
private def evalSuggestSeqCore (tacs : Array Syntax) : M (TSyntax `tactic) := do
evalSuggestSeq (tacs.map fun tac => ⟨tac⟩)

private def evalSuggestTacticSeq (s : TSyntax ``Parser.Tactic.tacticSeq) : TacticM (TSyntax `tactic) := do
private def evalSuggestTacticSeq (s : TSyntax ``Parser.Tactic.tacticSeq) : M (TSyntax `tactic) := do
let tacs ← match s with
| `(tacticSeq| { $t;* }) => pure t.getElems
| `(tacticSeq| $t;*) => pure t.getElems
| _ => throwError "unexpeted sequence"
evalSuggestSeq tacs

/-- `evalSuggest` for `first` tactic. -/
private partial def evalSuggestFirst (tacs : Array (TSyntax ``Parser.Tactic.tacticSeq)) : TacticM (TSyntax `tactic) := do
private partial def evalSuggestFirst (tacs : Array (TSyntax ``Parser.Tactic.tacticSeq)) : M (TSyntax `tactic) := do
go 0
where
go (i : Nat) : TacticM (TSyntax `tactic) := do
go (i : Nat) : M (TSyntax `tactic) := do
if i = tacs.size - 1 then
evalSuggestTacticSeq tacs[i]!
else
evalSuggestTacticSeq tacs[i]! <|> go (i+1)

/-- `evalSuggest` for `try` tactic. -/
private partial def evalSuggestTry (tac : TSyntax ``Parser.Tactic.tacticSeq) : TacticM (TSyntax `tactic) := do
private partial def evalSuggestTry (tac : TSyntax ``Parser.Tactic.tacticSeq) : M (TSyntax `tactic) := do
(do evalSuggestTacticSeq tac)
<|>
`(tactic| skip)

/-- `evalSuggest` for `attempt_all` tactic. -/
private partial def evalSuggestAttemptAll (tacs : Array (TSyntax ``Parser.Tactic.tacticSeq)) : TacticM (TSyntax `tactic) := do
private partial def evalSuggestAttemptAll (tacs : Array (TSyntax ``Parser.Tactic.tacticSeq)) : M (TSyntax `tactic) := do
unless (← read).terminal do
throwError "invalid occurrence of `attempt_all` in non-terminal position for `try?` script{indentD (← read).root}"
go 0 none #[]
where
go (i : Nat) (saved? : Option SavedState) (acc : Array (TSyntax `tactic)) : TacticM (TSyntax `tactic) := do
go (i : Nat) (saved? : Option SavedState) (acc : Array (TSyntax `tactic)) : M (TSyntax `tactic) := do
if i < tacs.size then
match (← observing (evalSuggestTacticSeq tacs[i]!)) with
| .ok tac s =>
Expand All @@ -305,8 +308,7 @@ where

-- `evalSuggest` implementation
@[export lean_eval_suggest_tactic]
private partial def evalSuggestImpl (tac : TSyntax `tactic) : TacticM (TSyntax `tactic) := do
trace[try.debug] "{tac}"
private partial def evalSuggestImpl (tac : TSyntax `tactic) : M (TSyntax `tactic) := do
match tac with
| `(tactic| $tac1 <;> $tac2) => evalSuggestChain tac1 tac2
| `(tactic| first $[| $tacs]*) => evalSuggestFirst tacs
Expand All @@ -317,12 +319,17 @@ private partial def evalSuggestImpl (tac : TSyntax `tactic) : TacticM (TSyntax `
let k := tac.raw.getKind
if k == ``Parser.Tactic.seq1 then
evalSuggestSeqCore tac.raw[0].getSepArgs
else if k == ``Parser.Tactic.grindTrace then
evalSuggestGrindTrace tac
else if k == ``Parser.Tactic.simpTrace then
evalSuggestSimpTrace tac
else
evalSuggestAtomic tac
let r ← if k == ``Parser.Tactic.grindTrace then
evalSuggestGrindTrace tac
else if k == ``Parser.Tactic.simpTrace then
evalSuggestSimpTrace tac
else
evalSuggestAtomic tac
if (← read).terminal then
unless (← getGoals).isEmpty do
throwError "unsolved goals"
return r

private def toSuggestion (t : TSyntax `tactic) : Tactic.TryThis.Suggestion :=
t
Expand All @@ -341,8 +348,8 @@ private def addSuggestions (tk : Syntax) (s : Array Tactic.TryThis.Suggestion) :
else
Tactic.TryThis.addSuggestions tk (s.map fun stx => stx) (origSpan? := (← getRef))

def evalAndSuggest (tk : Syntax) (tac : TSyntax `tactic) : TacticM Unit := do
let tac' ← evalSuggest tac
def evalAndSuggest (tk : Syntax) (tac : TSyntax `tactic) (config : Try.Config := {}) : TacticM Unit := do
let tac' ← evalSuggest tac |>.run { terminal := true, root := tac, config }
let s := getSuggestions tac'
if s.isEmpty then
throwEvalAndSuggestFailed
Expand Down Expand Up @@ -447,11 +454,11 @@ private def mkTryEvalSuggestStx (info : Try.Info) : MetaM (TSyntax `tactic) := d

@[builtin_tactic Lean.Parser.Tactic.tryTrace] def evalTryTrace : Tactic := fun stx => do
match stx with
| `(tactic| try?%$tk $config:optConfig) => focus do withMainContext do
| `(tactic| try?%$tk $config:optConfig) => Tactic.focus do withMainContext do
let config ← elabTryConfig config
let info ← Try.collect (← getMainGoal) config
let stx ← mkTryEvalSuggestStx info
evalAndSuggest tk stx
evalAndSuggest tk stx config
| _ => throwUnsupportedSyntax

end Lean.Elab.Tactic
end Lean.Elab.Tactic.Try
41 changes: 41 additions & 0 deletions tests/lean/run/eval_suggest1.lean
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,44 @@ info: Try these:
#guard_msgs (info) in
example (h : 0 + x = y) : f x = f y := by
try_simple?


macro "bad_tac" : tactic => `(tactic| eval_suggest (intros; (attempt_all | rfl | grind?); simp))

/--
error: invalid occurrence of `attempt_all` in non-terminal position for `try?` script
(intros;
(attempt_all
| rfl
| grind?);
simp)
-/
#guard_msgs (error) in
example : True := by
bad_tac

macro "simple_tac" : tactic => `(tactic| eval_suggest (intros; skip; first | skip | simp))

/--
info: Try this: simp
-/
#guard_msgs (info) in
example : True ∧ True := by
simple_tac -- terminal `skip` should not succeed

example : False := by
fail_if_success simple_tac -- should not succeed
sorry

set_option hygiene false in
macro "simple_tac2" : tactic => `(tactic| eval_suggest (intros; (simp only [Nat.zero_add]; simp only [Nat.one_mul]); simp [*]))

/--
info: Try this: · intros; (simp only [Nat.zero_add]; simp only [Nat.one_mul]); simp [*]
-/
#guard_msgs (info) in
example : x = 0 → 0 + 1*x = 0 := by
simple_tac2

example : x = 0 → 0 + 1*x = 0 := by
· intros; (simp only [Nat.zero_add]; simp only [Nat.one_mul]); simp [*]
16 changes: 16 additions & 0 deletions tests/lean/run/try_trace1.lean
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,19 @@ info: Try this: ·
#guard_msgs (info) in
example (as : List α) (a : α) : concat as a = as ++ [a] := by
try?

def foo : Nat → Nat
| 0 => 1
| x+1 => foo x - 1


/--
info: Try this: ·
induction x using foo.induct
· grind [= foo]
· sorry
-/
#guard_msgs (info) in
example : foo x > 0 := by
try? -- `try?` does not solve all subgoals.
sorry
Loading