Skip to content

Commit

Permalink
feat: try? to use fun_induction (#7082)
Browse files Browse the repository at this point in the history
This PR makes `try?` use `fun_induction` instead of `induction … using
foo.induct`. It uses the argument-free short-hand `fun_induction foo` if
that is unambiguous. Avoids `expose_names` if not necessary by simply
trying without first.
  • Loading branch information
nomeata authored Feb 18, 2025
1 parent 2d4c001 commit 2fed934
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 102 deletions.
39 changes: 24 additions & 15 deletions src/Lean/Elab/Tactic/Try.lean
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ private def isAccessible (fvarId : FVarId) : MetaM Bool := do
| return false
return localDecl'.fvarId == localDecl.fvarId

/-- Returns `true` if all free variables occurring in `e` are accessible. -/
/--
Returns `true` if all free variables occurring in `e` are accessible. Over-approximation, since
the free variable may be implicit.
-/
private def isExprAccessible (e : Expr) : MetaM Bool := do
let (_, s) ← e.collectFVars |>.run {}
s.fvarIds.allM isAccessible
Expand Down Expand Up @@ -598,22 +601,28 @@ private def mkSimpleTacStx : CoreM (TSyntax `tactic) :=
/-! Function induction generators -/

open Try.Collector in
private def mkFunIndStx (c : FunIndCandidate) (cont : TSyntax `tactic) : MetaM (TSyntax `tactic) := do
if (← c.majors.allM isAccessible) then
go
else withExposedNames do
`(tactic| (expose_names; $(← go):tactic))
where
go : MetaM (TSyntax `tactic) := do
let mut terms := #[]
for major in c.majors do
let localDecl ← major.getDecl
terms := terms.push (← `(Parser.Tactic.elimTarget| $(mkIdent localDecl.userName):term))
let indFn ← toIdent c.funIndDeclName
`(tactic| induction $terms,* using $indFn <;> $cont)
private def mkFunIndStx (uniques : NameSet) (expr : Expr) (cont : TSyntax `tactic) :
MetaM (TSyntax `tactic) := do
let fn := expr.getAppFn.constName!
if uniques.contains fn then
-- If it is unambigous, use `fun_induction foo` without arguments
`(tactic| fun_induction $(← toIdent fn):term <;> $cont)
else
let isAccessible ← isExprAccessible expr
withExposedNames do
let stx ← PrettyPrinter.delab expr
let tac₁ ← `(tactic| fun_induction $stx <;> $cont)
-- if expr has no inaccessible names, use as is
if isAccessible then
pure tac₁
else
-- if it has inaccessible names, still try without, in case they are all implicit
let tac₂ ← `(tactic| (expose_names; $tac₁))
mkFirstStx #[tac₁, tac₂]

private def mkAllFunIndStx (info : Try.Info) (cont : TSyntax `tactic) : MetaM (TSyntax `tactic) := do
let tacs ← info.funIndCandidates.elems.mapM (mkFunIndStx · cont)
let uniques := info.funIndCandidates.uniques
let tacs ← info.funIndCandidates.calls.mapM (mkFunIndStx uniques · cont)
mkFirstStx tacs

/-! Main code -/
Expand Down
27 changes: 23 additions & 4 deletions src/Lean/Meta/Tactic/FunIndCollect.lean
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@ structure Call where
structure SeenCalls where
/-- the full calls -/
calls : Array Expr
/-- only relevant arguments -/
seen : Std.HashSet (Array Expr)
/-- only function name and relevant arguments -/
seen : Std.HashSet (Name × Array Expr)

instance : EmptyCollection SeenCalls where
emptyCollection := ⟨#[], {}⟩

def SeenCalls.isEmpty (sc : SeenCalls) : Bool :=
sc.calls.isEmpty

def SeenCalls.push (e : Expr) (declName : Name) (args : Array Expr) (calls : SeenCalls) :
MetaM SeenCalls := do
let some funIndInfo ← getFunIndInfo? (cases := false) declName | return calls
Expand All @@ -41,8 +44,24 @@ def SeenCalls.push (e : Expr) (declName : Name) (args : Array Expr) (calls : See
if !arg.isFVar then return calls
unless kind matches .dropped do
keys := keys.push arg
if calls.seen.contains keys then return calls
return { calls := calls.calls.push e, seen := calls.seen.insert keys }
let key := (declName, keys)
if calls.seen.contains key then return calls
return { calls := calls.calls.push e, seen := calls.seen.insert key }

/--
Which functions have exactly one candidate application. Used by `try?` to determine whether
we can use `fun_induction foo` or need `fun_induction foo x y z`.
-/
def SeenCalls.uniques (calls : SeenCalls) : NameSet := Id.run do
let mut seen : NameSet := {}
let mut seenTwice : NameSet := {}
for (n, _) in calls.seen do
unless seenTwice.contains n do
if seen.contains n then
seenTwice := seenTwice.insert n
else
seen := seen.insert n
return seen.filter (! seenTwice.contains ·)

namespace Collector

Expand Down
69 changes: 9 additions & 60 deletions src/Lean/Meta/Tactic/Try/Collect.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,15 @@ import Lean.Meta.Tactic.LibrarySearch
import Lean.Meta.Tactic.Util
import Lean.Meta.Tactic.Grind.Cases
import Lean.Meta.Tactic.Grind.EMatchTheorem
import Lean.Meta.Tactic.FunIndInfo
import Lean.Meta.Tactic.FunIndCollect

namespace Lean.Meta.Try.Collector

structure InductionCandidate where
fvarId : FVarId
val : InductiveVal

structure FunIndCandidate where
funIndDeclName : Name
majors : Array FVarId
deriving Hashable, BEq

/-- `Set` with insertion order preserved. -/
structure OrdSet (α : Type) [Hashable α] [BEq α] where
elems : Array α := #[]
Expand All @@ -44,8 +41,8 @@ structure Result where
unfoldCandidates : OrdSet Name := {}
/-- Equation function candiates. -/
eqnCandidates : OrdSet Name := {}
/-- Function induction candidates. -/
funIndCandidates : OrdSet FunIndCandidate := {}
/-- Function induction candidates -/
funIndCandidates : FunInd.SeenCalls := {}
/-- Induction candidates. -/
indCandidates : Array InductionCandidate := #[]
/-- Relevant declarations by `libSearch` -/
Expand All @@ -66,17 +63,6 @@ def saveConst (declName : Name) : M Unit := do
def inCurrentModule (declName : Name) : CoreM Bool := do
return ((← getEnv).getModuleIdxFor? declName).isNone

def getFunInductName (declName : Name) : Name :=
declName ++ `induct

def getFunInduct? (declName : Name) : MetaM (Option Name) := do
let .defnInfo _ ← getConstInfo declName | return none
try
let result ← realizeGlobalConstNoOverloadCore (getFunInductName declName)
return some result
catch _ =>
return none

def isEligible (declName : Name) : M Bool := do
if declName.hasMacroScopes then
return false
Expand Down Expand Up @@ -112,49 +98,11 @@ def visitConst (declName : Name) : M Unit := do
saveConst declName
saveUnfoldCandidate declName

-- Horrible temporary hack: compute the mask assuming parameters appear before a variable named `motive`
-- It assumes major premises appear after variables with name `case?`
-- It assumes if something is not a parameter, then it is major :(
-- TODO: save the mask while generating the induction principle.
def getFunIndMask? (declName : Name) (indDeclName : Name) : MetaM (Option (Array Bool)) := do
let info ← getConstInfo declName
let indInfo ← getConstInfo indDeclName
let (numParams, numMajor) ← forallTelescope indInfo.type fun xs _ => do
let mut foundCase := false
let mut foundMotive := false
let mut numParams : Nat := 0
let mut numMajor : Nat := 0
for x in xs do
let localDecl ← x.fvarId!.getDecl
let n := localDecl.userName
if n == `motive then
foundMotive := true
else if !foundMotive then
numParams := numParams + 1
else if n.isStr && "case".isPrefixOf n.getString! then
foundCase := true
else if foundCase then
numMajor := numMajor + 1
return (numParams, numMajor)
if numMajor == 0 then return none
forallTelescope info.type fun xs _ => do
if xs.size != numParams + numMajor then
return none
return some (mkArray numParams false ++ mkArray numMajor true)

def saveFunInd (_e : Expr) (declName : Name) (args : Array Expr) : M Unit := do
def saveFunInd (e : Expr) (declName : Name) (args : Array Expr) : M Unit := do
if (← isEligible declName) then
let some funIndDeclName ← getFunInduct? declName
| saveUnfoldCandidate declName; return ()
let some mask ← getFunIndMask? declName funIndDeclName | return ()
if mask.size != args.size then return ()
let mut majors := #[]
for arg in args, isMajor in mask do
if isMajor then
if !arg.isFVar then return ()
majors := majors.push arg.fvarId!
trace[try.collect.funInd] "{funIndDeclName}, {majors.map mkFVar}"
modify fun s => { s with funIndCandidates := s.funIndCandidates.insert { majors, funIndDeclName }}
let sc := (← get).funIndCandidates
let sc' ← sc.push e declName args
modify fun s => { s with funIndCandidates := sc' }

open LibrarySearch in
def saveLibSearchCandidates (e : Expr) : M Unit := do
Expand All @@ -170,6 +118,7 @@ def saveLibSearchCandidates (e : Expr) : M Unit := do
def visitApp (e : Expr) (declName : Name) (args : Array Expr) : M Unit := do
saveEqnCandidate declName
saveFunInd e declName args
saveUnfoldCandidate declName
saveLibSearchCandidates e

def checkInductive (localDecl : LocalDecl) : M Unit := do
Expand Down
7 changes: 4 additions & 3 deletions tests/lean/run/grind_constProp.lean
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def evalExpr (e : Expr) : EvalM Val := do
@[grind] theorem UnaryOp.simplify_eval (op : UnaryOp) : (op.simplify a).eval σ = (Expr.una op a).eval σ := by
grind [UnaryOp.simplify.eq_def]

/-- info: Try this: (induction e using Expr.simplify.induct) <;> grind -/
/-- info: Try this: (fun_induction Expr.simplify) <;> grind -/
#guard_msgs (info) in
example (e : Expr) : e.simplify.eval σ = e.eval σ := by
try? (max := 1)
Expand Down Expand Up @@ -304,13 +304,14 @@ theorem State.cons_le_of_eq (h₁ : σ' ≼ σ) (h₂ : σ.find? x = some v) : (
@[grind] theorem State.join_le_left_of (h : σ₁ ≼ σ₂) (σ₃ : State) : σ₁.join σ₃ ≼ σ₂ := by
grind

/-- info: Try this: (induction σ₁, σ₂ using State.join.induct) <;> grind -/
/-- info: Try this: (fun_induction join) <;> grind -/
#guard_msgs (info) in
open State in
example (σ₁ σ₂ : State) : σ₁.join σ₂ ≼ σ₂ := by
try? (max := 1)

@[grind] theorem State.join_le_right (σ₁ σ₂ : State) : σ₁.join σ₂ ≼ σ₂ := by
induction σ₁, σ₂ using State.join.induct <;> grind
fun_induction join <;> grind

@[grind] theorem State.join_le_right_of (h : σ₁ ≼ σ₂) (σ₃ : State) : σ₃.join σ₁ ≼ σ₂ := by
grind
Expand Down
51 changes: 31 additions & 20 deletions tests/lean/run/grind_try_trace.lean
Original file line number Diff line number Diff line change
Expand Up @@ -80,24 +80,22 @@ example : app [a, b] [c] = [a, b, c] := by

/--
info: Try these:
• (induction as, bs using app.induct) <;> grind [= app]
• (induction as, bs using app.induct) <;> grind only [app]
• (fun_induction app as bs) <;> grind [= app]
• (fun_induction app as bs) <;> grind only [app]
-/
#guard_msgs (info) in
example : app (app as bs) cs = app as (app bs cs) := by
try?

/--
info: Try this: (induction as, bs using app.induct) <;> grind [= app]
-/
/-- info: Try this: (fun_induction app as bs) <;> grind [= app] -/
#guard_msgs (info) in
example : app (app as bs) cs = app as (app bs cs) := by
try? (max := 1)

/--
info: Try these:
• · expose_names; induction as, bs_1 using app.induct <;> grind [= app]
• · expose_names; induction as, bs_1 using app.induct <;> grind only [app]
• · expose_names; fun_induction app as bs_1 <;> grind [= app]
• · expose_names; fun_induction app as bs_1 <;> grind only [app]
-/
#guard_msgs (info) in
example : app (app as bs) cs = app as (app bs cs) := by
Expand All @@ -106,8 +104,8 @@ example : app (app as bs) cs = app as (app bs cs) := by

/--
info: Try these:
• · expose_names; induction as, bs using app.induct <;> grind [= app]
• · expose_names; induction as, bs using app.induct <;> grind only [app]
• · expose_names; fun_induction app as bs <;> grind [= app]
• · expose_names; fun_induction app as bs <;> grind only [app]
-/
#guard_msgs (info) in
example : app (app as bs) cs = app as (app bs cs) := by
Expand All @@ -124,34 +122,47 @@ attribute [simp] concat

/--
info: Try these:
• (induction as, a using concat.induct) <;> simp_all
• (induction as, a using concat.induct) <;> simp [*]
• (fun_induction concat) <;> simp_all
• (fun_induction concat) <;> simp [*]
-/
#guard_msgs (info) in
example (as : List α) (a : α) : concat as a = as ++ [a] := by
try? -only

/--
info: Try these:
• (induction as, a using concat.induct) <;> simp_all
• (fun_induction concat) <;> simp_all
• ·
induction as, a using concat.induct
fun_induction concat
· simp
· simp [*]
-/
#guard_msgs (info) in
example (as : List α) (a : α) : concat as a = as ++ [a] := by
try? -only -merge

def map (f : α → β) : List α → List β
| [] => []
| x::xs => f x :: map f xs

/--
info: Try these:
• (fun_induction map) <;> grind [= map]
• (fun_induction map) <;> grind only [map]
-/
#guard_msgs (info) in
theorem map_map (f : α → β) (g : β → γ) xs :
map g (map f xs) = map (fun x => g (f x)) xs := by
try? -- NB: Multiple calls to `xs.map`, but they differ only in ignore arguments


def foo : Nat → Nat
| 0 => 1
| x+1 => foo x - 1


/--
info: Try this: ·
induction x using foo.induct
fun_induction foo
· grind [= foo]
· sorry
-/
Expand All @@ -177,11 +188,11 @@ attribute [grind] List.length_reverse bla

/--
info: Try these:
• (induction xs, ys using bla.induct) <;> grind
• (induction xs, ys using bla.induct) <;> simp_all
• (induction xs, ys using bla.induct) <;> simp [*]
• (induction xs, ys using bla.induct) <;> simp only [bla, List.length_reverse, *]
• (induction xs, ys using bla.induct) <;> grind only [List.length_reverse, bla]
• (fun_induction bla) <;> grind
• (fun_induction bla) <;> simp_all
• (fun_induction bla) <;> simp [*]
• (fun_induction bla) <;> simp only [bla, List.length_reverse, *]
• (fun_induction bla) <;> grind only [List.length_reverse, bla]
-/
#guard_msgs (info) in
example : (bla xs ys).length = ys.length := by
Expand Down

0 comments on commit 2fed934

Please sign in to comment.