feat: linear-size DecidableEq instance (#10152)

This PR introduces an alternative construction for `DecidableEq`
instances that avoids the quadratic overhead of the default
construction.

The usual construction uses a `match` statement that looks at each pair
of constructors, and thus is necessarily quadratic in size. For
inductive data type with dozens of constructors or more, this quickly
becomes slow to process.

The new construction first compares the constructor tags (using the
`.ctorIdx` introduced in #9951), and handles the case of a differing
constructor tag quickly. If the constructor tags match, it uses the
per-constructor-eliminators (#9952) to create a linear-size instance. It
does so by creating a custom “matcher” for a parallel match on the data
types and the `h : x1.ctorIdx = x2.ctorIdx` assumption; this behaves
(and delaborates) like a normal `match` statement, but is implemented in
a bespoke way. This same-constructor-matcher will be useful for
implementing other instances as well.

The new construction produces less efficient code at the moment, so we
use it only for inductive types with 10 or more constructors by default.
The option `deriving.decEq.linear_construction_threshold` can be used to
adjust the threshold; set it to 0 to always use the new construction.
This commit is contained in:
Joachim Breitner
2025-09-03 08:31:49 +02:00
committed by GitHub
parent a4f6f391fe
commit ccb8568756
11 changed files with 705 additions and 31 deletions

View File

@@ -6,21 +6,32 @@ Authors: Leonardo de Moura
module
prelude
public import Lean.Data.Options
import Lean.Meta.Transform
import Lean.Meta.Inductive
import Lean.Elab.Deriving.Basic
import Lean.Elab.Deriving.Util
import Lean.Meta.NatTable
import Lean.Meta.Constructions.CtorIdx
import Lean.Meta.Constructions.CtorElim
import Lean.Meta.Constructions.CasesOnSameCtor
namespace Lean.Elab.Deriving.DecEq
open Lean.Parser.Term
open Meta
register_builtin_option deriving.decEq.linear_construction_threshold : Nat := {
defValue := 10
descr := "If the inductive data type has this many or more constructors, use a different \
implementation for deciding equality that avoids the quadratic code size produced by the \
default implementation.\n\n\
The alternative construction compiles to less efficient code in some cases, so by default \
it is only used for inductive types with 10 or more constructors." }
def mkDecEqHeader (indVal : InductiveVal) : TermElabM Header := do
mkHeader `DecidableEq 2 indVal
def mkMatch (ctx : Context) (header : Header) (indVal : InductiveVal) : TermElabM Term := do
def mkMatchOld (ctx : Context) (header : Header) (indVal : InductiveVal) : TermElabM Term := do
let discrs mkDiscrs header indVal
let alts mkAlts
`(match $[$discrs],* with $alts:matchAlt*)
@@ -91,6 +102,76 @@ where
alts := alts.push ( `(matchAltExpr| | $[$patterns:term],* => $rhs:term))
return alts
def mkMatchNew (ctx : Context) (header : Header) (indVal : InductiveVal) : TermElabM Term := do
assert! header.targetNames.size == 2
let x1 := mkIdent header.targetNames[0]!
let x2 := mkIdent header.targetNames[1]!
let ctorIdxName := mkCtorIdxName indVal.name
-- NB: the getMatcherInfo? assumes all mathcers are called `match_`
let casesOnSameCtorName mkFreshUserName (indVal.name ++ `match_on_same_ctor)
mkCasesOnSameCtor casesOnSameCtorName indVal.name
let alts Array.ofFnM (n := indVal.numCtors) fun ctorIdx, _ => do
let ctorName := indVal.ctors[ctorIdx]!
let ctorInfo getConstInfoCtor ctorName
forallTelescopeReducing ctorInfo.type fun xs type => do
let type Core.betaReduce type -- we 'beta-reduce' to eliminate "artificial" dependencies
let mut ctorArgs1 : Array Term := #[]
let mut ctorArgs2 : Array Term := #[]
let mut todo := #[]
for i in *...ctorInfo.numFields do
let x := xs[indVal.numParams + i]!
if type.containsFVar x.fvarId! then
-- If resulting type depends on this field, we don't need to bring it into
-- scope nor compare it
ctorArgs1 := ctorArgs1.push ( `(_))
else
let a := mkIdent ( mkFreshUserName `a)
let b := mkIdent ( mkFreshUserName `b)
ctorArgs1 := ctorArgs1.push a
ctorArgs2 := ctorArgs2.push b
let xType inferType x
let indValNum :=
ctx.typeInfos.findIdx?
(xType.isAppOf ConstantVal.name InductiveVal.toConstantVal)
let recField := indValNum.map (ctx.auxFunNames[·]!)
let isProof isProp xType
todo := todo.push (a, b, recField, isProof)
let rhs mkSameCtorRhs todo.toList
`(@fun $ctorArgs1:term* $ctorArgs2:term* =>$rhs:term)
if indVal.numCtors == 1 then
`( $(mkCIdent casesOnSameCtorName) $x1:term $x2:term rfl $alts:term* )
else
`( if h : $(mkCIdent ctorIdxName) $x1:ident = $(mkCIdent ctorIdxName) $x2:ident then
$(mkCIdent casesOnSameCtorName) $x1:term $x2:term h $alts:term*
else
isFalse (fun h' => h (congrArg $(mkCIdent ctorIdxName) h')))
where
mkSameCtorRhs : List (Ident × Ident × Option Name × Bool) TermElabM Term
| [] => ``(isTrue rfl)
| (a, b, recField, isProof) :: todo => withFreshMacroScope do
let rhs if isProof then
`(have h : @$a = @$b := rfl; by subst h; exact $( mkSameCtorRhs todo):term)
else
let sameCtor mkSameCtorRhs todo
`(if h : @$a = @$b then
by subst h; exact $sameCtor:term
else
isFalse (by intro n; injection n; apply h _; assumption))
if let some auxFunName := recField then
-- add local instance for `a = b` using the function being defined `auxFunName`
`(let inst := $(mkIdent auxFunName) @$a @$b; $rhs)
else
return rhs
def mkMatch (ctx : Context) (header : Header) (indVal : InductiveVal) : TermElabM Term := do
if indVal.numCtors deriving.decEq.linear_construction_threshold.get ( getOptions) then
mkMatchNew ctx header indVal
else
mkMatchOld ctx header indVal
def mkAuxFunction (ctx : Context) (auxFunName : Name) (indVal : InductiveVal): TermElabM (TSyntax `command) := do
let header mkDecEqHeader indVal
let body mkMatch ctx header indVal

View File

@@ -197,6 +197,7 @@ private partial def replaceRecApps (recArgInfos : Array RecArgInfo) (positions :
mkLambdaFVars xs ( loop belowForAlt altBody)
pure { matcherApp with alts := altsNew }.toExpr
else
trace[Elab.definition.structural] "`matcherApp.addArg?` failed"
processApp e
| none => processApp e
| e =>

View File

@@ -10,5 +10,6 @@ public import Lean.Meta.Constructions.CasesOn
public import Lean.Meta.Constructions.NoConfusion
public import Lean.Meta.Constructions.RecOn
public import Lean.Meta.Constructions.BRecOn
public import Lean.Meta.Constructions.CasesOnSameCtor
public section

View File

@@ -0,0 +1,248 @@
/-
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Joachim Breitner
-/
module
prelude
public import Lean.Meta.Basic
import Lean.AddDecl
import Lean.Meta.AppBuilder
import Lean.Meta.CompletionName
import Lean.Meta.Constructions.CtorIdx
import Lean.Meta.Constructions.CtorElim
import Lean.Elab.App
/-!
See `mkCasesOnSameCtor` below.
-/
namespace Lean
open Meta
/--
Helper for `mkCasesOnSameCtor` that constructs a heterogenous matcher (indices may differ)
and does not include the equality proof in the motive (so it's not a the shape of a matcher) yet.
-/
public def mkCasesOnSameCtorHet (declName : Name) (indName : Name) : MetaM Unit := do
let ConstantInfo.inductInfo info getConstInfo indName | unreachable!
let casesOnName := mkCasesOnName indName
let casesOnInfo getConstVal casesOnName
let v::us := casesOnInfo.levelParams.map mkLevelParam | panic! "unexpected universe levels on `casesOn`"
let e forallBoundedTelescope casesOnInfo.type info.numParams fun params t =>
forallBoundedTelescope t (some 1) fun _ t => -- ignore motive
forallBoundedTelescope t (some (info.numIndices + 1)) fun ism1 _ =>
forallBoundedTelescope t (some (info.numIndices + 1)) fun ism2 _ => do
let motiveType mkForallFVars (ism1 ++ ism2) (mkSort v)
withLocalDecl `motive .implicit motiveType fun motive => do
let altTypes info.ctors.toArray.mapIdxM fun i ctorName => do
let ctor := mkAppN (mkConst ctorName us) params
let ctorType inferType ctor
forallTelescope ctorType fun zs1 ctorRet1 => do
let ctorApp1 := mkAppN ctor zs1
let ctorRet1 whnf ctorRet1
let is1 : Array Expr := ctorRet1.getAppArgs[info.numParams:]
let ism1 := is1.push ctorApp1
forallTelescope ctorType fun zs2 ctorRet2 => do
let ctorApp2 := mkAppN ctor zs2
let ctorRet2 whnf ctorRet2
let is2 : Array Expr := ctorRet2.getAppArgs[info.numParams:]
let ism2 := is2.push ctorApp2
let e := mkAppN motive (ism1 ++ ism2)
let e mkForallFVars (zs1 ++ zs2) e
let name := match ctorName with
| Name.str _ s => Name.mkSimple s
| _ => Name.mkSimple s!"alt{i+1}"
return (name, e)
withLocalDeclsDND altTypes fun alts => do
let ctorApp1 := mkAppN (mkConst (mkCtorIdxName indName) us) (params ++ ism1)
let ctorApp2 := mkAppN (mkConst (mkCtorIdxName indName) us) (params ++ ism2)
let heqType mkEq ctorApp1 ctorApp2
let heqType' mkEq ctorApp2 ctorApp1
withLocalDeclD `h heqType fun heq => do
let motive1 mkLambdaFVars ism1 ( mkArrow heqType' (mkAppN motive (ism1 ++ ism2)))
let e := mkConst casesOnInfo.name (v :: us)
let e := mkAppN e params
let e := mkApp e motive1
let e := mkAppN e ism1
let alts1 info.ctors.toArray.mapIdxM fun i ctorName => do
let ctor := mkAppN (mkConst ctorName us) params
let ctorType inferType ctor
forallTelescope ctorType fun zs1 ctorRet1 => do
let ctorApp1 := mkAppN ctor zs1
let ctorRet1 whnf ctorRet1
let is1 : Array Expr := ctorRet1.getAppArgs[info.numParams:]
let ism1 := is1.push ctorApp1
-- Here we let the typecheker reduce the `ctorIdx` application
let heq := mkApp3 (mkConst ``Eq [1]) (mkConst ``Nat) ctorApp2 (mkRawNatLit i)
withLocalDeclD `h heq fun h => do
let motive2 mkLambdaFVars ism2 (mkAppN motive (ism1 ++ ism2))
let alt forallTelescope ctorType fun zs2 _ => do
mkLambdaFVars zs2 <| mkAppN alts[i]! (zs1 ++ zs2)
let e := if info.numCtors = 1 then
let casesOn := mkConst (mkCasesOnName indName) (v :: us)
mkAppN casesOn (params ++ #[motive2] ++ ism2 ++ #[alt])
else
let casesOn := mkConst (mkConstructorElimName indName ctorName) (v :: us)
mkAppN casesOn (params ++ #[motive2] ++ ism2 ++ #[h, alt])
mkLambdaFVars (zs1.push h) e
let e := mkAppN e alts1
let e := mkApp e ( mkEqSymm heq)
mkLambdaFVars (params ++ #[motive] ++ ism1 ++ ism2 ++ #[heq] ++ alts) e
addAndCompile (.defnDecl ( mkDefinitionValInferringUnsafe
(name := declName)
(levelParams := casesOnInfo.levelParams)
(type := ( inferType e))
(value := e)
(hints := ReducibilityHints.abbrev)
))
modifyEnv fun env => markAuxRecursor env declName
modifyEnv fun env => addToCompletionBlackList env declName
modifyEnv fun env => addProtected env declName
Elab.Term.elabAsElim.setTag declName
setReducibleAttribute declName
def withSharedIndices (ctor : Expr) (k : Array Expr Expr Expr MetaM α) : MetaM α := do
let ctorType inferType ctor
forallTelescopeReducing ctorType fun zs ctorRet => do
let ctor1 := mkAppN ctor zs
let rec go ctor2 todo acc := do
match todo with
| [] => k acc ctor1 ctor2
| z::todo' =>
if ctorRet.containsFVar z.fvarId! then
go (mkApp ctor2 z) todo' acc
else
let t whnfForall ( inferType ctor2)
assert! t.isForall
withLocalDeclD t.bindingName! t.bindingDomain! fun z' => do
go (mkApp ctor2 z') todo' (acc.push z')
go ctor zs.toList zs
/--
This constructs a matcher for a match statement that matches on the constructors of
a data type in parallel. So if `h : x1.ctorIdx = x2.ctorIdx`, then it implements
```
match x1, x2, h with
| ctor1 .. , ctor1 .. , _ => ...
| ctor2 .. , ctor2 .. , _ => ...
```
The normal matcher supports such matches, but implements them using nested `casesOn`, which
leads to a quadratic blow-up. This function uses the per-constructor eliminators to implement this
more efficiently.
This is useful for implementing or deriving functionality like `BEq`, `DecidableEq`, `Ord` and
proving their lawfulness.
One could imagine a future where `match` compilation is smart enough to do that automatically; then
this module can be dropped.
Note that for some data types where the indices determine the constructor (e.g. `Vec`), this leads
to less efficient code than the normal matcher, as this needs to read the constructor tag on both
arguments, wheras the normal matcher produces code that reads just the first arguments tag, and
then boldly reads the second arguments fields.
-/
public def mkCasesOnSameCtor (declName : Name) (indName : Name) : MetaM Unit := do
let ConstantInfo.inductInfo info getConstInfo indName | unreachable!
let casesOnSameCtorHet := declName ++ `het
mkCasesOnSameCtorHet casesOnSameCtorHet indName
let casesOnName := mkCasesOnName indName
let casesOnInfo getConstVal casesOnName
let v::us := casesOnInfo.levelParams.map mkLevelParam | panic! "unexpected universe levels on `casesOn`"
forallBoundedTelescope casesOnInfo.type info.numParams fun params t =>
let t0 := t.bindingBody! -- ignore motive
forallBoundedTelescope t0 (some info.numIndices) fun is t =>
forallBoundedTelescope t (some 1) fun x1 _ =>
forallBoundedTelescope t (some 1) fun x2 _ => do
let x1 := x1[0]!
let x2 := x2[0]!
let ctorApp1 := mkAppN (mkConst (mkCtorIdxName indName) us) (params ++ is ++ #[x1])
let ctorApp2 := mkAppN (mkConst (mkCtorIdxName indName) us) (params ++ is ++ #[x2])
let heqType mkEq ctorApp1 ctorApp2
withLocalDeclD `h heqType fun heq => do
let motiveType mkForallFVars (is ++ #[x1,x2,heq]) (mkSort v)
withLocalDecl `motive .implicit motiveType fun motive => do
let altTypes info.ctors.toArray.mapIdxM fun i ctorName => do
let ctor := mkAppN (mkConst ctorName us) params
withSharedIndices ctor fun zs12 ctorApp1 ctorApp2 => do
let ctorRet1 whnf ( inferType ctorApp1)
let is : Array Expr := ctorRet1.getAppArgs[info.numParams:]
let e := mkAppN motive (is ++ #[ctorApp1, ctorApp2, ( mkEqRefl (mkNatLit i))])
let e mkForallFVars zs12 e
let name := match ctorName with
| Name.str _ s => Name.mkSimple s
| _ => Name.mkSimple s!"alt{i+1}"
return (name, e)
withLocalDeclsDND altTypes fun alts => do
forallBoundedTelescope t0 (some (info.numIndices + 1)) fun ism1' _ =>
forallBoundedTelescope t0 (some (info.numIndices + 1)) fun ism2' _ => do
let (motive', newRefls)
withNewEqs (is.push x1) ism1' fun newEqs1 newRefls1 => do
withNewEqs (is.push x2) ism2' fun newEqs2 newRefls2 => do
let motive' := mkAppN motive (is ++ #[x1, x2, heq])
let motive' mkForallFVars (newEqs1 ++ newEqs2) motive'
let motive' mkLambdaFVars (ism1' ++ ism2') motive'
return (motive', newRefls1 ++ newRefls2)
let casesOn2 := mkConst casesOnSameCtorHet (v :: us)
let casesOn2 := mkAppN casesOn2 params
let casesOn2 := mkApp casesOn2 motive'
let casesOn2 := mkAppN casesOn2 (is ++ #[x1] ++ is ++ #[x2])
let casesOn2 := mkApp casesOn2 heq
let altTypes' inferArgumentTypesN info.numCtors casesOn2
let alts' info.ctors.toArray.mapIdxM fun i ctorName => do
let ctor := mkAppN (mkConst ctorName us) params
let ctorType inferType ctor
forallTelescope ctorType fun zs1 _ctorRet1 => do
forallTelescope ctorType fun zs2 _ctorRet2 => do
let altType instantiateForall altTypes'[i]! (zs1 ++ zs2)
let alt mkFreshExprSyntheticOpaqueMVar altType
let goal := alt.mvarId!
let some (goal, _) Cases.unifyEqs? newRefls.size goal {}
| throwError "unifyEqns? unexpectedly closed goal"
let [] goal.apply alts[i]!
| throwError "could not apply {alts[i]!} to close\n{goal}"
mkLambdaFVars (zs1 ++ zs2) ( instantiateMVars alt)
let casesOn2 := mkAppN casesOn2 alts'
let casesOn2 := mkAppN casesOn2 newRefls
let e mkLambdaFVars (params ++ #[motive] ++ is ++ #[x1,x2] ++ #[heq] ++ alts) casesOn2
let decl := .defnDecl ( mkDefinitionValInferringUnsafe
(name := declName)
(levelParams := casesOnInfo.levelParams)
(type := ( inferType e))
(value := e)
(hints := ReducibilityHints.abbrev)
)
let matcherInfo : MatcherInfo := {
numParams := info.numParams
numDiscrs := info.numIndices + 3
altNumParams := altTypes.map (·.2.getNumHeadForalls)
uElimPos? := some 0
discrInfos := #[{}, {}, {}]}
-- Compare attributes with `mkMatcherAuxDefinition`
addDecl decl
Elab.Term.elabAsElim.setTag declName
Match.addMatcherInfo declName matcherInfo
setInlineAttribute declName
-- Pragmatic hack:
-- Normally a matcher is not marked as an aux recursor. We still do that here
-- because this makes the elaborator unfold it more eagerily, it seems,
-- and this works around issues with the structural recursion equation generator
-- (see #10195).
modifyEnv fun env => markAuxRecursor env declName
enableRealizationsForConst declName
compileDecl decl
end Lean

View File

@@ -3,25 +3,58 @@
-- Print the generated derivations
set_option trace.Elab.Deriving.decEq true
namespace A
set_option deriving.decEq.linear_construction_threshold 1000
mutual
inductive Tree : Type :=
inductive Tree : Type where
| node : ListTree Tree
inductive ListTree : Type :=
inductive ListTree : Type where
| nil : ListTree
| cons : Tree ListTree ListTree
deriving DecidableEq
end
mutual
inductive Foo₁ : Type :=
inductive Foo₁ : Type where
| foo₁₁ : Foo₁
| foo₁₂ : Foo₂ Foo₁
deriving DecidableEq
inductive Foo₂ : Type :=
inductive Foo₂ : Type where
| foo₂ : Foo₃ Foo₂
inductive Foo₃ : Type :=
inductive Foo₃ : Type where
| foo₃ : Foo₁ Foo₃
end
end A
namespace B
set_option deriving.decEq.linear_construction_threshold 0
mutual
inductive Tree : Type where
| node : ListTree Tree
inductive ListTree : Type where
| nil : ListTree
| cons : Tree ListTree ListTree
deriving DecidableEq
end
mutual
inductive Foo₁ : Type where
| foo₁₁ : Foo₁
| foo₁₂ : Foo₂ Foo₁
deriving DecidableEq
inductive Foo₂ : Type where
| foo₂ : Foo₃ Foo₂
inductive Foo₃ : Type where
| foo₃ : Foo₁ Foo₃
end
end B

View File

@@ -1,18 +1,18 @@
[Elab.Deriving.decEq]
[mutual
def decEqTree✝ (x✝ : @Tree✝) (x✝¹ : @Tree✝) : Decidable✝ (x✝ = x✝¹) :=
def decEqTree✝ (x✝ : @A.Tree✝) (x✝¹ : @A.Tree✝) : Decidable✝ (x✝ = x✝¹) :=
match x✝, x✝¹ with
| @Tree.node a✝, @Tree.node b✝ =>
| @A.Tree.node a✝, @A.Tree.node b✝ =>
let inst✝ := decEqListTree✝ @a✝ @b✝;
if h✝ : @a✝ = @b✝ then by subst h✝; exact isTrue✝ rfl✝
else isFalse✝ (by intro n✝; injection n✝; apply h✝ _; assumption)
termination_by structural x✝
def decEqListTree✝ (x✝² : @ListTree✝) (x✝³ : @ListTree✝) : Decidable✝ (x✝² = x✝³) :=
def decEqListTree✝ (x✝² : @A.ListTree✝) (x✝³ : @A.ListTree✝) : Decidable✝ (x✝² = x✝³) :=
match x✝², x✝³ with
| @ListTree.nil, @ListTree.nil => isTrue✝¹ rfl✝¹
| ListTree.nil .., ListTree.cons .. => isFalse✝¹ (by intro h✝¹; injection h✝¹)
| ListTree.cons .., ListTree.nil .. => isFalse✝¹ (by intro h✝¹; injection h✝¹)
| @ListTree.cons a✝¹ a✝², @ListTree.cons b✝¹ b✝² =>
| @A.ListTree.nil, @A.ListTree.nil => isTrue✝¹ rfl✝¹
| A.ListTree.nil .., A.ListTree.cons .. => isFalse✝¹ (by intro h✝¹; injection h✝¹)
| A.ListTree.cons .., A.ListTree.nil .. => isFalse✝¹ (by intro h✝¹; injection h✝¹)
| @A.ListTree.cons a✝¹ a✝², @A.ListTree.cons b✝¹ b✝² =>
let inst✝¹ := decEqTree✝ @a✝¹ @b✝¹;
if h✝² : @a✝¹ = @b✝¹ then by subst h✝²;
exact
@@ -22,34 +22,87 @@
else isFalse✝³ (by intro n✝²; injection n✝²; apply h✝² _; assumption)
termination_by structural x✝²
end,
instance : DecidableEq✝ (@ListTree✝) :=
instance : DecidableEq✝ (@A.ListTree✝) :=
decEqListTree✝]
[Elab.Deriving.decEq]
[mutual
def decEqFoo₁✝ (x✝ : @Foo₁✝) (x✝¹ : @Foo₁✝) : Decidable✝ (x✝ = x✝¹) :=
def decEqFoo₁✝ (x✝ : @A.Foo₁✝) (x✝¹ : @A.Foo₁✝) : Decidable✝ (x✝ = x✝¹) :=
match x✝, x✝¹ with
| @Foo₁.foo₁₁, @Foo₁.foo₁₁ => isTrue✝ rfl✝
| Foo₁.foo₁₁ .., Foo₁.foo₁₂ .. => isFalse✝ (by intro h✝; injection h✝)
| Foo₁.foo₁₂ .., Foo₁.foo₁₁ .. => isFalse✝ (by intro h✝; injection h✝)
| @Foo₁.foo₁₂ a✝, @Foo₁.foo₁₂ b✝ =>
| @A.Foo₁.foo₁₁, @A.Foo₁.foo₁₁ => isTrue✝ rfl✝
| A.Foo₁.foo₁₁ .., A.Foo₁.foo₁₂ .. => isFalse✝ (by intro h✝; injection h✝)
| A.Foo₁.foo₁₂ .., A.Foo₁.foo₁₁ .. => isFalse✝ (by intro h✝; injection h✝)
| @A.Foo₁.foo₁₂ a✝, @A.Foo₁.foo₁₂ b✝ =>
let inst✝ := decEqFoo₂✝ @a✝ @b✝;
if h✝¹ : @a✝ = @b✝ then by subst h✝¹; exact isTrue✝¹ rfl✝¹
else isFalse✝¹ (by intro n✝; injection n✝; apply h✝¹ _; assumption)
termination_by structural x✝
def decEqFoo₂✝ (x✝² : @Foo₂✝) (x✝³ : @Foo₂✝) : Decidable✝ (x✝² = x✝³) :=
def decEqFoo₂✝ (x✝² : @A.Foo₂✝) (x✝³ : @A.Foo₂✝) : Decidable✝ (x✝² = x✝³) :=
match x✝², x✝³ with
| @Foo₂.foo₂ a✝¹, @Foo₂.foo₂ b✝¹ =>
| @A.Foo₂.foo₂ a✝¹, @A.Foo₂.foo₂ b✝¹ =>
let inst✝¹ := decEqFoo₃✝ @a✝¹ @b✝¹;
if h✝² : @a✝¹ = @b✝¹ then by subst h✝²; exact isTrue✝² rfl✝²
else isFalse✝² (by intro n✝¹; injection n✝¹; apply h✝² _; assumption)
termination_by structural x✝²
def decEqFoo₃✝ (x✝⁴ : @Foo₃✝) (x✝⁵ : @Foo₃✝) : Decidable✝ (x✝⁴ = x✝⁵) :=
def decEqFoo₃✝ (x✝⁴ : @A.Foo₃✝) (x✝⁵ : @A.Foo₃✝) : Decidable✝ (x✝⁴ = x✝⁵) :=
match x✝⁴, x✝⁵ with
| @Foo₃.foo₃ a✝², @Foo₃.foo₃ b✝² =>
| @A.Foo₃.foo₃ a✝², @A.Foo₃.foo₃ b✝² =>
let inst✝² := decEqFoo₁✝ @a✝² @b✝²;
if h✝³ : @a✝² = @b✝² then by subst h✝³; exact isTrue✝³ rfl✝³
else isFalse✝³ (by intro n✝²; injection n✝²; apply h✝³ _; assumption)
termination_by structural x✝⁴
end,
instance : DecidableEq✝ (@Foo₁✝) :=
instance : DecidableEq✝ (@A.Foo₁✝) :=
decEqFoo₁✝]
[Elab.Deriving.decEq]
[mutual
def decEqTree✝ (x✝ : @B.Tree✝) (x✝¹ : @B.Tree✝) : Decidable✝ (x✝ = x✝¹) :=
B.Tree.match_on_same_ctor✝ x✝ x✝¹ rfl✝
@fun a✝ b✝ =>
let inst✝ := decEqListTree✝ @a✝ @b✝;
if h✝ : @a✝ = @b✝ then by subst h✝; exact isTrue✝ rfl✝¹
else isFalse✝ (by intro n✝; injection n✝; apply h✝ _; assumption)
termination_by structural x✝
def decEqListTree✝ (x✝² : @B.ListTree✝) (x✝³ : @B.ListTree✝) : Decidable✝ (x✝² = x✝³) :=
if h✝¹ : B.ListTree.ctorIdx✝ x✝² = B.ListTree.ctorIdx✝ x✝³ then
B.ListTree.match_on_same_ctor✝ x✝² x✝³ h✝¹ (@fun => isTrue✝¹ rfl✝)
@fun a✝¹ a✝² b✝¹ b✝² =>
let inst✝¹ := decEqTree✝ @a✝¹ @b✝¹;
if h✝² : @a✝¹ = @b✝¹ then by subst h✝²;
exact
let inst✝² := decEqListTree✝ @a✝² @b✝²;
if h✝³ : @a✝² = @b✝² then by subst h✝³; exact isTrue✝² rfl✝²
else isFalse✝¹ (by intro n✝¹; injection n✝¹; apply h✝³ _; assumption)
else isFalse✝² (by intro n✝²; injection n✝²; apply h✝² _; assumption)
else isFalse✝³ (fun h'✝ => h✝¹ (congrArg✝ B.ListTree.ctorIdx✝ h'✝))
termination_by structural x✝²
end,
instance : DecidableEq✝ (@B.ListTree✝) :=
decEqListTree✝]
[Elab.Deriving.decEq]
[mutual
def decEqFoo₁✝ (x✝ : @B.Foo₁✝) (x✝¹ : @B.Foo₁✝) : Decidable✝ (x✝ = x✝¹) :=
if h✝ : B.Foo₁.ctorIdx✝ x✝ = B.Foo₁.ctorIdx✝ x✝¹ then
B.Foo₁.match_on_same_ctor✝ x✝ x✝¹ h✝ (@fun => isTrue✝ rfl✝)
@fun a✝ b✝ =>
let inst✝ := decEqFoo₂✝ @a✝ @b✝;
if h✝¹ : @a✝ = @b✝ then by subst h✝¹; exact isTrue✝¹ rfl✝¹
else isFalse✝ (by intro n✝; injection n✝; apply h✝¹ _; assumption)
else isFalse✝¹ (fun h'✝ => h✝ (congrArg✝ B.Foo₁.ctorIdx✝ h'✝))
termination_by structural x✝
def decEqFoo₂✝ (x✝² : @B.Foo₂✝) (x✝³ : @B.Foo₂✝) : Decidable✝ (x✝² = x✝³) :=
B.Foo₂.match_on_same_ctor✝ x✝² x✝³ rfl✝
@fun a✝¹ b✝¹ =>
let inst✝¹ := decEqFoo₃✝ @a✝¹ @b✝¹;
if h✝² : @a✝¹ = @b✝¹ then by subst h✝²; exact isTrue✝² rfl✝²
else isFalse✝² (by intro n✝¹; injection n✝¹; apply h✝² _; assumption)
termination_by structural x✝²
def decEqFoo₃✝ (x✝⁴ : @B.Foo₃✝) (x✝⁵ : @B.Foo₃✝) : Decidable✝ (x✝⁴ = x✝⁵) :=
B.Foo₃.match_on_same_ctor✝ x✝⁴ x✝⁵ rfl✝
@fun a✝² b✝² =>
let inst✝² := decEqFoo₁✝ @a✝² @b✝²;
if h✝³ : @a✝² = @b✝² then by subst h✝³; exact isTrue✝³ rfl✝³
else isFalse✝³ (by intro n✝²; injection n✝²; apply h✝³ _; assumption)
termination_by structural x✝⁴
end,
instance : DecidableEq✝ (@B.Foo₁✝) :=
decEqFoo₁✝]

View File

@@ -0,0 +1,162 @@
import Lean
/-! This tests and documents the constructions in CasesOnSameCtor. -/
open Lean Meta
inductive Vec (α : Type u) : Nat Type u
| nil : Vec α 0
| cons : α {n : Nat} Vec α n Vec α (n+1)
namespace Vec
-- set_option debug.skipKernelTC true
run_meta mkCasesOnSameCtor `Vec.match_on_same_ctor ``Vec
/--
info: Vec.match_on_same_ctor.het.{u_1, u} {α : Type u} {motive : {a : Nat} → Vec α a → {a : Nat} → Vec α a → Sort u_1}
{a✝ : Nat} (t : Vec α a✝) {a✝¹ : Nat} (t✝ : Vec α a✝¹) (h : t.ctorIdx = t✝.ctorIdx) (nil : motive nil nil)
(cons :
(a : α) →
{n : Nat} → (a_1 : Vec α n) → (a_2 : α) → {n_1 : Nat} → (a_3 : Vec α n_1) → motive (cons a a_1) (cons a_2 a_3)) :
motive t t✝
-/
#guard_msgs in
#check Vec.match_on_same_ctor.het
/--
info: Vec.match_on_same_ctor.{u_1, u} {α : Type u}
{motive : {a : Nat} → (t t_1 : Vec α a) → t.ctorIdx = t_1.ctorIdx → Sort u_1} {a✝ : Nat} (t t✝ : Vec α a✝)
(h : t.ctorIdx = t✝.ctorIdx) (nil : motive nil nil ⋯)
(cons : (a : α) → {n : Nat} → (a_1 : Vec α n) → (a_2 : α) → (a_3 : Vec α n) → motive (cons a a_1) (cons a_2 a_3) ⋯) :
motive t t✝ h
-/
#guard_msgs in
#check Vec.match_on_same_ctor
-- Splitter and equations are generated
/--
info: Vec.match_on_same_ctor.splitter.{u_1, u} {α : Type u}
{motive : {a : Nat} → (t t_1 : Vec α a) → t.ctorIdx = t_1.ctorIdx → Sort u_1} {a✝ : Nat} (t t✝ : Vec α a✝)
(h : t.ctorIdx = t✝.ctorIdx) (h_1 : motive nil nil ⋯)
(h_2 : (a : α) → (n : Nat) → (a_1 : Vec α n) → (a_2 : α) → (a_3 : Vec α n) → motive (cons a a_1) (cons a_2 a_3) ⋯) :
motive t t✝ h
-/
#guard_msgs in
#check Vec.match_on_same_ctor.splitter
-- Since there is no overlap, the splitter is equal to the matcher
-- (I wonder if we should use this in general in MatchEq)
example : @Vec.match_on_same_ctor = @Vec.match_on_same_ctor.splitter := by rfl
/--
info: Vec.match_on_same_ctor.eq_2.{u_1, u} {α : Type u}
{motive : {a : Nat} → (t t_1 : Vec α a) → t.ctorIdx = t_1.ctorIdx → Sort u_1} (a✝ : α) (n : Nat) (a✝¹ : Vec α n)
(a✝² : α) (a✝³ : Vec α n) (nil : motive nil nil ⋯)
(cons : (a : α) → {n : Nat} → (a_1 : Vec α n) → (a_2 : α) → (a_3 : Vec α n) → motive (cons a a_1) (cons a_2 a_3) ⋯) :
(match n + 1, Vec.cons a✝ a✝¹, Vec.cons a✝² a✝³ with
| 0, Vec.nil, Vec.nil, ⋯ => nil
| n + 1, Vec.cons a a_1, Vec.cons a_2 a_3, ⋯ => cons a a_1 a_2 a_3) =
cons a✝ a✝¹ a✝² a✝³
-/
#guard_msgs in
#check Vec.match_on_same_ctor.eq_2
-- Recursion works
-- set_option trace.split.debug true
-- set_option trace.split.failure true
-- set_option trace.Elab.definition.structural.eqns true
def decEqVec {α} {a} [DecidableEq α] (x : @Vec α a) (x_1 : @Vec α a) : Decidable (x = x_1) :=
if h : Vec.ctorIdx x = Vec.ctorIdx x_1 then
Vec.match_on_same_ctor x x_1 h (isTrue rfl)
@fun a_1 _ a_2 b b_1 =>
if h_1 : @a_1 = @b then by
subst h_1
exact
let inst := decEqVec @a_2 @b_1;
if h_2 : @a_2 = @b_1 then by subst h_2; exact isTrue rfl
else isFalse (by intro n; injection n; apply h_2 _; assumption)
else isFalse (by intro n_1; injection n_1; apply h_1 _; assumption)
else isFalse (fun h' => h (congrArg Vec.ctorIdx h'))
termination_by structural x
-- Equation generation and pretty match syntax:
/--
info: theorem Vec.decEqVec.eq_def.{u_1} : ∀ {α : Type u_1} {a : Nat} [inst : DecidableEq α] (x x_1 : Vec α a),
x.decEqVec x_1 =
if h : x.ctorIdx = x_1.ctorIdx then
match a, x, x_1 with
| 0, Vec.nil, Vec.nil, ⋯ => isTrue ⋯
| x + 1, Vec.cons a_1 a_2, Vec.cons b b_1, ⋯ =>
if h_1 : a_1 = b then
h_1 ▸
have inst_1 := a_2.decEqVec b_1;
if h_2 : a_2 = b_1 then
h_2 ▸
have inst := a_2.decEqVec a_2;
isTrue ⋯
else isFalse ⋯
else isFalse ⋯
else isFalse ⋯
-/
#guard_msgs(pass trace, all) in
#print sig decEqVec.eq_def
-- Incidentially, normal match syntax is able to produce an equivalent matcher
-- (with different implementation):
-- (see #10195 for problems with equation generation)
def decEqVecPlain {α} {a} [DecidableEq α] (x : @Vec α a) (x_1 : @Vec α a) : Decidable (x = x_1) :=
if h : Vec.ctorIdx x = Vec.ctorIdx x_1 then
match x, x_1, h with
| Vec.nil, Vec.nil, _ => isTrue rfl
| Vec.cons a_1 a_2, Vec.cons b b_1, _ =>
if h_1 : @a_1 = @b then by
subst h_1
exact
let inst := decEqVecPlain @a_2 @b_1;
if h_2 : @a_2 = @b_1 then by subst h_2; exact isTrue rfl
else isFalse (by intro n; injection n; apply h_2 _; assumption)
else isFalse (by intro n_1; injection n_1; apply h_1 _; assumption)
else isFalse (fun h' => h (congrArg Vec.ctorIdx h'))
termination_by structural x
end Vec
namespace List
-- set_option debug.skipKernelTC true
-- set_option trace.compiler.ir.result true
run_meta mkCasesOnSameCtor `List.match_on_same_ctor ``List
/--
info: List.match_on_same_ctor.{u_1, u} {α : Type u} {motive : (t t_1 : List α) → t.ctorIdx = t_1.ctorIdx → Sort u_1}
(t t✝ : List α) (h : t.ctorIdx = t✝.ctorIdx) (nil : motive [] [] ⋯)
(cons :
(head : α) → (tail : List α) → (head_1 : α) → (tail_1 : List α) → motive (head :: tail) (head_1 :: tail_1) ⋯) :
motive t t✝ h
-/
#guard_msgs in
#check List.match_on_same_ctor
end List
namespace BadIdx
opaque f : Nat Nat
inductive T : (n : Nat) Type where
| mk1 : Fin n T (f n)
| mk2 : Fin (2*n) T (f n)
run_meta mkCasesOnSameCtorHet `BadIdx.casesOn2Het ``T
/--
error: Dependent elimination failed: Failed to solve equation
f n✝ = f n
-/
#guard_msgs in
run_meta mkCasesOnSameCtor `BadIdx.casesOn2 ``T
end BadIdx

View File

@@ -0,0 +1,58 @@
/-!
This test checks what deriving `DecidableEq` does when the inductive type has
non-injective indices, and just how bad the error messages are.
-/
opaque f : Nat Nat
set_option deriving.decEq.linear_construction_threshold 0
/--
error: Tactic `cases` failed with a nested error:
Dependent elimination failed: Failed to solve equation
f n✝¹ = f n✝
at case `T.mk1` after processing
_, (T.mk1 _ _), _
the dependent pattern matcher can solve the following kinds of equations
- <var> = <term> and <term> = <var>
- <term> = <term> where the terms are definitionally equal
- <constructor> = <constructor>, examples: List.cons x xs = List.cons y ys, and List.cons x xs = List.nil
---
error: Dependent elimination failed: Failed to solve equation
f n✝ = f n
-/
#guard_msgs(pass trace, all) in
inductive T : (n : Nat) Type where
| mk1 : Fin n T (f n)
| mk2 : Fin (2*n) T (f n)
deriving BEq, DecidableEq
set_option deriving.decEq.linear_construction_threshold 10000
/--
error: Tactic `cases` failed with a nested error:
Dependent elimination failed: Failed to solve equation
f n✝¹ = f n✝
at case `T'.mk1` after processing
_, (T'.mk1 _ _), _
the dependent pattern matcher can solve the following kinds of equations
- <var> = <term> and <term> = <var>
- <term> = <term> where the terms are definitionally equal
- <constructor> = <constructor>, examples: List.cons x xs = List.cons y ys, and List.cons x xs = List.nil
---
error: Tactic `cases` failed with a nested error:
Dependent elimination failed: Failed to solve equation
f n✝¹ = f n✝
at case `T'.mk1` after processing
_, (T'.mk1 _ _), _
the dependent pattern matcher can solve the following kinds of equations
- <var> = <term> and <term> = <var>
- <term> = <term> where the terms are definitionally equal
- <constructor> = <constructor>, examples: List.cons x xs = List.cons y ys, and List.cons x xs = List.nil
-/
#guard_msgs(pass trace, all) in
inductive T' : (n : Nat) Type where
| mk1 : Fin n T' (f n)
| mk2 : Fin (2*n) T' (f n)
deriving BEq, DecidableEq

View File

@@ -0,0 +1,39 @@
/-!
Tests for deriving decidable equality using the linear-size parallel match construction that takes
`x1.ctorIdx = x2.ctorIdx` as assumption.
-/
-- We always want to use the new construction in this test
set_option deriving.decEq.linear_construction_threshold 0
inductive EmptyType : Type
deriving DecidableEq
structure SimpleStruct where
field : Bool
deriving DecidableEq
inductive DependentStruct1 : Nat Type where
| mk (n : Nat) (x : Fin n): DependentStruct1 n
deriving DecidableEq
/--
error: Dependent elimination failed: Failed to solve equation
Decidable.rec (fun h => (fun x => 1) h) (fun h => (fun x => 0) h) (instDecidableEqBool b✝ true) =
Decidable.rec (fun h => (fun x => 1) h) (fun h => (fun x => 0) h) (instDecidableEqBool b true)
-/
#guard_msgs in
inductive DependentStruct2 : Nat Type where
| mk (b : Bool) : DependentStruct2 (if b then 0 else 1)
deriving DecidableEq
inductive Vec (α : Type u) : Nat Type u
| nil : Vec α 0
| cons : α {n : Nat} Vec α n Vec α (n+1)
deriving DecidableEq
inductive Test (α : Type)
| mk₀
| mk₁ : (n : Nat) (α × α) List α Vec α n Test α
| mk₂ : Test α α Test α
deriving DecidableEq

View File

@@ -0,0 +1,5 @@
def test1 (x1 x2 : List α) (h : x2.ctorIdx = x1.ctorIdx) : Bool :=
match x1, x2, h with
| .nil, .nil, _h => true
| .cons _h1 _t1, .cons _h2 _t2, _h => false
-- NB: This is a complete pattern match

View File

@@ -5,13 +5,6 @@ inductive Foo (α : Type u) where
| mk4 (val : String)
| mk5 (head : α) (tail : Foo α)
def Foo.ctorIdx : Foo α Nat
| .mk1 .. => 0
| .mk2 .. => 1
| .mk3 .. => 2
| .mk4 .. => 3
| .mk5 .. => 4
@[elab_as_elim]
def Foo.elimCtor1 {motive : Foo α Sort v} (a : Foo α) (hIdx : a.ctorIdx == 0) (h : (val : α) motive (Foo.mk1 val)) : motive a :=
match a with