Compare commits

...

1 Commits

Author SHA1 Message Date
Kim Morrison
97671dddb9 feat: grind annotations for basic monad transformers 2025-09-03 07:47:49 +02:00

View File

@@ -22,23 +22,24 @@ open Function
namespace ExceptT
@[ext] theorem ext {x y : ExceptT ε m α} (h : x.run = y.run) : x = y := by
@[ext, grind ext] theorem ext {x y : ExceptT ε m α} (h : x.run = y.run) : x = y := by
simp [run] at h
assumption
@[simp] theorem run_pure [Monad m] (x : α) : run (pure x : ExceptT ε m α) = pure (Except.ok x) := rfl
@[simp, grind =] theorem run_pure [Monad m] (x : α) : run (pure x : ExceptT ε m α) = pure (Except.ok x) := rfl
@[simp] theorem run_lift [Monad.{u, v} m] (x : m α) : run (ExceptT.lift x : ExceptT ε m α) = (Except.ok <$> x : m (Except ε α)) := rfl
@[simp, grind =] theorem run_lift [Monad.{u, v} m] (x : m α) : run (ExceptT.lift x : ExceptT ε m α) = (Except.ok <$> x : m (Except ε α)) := rfl
@[simp] theorem run_throw [Monad m] : run (throw e : ExceptT ε m β) = pure (Except.error e) := rfl
@[simp, grind =] theorem run_throw [Monad m] : run (throw e : ExceptT ε m β) = pure (Except.error e) := rfl
@[simp] theorem run_bind_lift [Monad m] [LawfulMonad m] (x : m α) (f : α ExceptT ε m β) : run (ExceptT.lift x >>= f : ExceptT ε m β) = x >>= fun a => run (f a) := by
@[simp, grind =] theorem run_bind_lift [Monad m] [LawfulMonad m] (x : m α) (f : α ExceptT ε m β) : run (ExceptT.lift x >>= f : ExceptT ε m β) = x >>= fun a => run (f a) := by
simp [ExceptT.run, ExceptT.lift, bind, ExceptT.bind, ExceptT.mk, ExceptT.bindCont]
@[simp] theorem bind_throw [Monad m] [LawfulMonad m] (f : α ExceptT ε m β) : (throw e >>= f) = throw e := by
@[simp, grind =] theorem bind_throw [Monad m] [LawfulMonad m] (f : α ExceptT ε m β) : (throw e >>= f) = throw e := by
simp [throw, throwThe, MonadExceptOf.throw, bind, ExceptT.bind, ExceptT.bindCont, ExceptT.mk]
theorem run_bind [Monad m] (x : ExceptT ε m α)
@[grind =]
theorem run_bind [Monad m] (x : ExceptT ε m α) (f : α ExceptT ε m β)
: run (x >>= f : ExceptT ε m β)
=
run x >>= fun
@@ -46,10 +47,10 @@ theorem run_bind [Monad m] (x : ExceptT ε m α)
| Except.error e => pure (Except.error e) :=
rfl
@[simp] theorem lift_pure [Monad m] [LawfulMonad m] (a : α) : ExceptT.lift (pure a) = (pure a : ExceptT ε m α) := by
@[simp, grind =] theorem lift_pure [Monad m] [LawfulMonad m] (a : α) : ExceptT.lift (pure a) = (pure a : ExceptT ε m α) := by
simp [ExceptT.lift, pure, ExceptT.pure]
@[simp] theorem run_map [Monad m] [LawfulMonad m] (f : α β) (x : ExceptT ε m α)
@[simp, grind =] theorem run_map [Monad m] [LawfulMonad m] (f : α β) (x : ExceptT ε m α)
: (f <$> x).run = Except.map f <$> x.run := by
simp [Functor.map, ExceptT.map, bind_pure_comp]
apply bind_congr
@@ -113,28 +114,28 @@ instance : LawfulFunctor (Except ε) := inferInstance
namespace ReaderT
@[ext] theorem ext {x y : ReaderT ρ m α} (h : ctx, x.run ctx = y.run ctx) : x = y := by
@[ext, grind ext] theorem ext {x y : ReaderT ρ m α} (h : ctx, x.run ctx = y.run ctx) : x = y := by
simp [run] at h
exact funext h
@[simp] theorem run_pure [Monad m] (a : α) (ctx : ρ) : (pure a : ReaderT ρ m α).run ctx = pure a := rfl
@[simp, grind =] theorem run_pure [Monad m] (a : α) (ctx : ρ) : (pure a : ReaderT ρ m α).run ctx = pure a := rfl
@[simp] theorem run_bind [Monad m] (x : ReaderT ρ m α) (f : α ReaderT ρ m β) (ctx : ρ)
@[simp, grind =] theorem run_bind [Monad m] (x : ReaderT ρ m α) (f : α ReaderT ρ m β) (ctx : ρ)
: (x >>= f).run ctx = x.run ctx >>= λ a => (f a).run ctx := rfl
@[simp] theorem run_mapConst [Monad m] (a : α) (x : ReaderT ρ m β) (ctx : ρ)
@[simp, grind =] theorem run_mapConst [Monad m] (a : α) (x : ReaderT ρ m β) (ctx : ρ)
: (Functor.mapConst a x).run ctx = Functor.mapConst a (x.run ctx) := rfl
@[simp] theorem run_map [Monad m] (f : α β) (x : ReaderT ρ m α) (ctx : ρ)
@[simp, grind =] theorem run_map [Monad m] (f : α β) (x : ReaderT ρ m α) (ctx : ρ)
: (f <$> x).run ctx = f <$> x.run ctx := rfl
@[simp] theorem run_monadLift [MonadLiftT n m] (x : n α) (ctx : ρ)
@[simp, grind =] theorem run_monadLift [MonadLiftT n m] (x : n α) (ctx : ρ)
: (monadLift x : ReaderT ρ m α).run ctx = (monadLift x : m α) := rfl
@[simp] theorem run_monadMap [MonadFunctorT n m] (f : {β : Type u} n β n β) (x : ReaderT ρ m α) (ctx : ρ)
@[simp, grind =] theorem run_monadMap [MonadFunctorT n m] (f : {β : Type u} n β n β) (x : ReaderT ρ m α) (ctx : ρ)
: (monadMap @f x : ReaderT ρ m α).run ctx = monadMap @f (x.run ctx) := rfl
@[simp] theorem run_read [Monad m] (ctx : ρ) : (ReaderT.read : ReaderT ρ m ρ).run ctx = pure ctx := rfl
@[simp, grind =] theorem run_read [Monad m] (ctx : ρ) : (ReaderT.read : ReaderT ρ m ρ).run ctx = pure ctx := rfl
@[simp] theorem run_seq {α β : Type u} [Monad m] (f : ReaderT ρ m (α β)) (x : ReaderT ρ m α) (ctx : ρ)
: (f <*> x).run ctx = (f.run ctx <*> x.run ctx) := rfl
@@ -175,38 +176,39 @@ instance [Monad m] [LawfulMonad m] : LawfulMonad (StateRefT' ω σ m) :=
namespace StateT
@[ext] theorem ext {x y : StateT σ m α} (h : s, x.run s = y.run s) : x = y :=
@[ext, grind ext] theorem ext {x y : StateT σ m α} (h : s, x.run s = y.run s) : x = y :=
funext h
@[simp] theorem run'_eq [Monad m] (x : StateT σ m α) (s : σ) : run' x s = (·.1) <$> run x s :=
@[simp, grind =] theorem run'_eq [Monad m] (x : StateT σ m α) (s : σ) : run' x s = (·.1) <$> run x s :=
rfl
@[simp] theorem run_pure [Monad m] (a : α) (s : σ) : (pure a : StateT σ m α).run s = pure (a, s) := rfl
@[simp, grind =] theorem run_pure [Monad m] (a : α) (s : σ) : (pure a : StateT σ m α).run s = pure (a, s) := rfl
@[simp] theorem run_bind [Monad m] (x : StateT σ m α) (f : α StateT σ m β) (s : σ)
@[simp, grind =] theorem run_bind [Monad m] (x : StateT σ m α) (f : α StateT σ m β) (s : σ)
: (x >>= f).run s = x.run s >>= λ p => (f p.1).run p.2 := by
simp [bind, StateT.bind, run]
@[simp] theorem run_map {α β σ : Type u} [Monad m] [LawfulMonad m] (f : α β) (x : StateT σ m α) (s : σ) : (f <$> x).run s = (fun (p : α × σ) => (f p.1, p.2)) <$> x.run s := by
@[simp, grind =] theorem run_map {α β σ : Type u} [Monad m] [LawfulMonad m] (f : α β) (x : StateT σ m α) (s : σ) : (f <$> x).run s = (fun (p : α × σ) => (f p.1, p.2)) <$> x.run s := by
simp [Functor.map, StateT.map, run, bind_pure_comp]
@[simp] theorem run_get [Monad m] (s : σ) : (get : StateT σ m σ).run s = pure (s, s) := rfl
@[simp, grind =] theorem run_get [Monad m] (s : σ) : (get : StateT σ m σ).run s = pure (s, s) := rfl
@[simp] theorem run_set [Monad m] (s s' : σ) : (set s' : StateT σ m PUnit).run s = pure (, s') := rfl
@[simp, grind =] theorem run_set [Monad m] (s s' : σ) : (set s' : StateT σ m PUnit).run s = pure (, s') := rfl
@[simp] theorem run_modify [Monad m] (f : σ σ) (s : σ) : (modify f : StateT σ m PUnit).run s = pure (, f s) := rfl
@[simp, grind =] theorem run_modify [Monad m] (f : σ σ) (s : σ) : (modify f : StateT σ m PUnit).run s = pure (, f s) := rfl
@[simp] theorem run_modifyGet [Monad m] (f : σ α × σ) (s : σ) : (modifyGet f : StateT σ m α).run s = pure ((f s).1, (f s).2) := by
@[simp, grind =] theorem run_modifyGet [Monad m] (f : σ α × σ) (s : σ) : (modifyGet f : StateT σ m α).run s = pure ((f s).1, (f s).2) := by
simp [modifyGet, MonadStateOf.modifyGet, StateT.modifyGet, run]
@[simp] theorem run_lift {α σ : Type u} [Monad m] (x : m α) (s : σ) : (StateT.lift x : StateT σ m α).run s = x >>= fun a => pure (a, s) := rfl
@[simp, grind =] theorem run_lift {α σ : Type u} [Monad m] (x : m α) (s : σ) : (StateT.lift x : StateT σ m α).run s = x >>= fun a => pure (a, s) := rfl
@[grind =]
theorem run_bind_lift {α σ : Type u} [Monad m] [LawfulMonad m] (x : m α) (f : α StateT σ m β) (s : σ) : (StateT.lift x >>= f).run s = x >>= fun a => (f a).run s := by
simp [StateT.lift, StateT.run, bind, StateT.bind]
@[simp] theorem run_monadLift {α σ : Type u} [Monad m] [MonadLiftT n m] (x : n α) (s : σ) : (monadLift x : StateT σ m α).run s = (monadLift x : m α) >>= fun a => pure (a, s) := rfl
@[simp, grind =] theorem run_monadLift {α σ : Type u} [Monad m] [MonadLiftT n m] (x : n α) (s : σ) : (monadLift x : StateT σ m α).run s = (monadLift x : m α) >>= fun a => pure (a, s) := rfl
@[simp] theorem run_monadMap [MonadFunctorT n m] (f : {β : Type u} n β n β) (x : StateT σ m α) (s : σ) :
@[simp, grind =] theorem run_monadMap [MonadFunctorT n m] (f : {β : Type u} n β n β) (x : StateT σ m α) (s : σ) :
(monadMap @f x : StateT σ m α).run s = monadMap @f (x.run s) := rfl
@[simp] theorem run_seq {α β σ : Type u} [Monad m] [LawfulMonad m] (f : StateT σ m (α β)) (x : StateT σ m α) (s : σ) : (f <*> x).run s = (f.run s >>= fun fs => (fun (p : α × σ) => (fs.1 p.1, p.2)) <$> x.run fs.2) := by