Skip to content

Commit

Permalink
feat: add LawfulMonadLift (#1125)
Browse files Browse the repository at this point in the history
Co-authored-by: Kim Morrison <[email protected]>
  • Loading branch information
quangvdao and kim-em authored Feb 10, 2025
1 parent 45fe933 commit 309e155
Show file tree
Hide file tree
Showing 3 changed files with 264 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Batteries.lean
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import Batteries.CodeAction.Misc
import Batteries.Control.ForInStep
import Batteries.Control.ForInStep.Basic
import Batteries.Control.ForInStep.Lemmas
import Batteries.Control.Lawful.MonadLift
import Batteries.Control.Lemmas
import Batteries.Control.Monad
import Batteries.Control.Nondet.Basic
Expand Down Expand Up @@ -50,6 +51,7 @@ import Batteries.Lean.HashSet
import Batteries.Lean.IO.Process
import Batteries.Lean.Json
import Batteries.Lean.LawfulMonad
import Batteries.Lean.LawfulMonadLift
import Batteries.Lean.Meta.Basic
import Batteries.Lean.Meta.DiscrTree
import Batteries.Lean.Meta.Expr
Expand Down
207 changes: 207 additions & 0 deletions Batteries/Control/Lawful/MonadLift.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
/-
Copyright (c) 2025 Quang Dao. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Quang Dao
-/

/-!
# Lawful version of `MonadLift`
This file defines the `LawfulMonadLift(T)` class, which adds laws to the `MonadLift(T)` class.
-/

universe u v w

/-- The `MonadLift` typeclass only contains the lifting operation. `LawfulMonadLift` further
asserts that lifting commutes with `pure` and `bind`:
```
monadLift (pure a) = pure a
monadLift ma >>= monadLift ∘ f = monadLift (ma >>= f)
```
-/
class LawfulMonadLift (m : semiOutParam (Type u → Type v)) (n : Type u → Type w)
[Monad m] [Monad n] [inst : MonadLift m n] : Prop where
/-- Lifting preserves `pure` -/
monadLift_pure {α : Type u} (a : α) : inst.monadLift (pure a) = pure a
/-- Lifting preserves `bind` -/
monadLift_bind {α β : Type u} (ma : m α) (f : α → m β) :
inst.monadLift ma >>= (fun x => inst.monadLift (f x)) = inst.monadLift (ma >>= f)

/-- The `MonadLiftT` typeclass only contains the transitive lifting operation.
`LawfulMonadLiftT` further asserts that lifting commutes with `pure` and `bind`:
```
monadLift (pure a) = pure a
monadLift ma >>= monadLift ∘ f = monadLift (ma >>= f)
```
-/
class LawfulMonadLiftT (m : Type u → Type v) (n : Type u → Type w) [Monad m] [Monad n]
[inst : MonadLiftT m n] : Prop where
/-- Lifting preserves `pure` -/
monadLift_pure {α : Type u} (a : α) : inst.monadLift (pure a) = pure a
/-- Lifting preserves `bind` -/
monadLift_bind {α β : Type u} (ma : m α) (f : α → m β) :
monadLift ma >>= (fun x => monadLift (f x)) = inst.monadLift (ma >>= f)

export LawfulMonadLiftT (monadLift_pure monadLift_bind)

section Lemmas

variable [Monad m] [Monad n] [MonadLiftT m n] [LawfulMonadLiftT m n]

@[simp]
theorem liftM_pure (a : α) : liftM (pure a : m α) = pure (f := n) a :=
monadLift_pure _

@[simp]
theorem liftM_bind (ma : m α) (f : α → m β) :
liftM ma >>= (fun a => liftM (f a)) = liftM (n := n) (ma >>= f) :=
monadLift_bind _ _

@[simp]
theorem liftM_map [LawfulMonad m] [LawfulMonad n] (f : α → β) (ma : m α) :
f <$> (liftM ma : n α) = liftM (f <$> ma) := by
rw [← bind_pure_comp, ← bind_pure_comp, ← liftM_bind]
simp only [bind_pure_comp, liftM_pure]

@[simp]
theorem liftM_seq [LawfulMonad m] [LawfulMonad n] (mf : m (α → β)) (ma : m α) :
liftM mf <*> (liftM ma : n α) = liftM (mf <*> ma) := by
simp only [seq_eq_bind, liftM_map, liftM_bind]

@[simp]
theorem liftM_seqLeft [LawfulMonad m] [LawfulMonad n] (x : m α) (y : m β) :
(liftM x : n α) <* (liftM y : n β) = liftM (x <* y) := by
simp only [seqLeft_eq, liftM_map, liftM_seq]

@[simp]
theorem liftM_seqRight [LawfulMonad m] [LawfulMonad n] (x : m α) (y : m β) :
(liftM x : n α) *> (liftM y : n β) = liftM (x *> y) := by
simp only [seqRight_eq, liftM_map, liftM_seq]

end Lemmas

instance (m n o) [Monad m] [Monad n] [Monad o] [MonadLift n o] [MonadLiftT m n]
[LawfulMonadLift n o] [LawfulMonadLiftT m n] : LawfulMonadLiftT m o where
monadLift_pure := fun a => by
simp only [monadLift, LawfulMonadLift.monadLift_pure, liftM_pure]
monadLift_bind := fun ma f => by
simp only [monadLift, LawfulMonadLift.monadLift_bind, liftM_bind]

instance (m) [Monad m] : LawfulMonadLiftT m m where
monadLift_pure _ := rfl
monadLift_bind _ _ := rfl

namespace StateT

variable [Monad m] [LawfulMonad m]

instance : LawfulMonadLift m (StateT σ m) where
monadLift_pure _ := by
simp only [MonadLift.monadLift]
unfold StateT.lift StateT.instMonad StateT.bind StateT.pure
simp only [bind_pure_comp, map_pure]
monadLift_bind _ _ := by
simp only [MonadLift.monadLift]
unfold StateT.lift StateT.instMonad StateT.bind StateT.pure
simp only [bind_pure_comp, Function.comp_apply, bind_map_left, map_bind]

end StateT

namespace ReaderT

variable [Monad m]

instance : LawfulMonadLift m (ReaderT ρ m) where
monadLift_pure _ := by
funext x
simp only [MonadLift.monadLift, pure, ReaderT.pure]
monadLift_bind _ _ := by
funext x
simp only [bind, ReaderT.bind, MonadLift.monadLift, Function.comp_apply]

end ReaderT

namespace OptionT

variable [Monad m] [LawfulMonad m]

@[simp]
theorem lift_pure (a : α) : OptionT.lift (pure a : m α) = pure a := by
simp only [OptionT.lift, OptionT.mk, bind_pure_comp, map_pure, pure, OptionT.pure]

@[simp]
theorem lift_bind (ma : m α) (f : α → m β) :
OptionT.lift ma >>= (fun a => OptionT.lift (f a)) = OptionT.lift (ma >>= f) := by
simp only [instMonad, OptionT.bind, OptionT.mk, OptionT.lift, bind_pure_comp, bind_map_left,
map_bind]

instance : LawfulMonadLift m (OptionT m) where
monadLift_pure := lift_pure
monadLift_bind := lift_bind

end OptionT

namespace ExceptT

variable [Monad m] [LawfulMonad m]

@[simp]
theorem lift_bind (ma : m α) (f : α → m β) :
ExceptT.lift ma >>= (fun a => ExceptT.lift (f a)) = ExceptT.lift (ε := ε) (ma >>= f) := by
simp only [instMonad, ExceptT.bind, mk, ExceptT.lift, bind_map_left, ExceptT.bindCont, map_bind]

instance : LawfulMonadLift m (ExceptT ε m) where
monadLift_pure := lift_pure
monadLift_bind := lift_bind

instance : LawfulMonadLift (Except ε) (ExceptT ε m) where
monadLift_pure _ := by
simp only [MonadLift.monadLift, mk, pure, Except.pure, ExceptT.pure]
monadLift_bind ma _ := by
simp only [instMonad, ExceptT.bind, mk, MonadLift.monadLift, pure_bind, ExceptT.bindCont,
Except.instMonad, Except.bind]
rcases ma with _ | _ <;> simp

end ExceptT

namespace StateRefT'

instance [Monad m] : LawfulMonadLift m (StateRefT' ω σ m) where
monadLift_pure _ := by
simp only [MonadLift.monadLift, pure]
unfold StateRefT'.lift ReaderT.pure
simp only
monadLift_bind _ _ := by
simp only [MonadLift.monadLift, bind]
unfold StateRefT'.lift ReaderT.bind
simp only

end StateRefT'

namespace StateCpsT

instance [Monad m] [LawfulMonad m] : LawfulMonadLift m (StateCpsT σ m) where
monadLift_pure _ := by
simp only [MonadLift.monadLift, pure]
unfold StateCpsT.lift
simp only [pure_bind]
monadLift_bind _ _ := by
simp only [MonadLift.monadLift, bind]
unfold StateCpsT.lift
simp only [bind_assoc]

end StateCpsT

namespace ExceptCpsT

instance [Monad m] [LawfulMonad m] : LawfulMonadLift m (ExceptCpsT ε m) where
monadLift_pure _ := by
simp [MonadLift.monadLift, pure]
unfold ExceptCpsT.lift
simp only [pure_bind]
monadLift_bind _ _ := by
simp only [MonadLift.monadLift, bind]
unfold ExceptCpsT.lift
simp only [bind_assoc]

end ExceptCpsT
55 changes: 55 additions & 0 deletions Batteries/Lean/LawfulMonadLift.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/-
Copyright (c) 2025 Quang Dao. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Quang Dao
-/
import Batteries.Control.Lawful.MonadLift
import Batteries.Lean.LawfulMonad

/-!
# Lawful instances of `MonadLift` for the Lean monad stack.
-/

open Lean Elab Term Tactic Command

instance : LawfulMonadLift BaseIO (EIO ε) where
monadLift_pure _ := rfl
monadLift_bind ma f := by
simp only [MonadLift.monadLift, bind]
unfold BaseIO.toEIO EStateM.bind IO.RealWorld
simp only
funext x
rcases ma x with a | a
· simp only
· contradiction

instance : LawfulMonadLift (ST σ) (EST ε σ) where
monadLift_pure a := rfl
monadLift_bind ma f := by
simp only [MonadLift.monadLift, bind]
unfold EStateM.bind
simp only
funext x
rcases ma x with a | a
· simp only
· contradiction

instance : LawfulMonadLift IO CoreM where
monadLift_pure _ := rfl
monadLift_bind ma f := by
simp only [MonadLift.monadLift, bind, pure, Core.liftIOCore, liftM, monadLift, getRef, read,
readThe, MonadReaderOf.read, IO.toEIO]
unfold StateRefT'.lift ReaderT.read ReaderT.bind ReaderT.pure
simp only [pure_bind, Function.comp_apply, bind_assoc, bind, pure]
unfold ReaderT.bind ReaderT.pure
simp only [bind_pure_comp, map_pure, pure_bind, bind, pure]
unfold EStateM.adaptExcept EStateM.bind EStateM.pure
simp only
funext _ _ s
rcases ma s with a | a <;> simp only

instance : LawfulMonadLiftT (EIO Exception) CommandElabM := inferInstance
instance : LawfulMonadLiftT (EIO Exception) CoreM := inferInstance
instance : LawfulMonadLiftT CoreM MetaM := inferInstance
instance : LawfulMonadLiftT MetaM TermElabM := inferInstance
instance : LawfulMonadLiftT TermElabM TacticM := inferInstance

0 comments on commit 309e155

Please sign in to comment.