Compare commits

...

1 Commits

Author SHA1 Message Date
Leonardo de Moura
cb638353a4 feat: heterogeneous constructor injectivity theorems
This PR adds a heterogeneous version of the constructor injectivity theorems.
These theorems are useful for indexed families, and will be used in `grind`.
2025-12-02 17:33:20 -08:00
3 changed files with 158 additions and 24 deletions

View File

@@ -4,7 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Basic
import Lean.Meta.Tactic.Refl
@@ -12,9 +11,7 @@ import Lean.Meta.Tactic.Cases
import Lean.Meta.Tactic.Assumption
import Lean.Meta.Tactic.Simp.Main
import Lean.Meta.SameCtorUtils
public section
namespace Lean.Meta
private def mkAnd? (args : Array Expr) : Option Expr := Id.run do
@@ -33,20 +30,26 @@ def elimOptParam (type : Expr) : CoreM Expr := do
else
return .continue
private def mkEqs (args1 args2 : Array Expr) (skipIfPropOrEq : Bool := true) : MetaM (Array Expr) := do
let mut eqs := #[]
for arg1 in args1, arg2 in args2 do
let arg1Type inferType arg1
if !skipIfPropOrEq then
eqs := eqs.push ( mkEqHEq arg1 arg2)
else if !( isProp arg1Type) && arg1 != arg2 then
eqs := eqs.push ( mkEqHEq arg1 arg2)
return eqs
private partial def mkInjectiveTheoremTypeCore? (ctorVal : ConstructorVal) (useEq : Bool) : MetaM (Option Expr) := do
let us := ctorVal.levelParams.map mkLevelParam
let type elimOptParam ctorVal.type
forallBoundedTelescope type ctorVal.numParams fun params type =>
forallTelescope type fun args1 resultType => do
let jp (args2 args2New : Array Expr) : MetaM (Option Expr) := do
let k (args2 args2New : Array Expr) : MetaM (Option Expr) := do
let lhs := mkAppN (mkAppN (mkConst ctorVal.name us) params) args1
let rhs := mkAppN (mkAppN (mkConst ctorVal.name us) params) args2
let eq mkEq lhs rhs
let mut eqs := #[]
for arg1 in args1, arg2 in args2 do
let arg1Type inferType arg1
if !( isProp arg1Type) && arg1 != arg2 then
eqs := eqs.push ( mkEqHEq arg1 arg2)
let eqs mkEqs args1 args2
if let some andEqs := mkAnd? eqs then
let result if useEq then
mkEq eq andEqs
@@ -57,17 +60,15 @@ private partial def mkInjectiveTheoremTypeCore? (ctorVal : ConstructorVal) (useE
return none
let rec mkArgs2 (i : Nat) (type : Expr) (args2 args2New : Array Expr) : MetaM (Option Expr) := do
if h : i < args1.size then
match ( whnf type) with
| Expr.forallE n d b _ =>
let arg1 := args1[i]
if occursOrInType ( getLCtx) arg1 resultType then
mkArgs2 (i + 1) (b.instantiate1 arg1) (args2.push arg1) args2New
else
withLocalDecl n (if useEq then BinderInfo.default else BinderInfo.implicit) d fun arg2 =>
mkArgs2 (i + 1) (b.instantiate1 arg2) (args2.push arg2) (args2New.push arg2)
| _ => throwError "unexpected constructor type for `{ctorVal.name}`"
let .forallE n d b _ whnf type | throwError "unexpected constructor type for `{ctorVal.name}`"
let arg1 := args1[i]
if occursOrInType ( getLCtx) arg1 resultType then
mkArgs2 (i + 1) (b.instantiate1 arg1) (args2.push arg1) args2New
else
withLocalDecl n (if useEq then BinderInfo.default else BinderInfo.implicit) d fun arg2 =>
mkArgs2 (i + 1) (b.instantiate1 arg2) (args2.push arg2) (args2New.push arg2)
else
jp args2 args2New
k args2 args2New
if useEq then
mkArgs2 0 type #[] #[]
else
@@ -84,14 +85,16 @@ private def injTheoremFailureHeader (ctorName : Name) : MessageData :=
private def throwInjectiveTheoremFailure {α} (ctorName : Name) (mvarId : MVarId) : MetaM α :=
throwError "{injTheoremFailureHeader ctorName}{indentD <| MessageData.ofGoal mvarId}"
private def splitAndAssumption (mvarId : MVarId) (ctorName : Name) : MetaM Unit := do
( mvarId.splitAnd).forM fun mvarId =>
unless ( mvarId.assumptionCore) do
throwInjectiveTheoremFailure ctorName mvarId
private def solveEqOfCtorEq (ctorName : Name) (mvarId : MVarId) (h : FVarId) : MetaM Unit := do
trace[Meta.injective] "solving injectivity goal for {ctorName} with hypothesis {mkFVar h} at\n{mvarId}"
match ( injection mvarId h) with
| InjectionResult.solved => unreachable!
| InjectionResult.subgoal mvarId .. =>
( mvarId.splitAnd).forM fun mvarId =>
unless ( mvarId.assumptionCore) do
throwInjectiveTheoremFailure ctorName mvarId
| InjectionResult.subgoal mvarId .. => splitAndAssumption mvarId ctorName
private def mkInjectiveTheoremValue (ctorName : Name) (targetType : Expr) : MetaM Expr :=
forallTelescopeReducing targetType fun xs type => do
@@ -178,4 +181,106 @@ def mkInjectiveTheorems (declName : Name) : MetaM Unit := do
builtin_initialize
registerTraceClass `Meta.injective
private def getIndices? (ctorApp : Expr) : MetaM (Option (Array Expr)) := do
let type whnfD ( inferType ctorApp)
type.withApp fun typeFn typeArgs => do
let .const declName _ := typeFn | return none
let .inductInfo val getConstInfo declName | return none
if val.numIndices == 0 then return some #[]
return some typeArgs[val.numParams...*].toArray
private def mkArrows (hs : Array Expr) (type : Expr) : CoreM Expr := do
hs.foldrM (init := type) mkArrow
private structure MkHInjTypeResult where
thmType : Expr
us : List Level
numIndices : Nat
private partial def mkHInjType? (ctorVal : ConstructorVal) : MetaM (Option MkHInjTypeResult) := do
let us := ctorVal.levelParams.map mkLevelParam
let type elimOptParam ctorVal.type
forallBoundedTelescope type ctorVal.numParams fun params type =>
forallTelescope type fun args1 resultType => do
let k (args2 : Array Expr) : MetaM (Option MkHInjTypeResult) := do
let lhs := mkAppN (mkAppN (mkConst ctorVal.name us) params) args1
let rhs := mkAppN (mkAppN (mkConst ctorVal.name us) params) args2
let eq mkEqHEq lhs rhs
let eqs mkEqs args1 args2
if let some andEqs := mkAnd? eqs then
let result mkArrow eq andEqs
let some idxs1 getIndices? lhs | return none
let some idxs2 getIndices? rhs | return none
-- **Note**: We dot not skip here because the type of `noConfusion` does not.
let idxEqs mkEqs idxs1 idxs2 (skipIfPropOrEq := false)
let result mkArrows idxEqs result
let thmType mkForallFVars params ( mkForallFVars args1 ( mkForallFVars args2 result))
return some { thmType, us, numIndices := idxs1.size }
else
return none
let rec mkArgs2 (i : Nat) (type : Expr) (args2 : Array Expr) : MetaM (Option MkHInjTypeResult) := do
if h : i < args1.size then
let .forallE n d b _ whnf type | throwError "unexpected constructor type for `{ctorVal.name}`"
let arg1 := args1[i]
withLocalDecl n .implicit d fun arg2 =>
mkArgs2 (i + 1) (b.instantiate1 arg2) (args2.push arg2)
else
k args2
withNewBinderInfos (params.map fun param => (param.fvarId!, BinderInfo.implicit)) <|
withNewBinderInfos (args1.map fun arg1 => (arg1.fvarId!, BinderInfo.implicit)) <|
mkArgs2 0 type #[]
private def failedToGenHInj (ctorVal : ConstructorVal) : MetaM α :=
throwError "failed to generate heterogeneous injectivity theorem for `{ctorVal.name}`"
private partial def mkHInjectiveTheoremValue? (ctorVal : ConstructorVal) (typeInfo : MkHInjTypeResult) : MetaM (Option Expr) := do
forallTelescopeReducing typeInfo.thmType fun xs type => do
let noConfusionName := ctorVal.induct.str "noConfusion"
let params := xs[*...ctorVal.numParams]
let noConfusion := mkAppN (mkConst noConfusionName (0 :: typeInfo.us)) params
let noConfusion := mkApp noConfusion type
let n := xs.size - typeInfo.numIndices - 1
let eqs := xs[n...*].toArray
let eqExprs eqs.mapM fun x => do
match_expr ( inferType x) with
| Eq _ lhs rhs => return (lhs, rhs)
| HEq _ lhs _ rhs => return (lhs, rhs)
| _ => failedToGenHInj ctorVal
let (args₁, args₂) := eqExprs.unzip
let noConfusion := mkAppN (mkAppN (mkAppN noConfusion args₁) args₂) eqs
let .forallE _ d _ _ whnf ( inferType noConfusion) | failedToGenHInj ctorVal
let mvar mkFreshExprSyntheticOpaqueMVar d
let noConfusion := mkApp noConfusion mvar
let mvarId := mvar.mvarId!
let (_, mvarId) mvarId.intros
splitAndAssumption mvarId ctorVal.name
check noConfusion
let result instantiateMVars noConfusion
mkLambdaFVars xs result
private def hinjSuffix := "hinj"
def mkHInjectiveTheoremNameFor (ctorName : Name) : Name :=
ctorName.str hinjSuffix
private def mkHInjectiveTheorem? (thmName : Name) (ctorVal : ConstructorVal) : MetaM (Option TheoremVal) := do
let some typeInfo mkHInjType? ctorVal | return none
let some value mkHInjectiveTheoremValue? ctorVal typeInfo | return none
return some { name := thmName, value, levelParams := ctorVal.levelParams, type := typeInfo.thmType }
builtin_initialize registerReservedNamePredicate fun env n =>
match n with
| .str p "hinj" => (env.find? p matches some (.ctorInfo _))
| _ => false
builtin_initialize
registerReservedNameAction fun name => do
let .str p "hinj" := name | return false
let some (.ctorInfo ctorVal) := ( getEnv).find? p | return false
MetaM.run' do
let some thmVal mkHInjectiveTheorem? name ctorVal | return false
realizeConst p name do
addDecl ( mkThmOrUnsafeDef thmVal)
return true
end Lean.Meta

View File

@@ -0,0 +1,22 @@
opaque double : Nat Nat
def P (n : Nat) : Prop := n >= 0
theorem pax (n : Nat) : P n := by grind [P]
def T (n : Nat) : Type := Vector Nat n
inductive Foo' (α β : Type u) : (n : Nat) P n -> Type u
| even (a : α) (n : Nat) (v : T n) (h : P n) : Foo' α β (double n) (pax _)
| odd (b : β) (n : Nat) (v : T n) : Foo' α β (Nat.succ (double n)) (pax _)
/--
info: Foo'.even.hinj.{u} {α β : Type u} {a : α} {n : Nat} {v : T n} {h : P n} {a✝ : α} {n✝ : Nat} {v✝ : T n✝} {h✝ : P n✝} :
double n = double n✝ → ⋯ ≍ ⋯ → Foo'.even a n v h ≍ Foo'.even a✝ n✝ v✝ h✝ → a = a✝ ∧ n = n✝ ∧ v ≍ v✝
-/
#guard_msgs in
#check Foo'.even.hinj
/--
info: Foo'.odd.hinj.{u} {α β : Type u} {b : β} {n : Nat} {v : T n} {b✝ : β} {n✝ : Nat} {v✝ : T n✝} :
(double n).succ = (double n✝).succ → ⋯ ≍ ⋯ → Foo'.odd b n v ≍ Foo'.odd b✝ n✝ v✝ → b = b✝ ∧ n = n✝ ∧ v ≍ v✝
-/
#guard_msgs in
#check Foo'.odd.hinj

View File

@@ -56,8 +56,15 @@ info: Vec.cons.inj.{u} {α : Type u} {n : Nat} {x : α} {xs : Vec α n} {x✝ :
#guard_msgs in
#check Vec.cons.inj
theorem Vec.cons.hinj {α : Type u}
theorem Vec.cons.hinj' {α : Type u}
{x : α} {n : Nat} {xs : Vec α n} {x' : α} {n' : Nat} {xs' : Vec α n'} :
Vec.cons x xs Vec.cons x' xs' (n + 1 = n' + 1 (x = x' xs xs')) := by
intro h eq_1
apply Vec.cons.noConfusion eq_1 h (fun _ eq_x eq_xs => eq_x, eq_xs)
/--
info: Vec.cons.hinj.{u} {α : Type u} {n : Nat} {x : α} {xs : Vec α n} {n✝ : Nat} {x✝ : α} {xs✝ : Vec α n✝} :
n + 1 = n✝ + 1 → Vec.cons x xs ≍ Vec.cons x✝ xs✝ → n = n✝ ∧ x = x✝ ∧ xs ≍ xs✝
-/
#guard_msgs in
#check Vec.cons.hinj