Compare commits

...

7 Commits

Author SHA1 Message Date
Leonardo de Moura
e080afbe38 test: for simp [x] where x is a let-variable 2023-10-23 19:06:23 -07:00
Leonardo de Moura
1c1436c569 feat: add support for expanding let-declarations to simp
Given a local context containing `x : t := e`,
`simp (config := { zeta := false }) [x]` will expand `x` even
if `zeta := false`.
2023-10-23 19:03:31 -07:00
Leonardo de Moura
54577bc30d chore: fix configuration for UnificationHints 2023-10-23 06:10:08 -07:00
Leonardo de Moura
ab66adaba6 chore: typos and PR feedback
Co-authored-by: Scott Morrison <scott.morrison@gmail.com>

Co-authored-by: Scott Morrison <scott.morrison@gmail.com>

Co-authored-by: Scott Morrison <scott.morrison@gmail.com>
2023-10-23 06:06:55 -07:00
Leonardo de Moura
bc03c21c21 fix: fixes #2669 #2281 2023-10-22 14:53:55 -07:00
Leonardo de Moura
6cd22c45b5 refactor: add configuration options to control WHNF
This commit also removes parameter `simpleReduce` from discrimination
trees, and take WHNF configuration options.
Reason: it is more dynamic now. For example, the simplifier
will be able to use different configurations for discrimination tree insertion
and retrieval. We need this feature to address issues #2669 and #2281

This commit also removes the dead Meta.Config field `zetaNonDep`.
2023-10-22 14:29:55 -07:00
Leonardo de Moura
1fb77d0855 chore: add some doc strings and cleanup 2023-10-22 09:17:17 -07:00
18 changed files with 392 additions and 268 deletions

View File

@@ -101,6 +101,17 @@ private def addDeclToUnfoldOrTheorem (thms : Meta.SimpTheorems) (id : Origin) (e
return thms.addDeclToUnfoldCore declName
else
thms.addDeclToUnfold declName
else if e.isFVar then
let fvarId := e.fvarId!
let decl fvarId.getDecl
if ( isProp decl.type) then
thms.add id #[] e (post := post) (inv := inv)
else if !decl.isLet then
throwError "invalid argument, variable is not a proposition or let-declaration"
else if inv then
throwError "invalid '←' modifier, '{e}' is a let-declaration name to be unfolded"
else
return thms.addLetDeclToUnfold fvarId
else
thms.add id #[] e (post := post) (inv := inv)
@@ -237,6 +248,10 @@ def mkSimpContext (stx : Syntax) (eraseLocal : Bool) (kind := SimpKind.simp) (ig
else
let ctx := r.ctx
let mut simpTheorems := ctx.simpTheorems
/-
When using `zeta := false`, we do not expand let-declarations when using `[*]`.
Users must explicitly include it in the list.
-/
let hs getPropHyps
for h in hs do
unless simpTheorems.isErased (.fvar h) do

View File

@@ -60,8 +60,8 @@ where
-- Drawback: cost.
return e
else match mode with
| .reduce => DiscrTree.reduce e (simpleReduce := false)
| .reduceSimpleOnly => DiscrTree.reduce e (simpleReduce := true)
| .reduce => DiscrTree.reduce e {}
| .reduceSimpleOnly => DiscrTree.reduce e { iota := false, proj := .no }
| .none => return e
lt (a b : Expr) : MetaM Bool := do

View File

@@ -78,14 +78,6 @@ structure Config where
we may want to notify the caller that the TC problem may be solvable
later after it assigns `?m`. -/
isDefEqStuckEx : Bool := false
/--
Controls which definitions and theorems can be unfolded by `isDefEq` and `whnf`.
-/
transparency : TransparencyMode := TransparencyMode.default
/-- If zetaNonDep == false, then non dependent let-decls are not zeta expanded. -/
zetaNonDep : Bool := true
/-- When `trackZeta == true`, we store zetaFVarIds all free variables that have been zeta-expanded. -/
trackZeta : Bool := false
/-- Enable/disable the unification hints feature. -/
unificationHints : Bool := true
/-- Enables proof irrelevance at `isDefEq` -/
@@ -99,8 +91,24 @@ structure Config where
assignSyntheticOpaque : Bool := false
/-- Enable/Disable support for offset constraints such as `?x + 1 =?= e` -/
offsetCnstrs : Bool := true
/--
Controls which definitions and theorems can be unfolded by `isDefEq` and `whnf`.
-/
transparency : TransparencyMode := TransparencyMode.default
/--
When `trackZeta = true`, we track all free variables that have been zeta-expanded.
That is, suppose the local context contains
the declaration `x : t := v`, and we reduce `x` to `v`, then we insert `x` into `State.zetaFVarIds`.
We use `trackZeta` to discover which let-declarations `let x := v; e` can be represented as `(fun x => e) v`.
When we find these declarations we set their `nonDep` flag with `true`.
To find these let-declarations in a given term `s`, we
1- Reset `State.zetaFVarIds`
2- Set `trackZeta := true`
3- Type-check `s`.
-/
trackZeta : Bool := false
/-- Eta for structures configuration mode. -/
etaStruct : EtaStructMode := .all
etaStruct : EtaStructMode := .all
/--
Function parameter information cache.
@@ -366,7 +374,7 @@ section Methods
variable [MonadControlT MetaM n] [Monad n]
@[inline] def modifyCache (f : Cache Cache) : MetaM Unit :=
modify fun mctx, cache, zetaFVarIds, postponed => mctx, f cache, zetaFVarIds, postponed
modify fun { mctx, cache, zetaFVarIds, postponed } => { mctx, cache := f cache, zetaFVarIds, postponed }
@[inline] def modifyInferTypeCache (f : InferTypeCache InferTypeCache) : MetaM Unit :=
modifyCache fun ic, c1, c2, c3, c4, c5, c6 => f ic, c1, c2, c3, c4, c5, c6
@@ -781,6 +789,9 @@ def elimMVarDeps (xs : Array Expr) (e : Expr) (preserveOrder : Bool := false) :
@[inline] def withConfig (f : Config Config) : n α n α :=
mapMetaM <| withReader (fun ctx => { ctx with config := f ctx.config })
/--
Executes `x` tracking zeta reductions `Config.trackZeta := true`
-/
@[inline] def withTrackingZeta (x : n α) : n α :=
withConfig (fun cfg => { cfg with trackZeta := true }) x

View File

@@ -48,7 +48,7 @@ namespace Lean.Meta.DiscrTree
2- Distinguish partial applications `f a`, `f a b`, and `f a b c`.
-/
def Key.ctorIdx : Key s Nat
def Key.ctorIdx : Key Nat
| .star => 0
| .other => 1
| .lit .. => 2
@@ -57,17 +57,17 @@ def Key.ctorIdx : Key s → Nat
| .arrow => 5
| .proj .. => 6
def Key.lt : Key s Key s Bool
def Key.lt : Key Key Bool
| .lit v₁, .lit v₂ => v₁ < v₂
| .fvar n₁ a₁, .fvar n₂ a₂ => Name.quickLt n₁.name n₂.name || (n₁ == n₂ && a₁ < a₂)
| .const n₁ a₁, .const n₂ a₂ => Name.quickLt n₁ n₂ || (n₁ == n₂ && a₁ < a₂)
| .proj s₁ i₁ a₁, .proj s₂ i₂ a₂ => Name.quickLt s₁ s₂ || (s₁ == s₂ && i₁ < i₂) || (s₁ == s₂ && i₁ == i₂ && a₁ < a₂)
| k₁, k₂ => k₁.ctorIdx < k₂.ctorIdx
instance : LT (Key s) := fun a b => Key.lt a b
instance (a b : Key s) : Decidable (a < b) := inferInstanceAs (Decidable (Key.lt a b))
instance : LT Key := fun a b => Key.lt a b
instance (a b : Key) : Decidable (a < b) := inferInstanceAs (Decidable (Key.lt a b))
def Key.format : Key s Format
def Key.format : Key Format
| .star => "*"
| .other => ""
| .lit (Literal.natVal v) => Std.format v
@@ -77,41 +77,41 @@ def Key.format : Key s → Format
| .fvar k _ => Std.format k.name
| .arrow => ""
instance : ToFormat (Key s) := Key.format
instance : ToFormat Key := Key.format
def Key.arity : (Key s) Nat
def Key.arity : Key Nat
| .const _ a => a
| .fvar _ a => a
| .arrow => 2
| .proj _ _ a => 1 + a
| _ => 0
instance : Inhabited (Trie α s) := .node #[] #[]
instance : Inhabited (Trie α) := .node #[] #[]
def empty : DiscrTree α s := { root := {} }
def empty : DiscrTree α := { root := {} }
partial def Trie.format [ToFormat α] : Trie α s Format
partial def Trie.format [ToFormat α] : Trie α Format
| .node vs cs => Format.group $ Format.paren $
"node" ++ (if vs.isEmpty then Format.nil else " " ++ Std.format vs)
++ Format.join (cs.toList.map fun k, c => Format.line ++ Format.paren (Std.format k ++ " => " ++ format c))
instance [ToFormat α] : ToFormat (Trie α s) := Trie.format
instance [ToFormat α] : ToFormat (Trie α) := Trie.format
partial def format [ToFormat α] (d : DiscrTree α s) : Format :=
partial def format [ToFormat α] (d : DiscrTree α) : Format :=
let (_, r) := d.root.foldl
(fun (p : Bool × Format) k c =>
(false, p.2 ++ (if p.1 then Format.nil else Format.line) ++ Format.paren (Std.format k ++ " => " ++ Std.format c)))
(true, Format.nil)
Format.group r
instance [ToFormat α] : ToFormat (DiscrTree α s) := format
instance [ToFormat α] : ToFormat (DiscrTree α) := format
/-- The discrimination tree ignores implicit arguments and proofs.
We use the following auxiliary id as a "mark". -/
private def tmpMVarId : MVarId := { name := `_discr_tree_tmp }
private def tmpStar := mkMVar tmpMVarId
instance : Inhabited (DiscrTree α s) where
instance : Inhabited (DiscrTree α) where
default := {}
/--
@@ -249,16 +249,16 @@ def hasNoindexAnnotation (e : Expr) : Bool :=
/--
Reduction procedure for the discrimination tree indexing.
The parameter `simpleReduce` controls how aggressive the term is reduced.
The parameter `config` controls how aggressively the term is reduced.
The parameter at type `DiscrTree` controls this value.
See comment at `DiscrTree`.
-/
partial def reduce (e : Expr) (simpleReduce : Bool) : MetaM Expr := do
let e whnfCore e (simpleReduceOnly := simpleReduce)
partial def reduce (e : Expr) (config : WhnfCoreConfig) : MetaM Expr := do
let e whnfCore e config
match ( unfoldDefinition? e) with
| some e => reduce e simpleReduce
| some e => reduce e config
| none => match e.etaExpandedStrict? with
| some e => reduce e simpleReduce
| some e => reduce e config
| none => return e
/--
@@ -307,31 +307,31 @@ private def elimLooseBVarsByBeta (e : Expr) : CoreM Expr :=
Reduce `e` until we get an irreducible term (modulo current reducibility setting) or the resulting term
is a bad key (see comment at `isBadKey`).
We use this method instead of `reduce` for root terms at `pushArgs`. -/
private partial def reduceUntilBadKey (e : Expr) (simpleReduce : Bool) : MetaM Expr := do
private partial def reduceUntilBadKey (e : Expr) (config : WhnfCoreConfig) : MetaM Expr := do
let e step e
match e.etaExpandedStrict? with
| some e => reduceUntilBadKey e simpleReduce
| some e => reduceUntilBadKey e config
| none => return e
where
step (e : Expr) := do
let e whnfCore e (simpleReduceOnly := simpleReduce)
let e whnfCore e config
match ( unfoldDefinition? e) with
| some e' => if isBadKey e'.getAppFn then return e else step e'
| none => return e
/-- whnf for the discrimination tree module -/
def reduceDT (e : Expr) (root : Bool) (simpleReduce : Bool) : MetaM Expr :=
if root then reduceUntilBadKey e simpleReduce else reduce e simpleReduce
def reduceDT (e : Expr) (root : Bool) (config : WhnfCoreConfig) : MetaM Expr :=
if root then reduceUntilBadKey e config else reduce e config
/- Remark: we use `shouldAddAsStar` only for nested terms, and `root == false` for nested terms -/
private def pushArgs (root : Bool) (todo : Array Expr) (e : Expr) : MetaM (Key s × Array Expr) := do
private def pushArgs (root : Bool) (todo : Array Expr) (e : Expr) (config : WhnfCoreConfig) : MetaM (Key × Array Expr) := do
if hasNoindexAnnotation e then
return (.star, todo)
else
let e reduceDT e root (simpleReduce := s)
let e reduceDT e root config
let fn := e.getAppFn
let push (k : Key s) (nargs : Nat) (todo : Array Expr): MetaM (Key s × Array Expr) := do
let push (k : Key) (nargs : Nat) (todo : Array Expr): MetaM (Key × Array Expr) := do
let info getFunInfoNArgs fn nargs
let todo pushArgsAux info.paramInfo (nargs-1) e todo
return (k, todo)
@@ -377,24 +377,24 @@ private def pushArgs (root : Bool) (todo : Array Expr) (e : Expr) : MetaM (Key s
| _ =>
return (.other, todo)
partial def mkPathAux (root : Bool) (todo : Array Expr) (keys : Array (Key s)) : MetaM (Array (Key s)) := do
partial def mkPathAux (root : Bool) (todo : Array Expr) (keys : Array Key) (config : WhnfCoreConfig) : MetaM (Array Key) := do
if todo.isEmpty then
return keys
else
let e := todo.back
let todo := todo.pop
let (k, todo) pushArgs root todo e
mkPathAux false todo (keys.push k)
let (k, todo) pushArgs root todo e config
mkPathAux false todo (keys.push k) config
private def initCapacity := 8
def mkPath (e : Expr) : MetaM (Array (Key s)) := do
def mkPath (e : Expr) (config : WhnfCoreConfig) : MetaM (Array Key) := do
withReducible do
let todo : Array Expr := .mkEmpty initCapacity
let keys : Array (Key s) := .mkEmpty initCapacity
mkPathAux (root := true) (todo.push e) keys
let keys : Array Key := .mkEmpty initCapacity
mkPathAux (root := true) (todo.push e) keys config
private partial def createNodes (keys : Array (Key s)) (v : α) (i : Nat) : Trie α s :=
private partial def createNodes (keys : Array Key) (v : α) (i : Nat) : Trie α :=
if h : i < keys.size then
let k := keys.get i, h
let c := createNodes keys v (i+1)
@@ -421,20 +421,20 @@ where
vs.push v
termination_by loop i => vs.size - i
private partial def insertAux [BEq α] (keys : Array (Key s)) (v : α) : Nat Trie α s Trie α s
private partial def insertAux [BEq α] (keys : Array Key) (v : α) (config : WhnfCoreConfig) : Nat Trie α Trie α
| i, .node vs cs =>
if h : i < keys.size then
let k := keys.get i, h
let c := Id.run $ cs.binInsertM
(fun a b => a.1 < b.1)
(fun _, s => let c := insertAux keys v (i+1) s; (k, c)) -- merge with existing
(fun _, s => let c := insertAux keys v config (i+1) s; (k, c)) -- merge with existing
(fun _ => let c := createNodes keys v (i+1); (k, c))
(k, default)
.node vs c
else
.node (insertVal vs v) cs
def insertCore [BEq α] (d : DiscrTree α s) (keys : Array (Key s)) (v : α) : DiscrTree α s :=
def insertCore [BEq α] (d : DiscrTree α) (keys : Array Key) (v : α) (config : WhnfCoreConfig) : DiscrTree α :=
if keys.isEmpty then panic! "invalid key sequence"
else
let k := keys[0]!
@@ -443,15 +443,15 @@ def insertCore [BEq α] (d : DiscrTree α s) (keys : Array (Key s)) (v : α) : D
let c := createNodes keys v 1
{ root := d.root.insert k c }
| some c =>
let c := insertAux keys v 1 c
let c := insertAux keys v config 1 c
{ root := d.root.insert k c }
def insert [BEq α] (d : DiscrTree α s) (e : Expr) (v : α) : MetaM (DiscrTree α s) := do
let keys mkPath e
return d.insertCore keys v
def insert [BEq α] (d : DiscrTree α) (e : Expr) (v : α) (config : WhnfCoreConfig) : MetaM (DiscrTree α) := do
let keys mkPath e config
return d.insertCore keys v config
private def getKeyArgs (e : Expr) (isMatch root : Bool) : MetaM (Key s × Array Expr) := do
let e reduceDT e root (simpleReduce := s)
private def getKeyArgs (e : Expr) (isMatch root : Bool) (config : WhnfCoreConfig) : MetaM (Key × Array Expr) := do
let e reduceDT e root config
unless root do
-- See pushArgs
if let some v := toNatLit? e then
@@ -530,22 +530,22 @@ private def getKeyArgs (e : Expr) (isMatch root : Bool) : MetaM (Key s × Array
| _ =>
return (.other, #[])
private abbrev getMatchKeyArgs (e : Expr) (root : Bool) : MetaM (Key s × Array Expr) :=
getKeyArgs e (isMatch := true) (root := root)
private abbrev getMatchKeyArgs (e : Expr) (root : Bool) (config : WhnfCoreConfig) : MetaM (Key × Array Expr) :=
getKeyArgs e (isMatch := true) (root := root) (config := config)
private abbrev getUnifyKeyArgs (e : Expr) (root : Bool) : MetaM (Key s × Array Expr) :=
getKeyArgs e (isMatch := false) (root := root)
private abbrev getUnifyKeyArgs (e : Expr) (root : Bool) (config : WhnfCoreConfig) : MetaM (Key × Array Expr) :=
getKeyArgs e (isMatch := false) (root := root) (config := config)
private def getStarResult (d : DiscrTree α s) : Array α :=
private def getStarResult (d : DiscrTree α) : Array α :=
let result : Array α := .mkEmpty initCapacity
match d.root.find? .star with
| none => result
| some (.node vs _) => result ++ vs
private abbrev findKey (cs : Array (Key s × Trie α s)) (k : Key s) : Option (Key s × Trie α s) :=
private abbrev findKey (cs : Array (Key × Trie α)) (k : Key) : Option (Key × Trie α) :=
cs.binSearch (k, default) (fun a b => a.1 < b.1)
private partial def getMatchLoop (todo : Array Expr) (c : Trie α s) (result : Array α) : MetaM (Array α) := do
private partial def getMatchLoop (todo : Array Expr) (c : Trie α) (result : Array α) (config : WhnfCoreConfig) : MetaM (Array α) := do
match c with
| .node vs cs =>
if todo.isEmpty then
@@ -556,19 +556,19 @@ private partial def getMatchLoop (todo : Array Expr) (c : Trie α s) (result : A
let e := todo.back
let todo := todo.pop
let first := cs[0]! /- Recall that `Key.star` is the minimal key -/
let (k, args) getMatchKeyArgs e (root := false)
let (k, args) getMatchKeyArgs e (root := false) config
/- We must always visit `Key.star` edges since they are wildcards.
Thus, `todo` is not used linearly when there is `Key.star` edge
and there is an edge for `k` and `k != Key.star`. -/
let visitStar (result : Array α) : MetaM (Array α) :=
if first.1 == .star then
getMatchLoop todo first.2 result
getMatchLoop todo first.2 result config
else
return result
let visitNonStar (k : Key s) (args : Array Expr) (result : Array α) : MetaM (Array α) :=
let visitNonStar (k : Key) (args : Array Expr) (result : Array α) : MetaM (Array α) :=
match findKey cs k with
| none => return result
| some c => getMatchLoop (todo ++ args) c.2 result
| some c => getMatchLoop (todo ++ args) c.2 result config
let result visitStar result
match k with
| .star => return result
@@ -580,32 +580,32 @@ private partial def getMatchLoop (todo : Array Expr) (c : Trie α s) (result : A
| .arrow => visitNonStar .other #[] ( visitNonStar k args result)
| _ => visitNonStar k args result
private def getMatchRoot (d : DiscrTree α s) (k : Key s) (args : Array Expr) (result : Array α) : MetaM (Array α) :=
private def getMatchRoot (d : DiscrTree α) (k : Key) (args : Array Expr) (result : Array α) (config : WhnfCoreConfig) : MetaM (Array α) :=
match d.root.find? k with
| none => return result
| some c => getMatchLoop args c result
| some c => getMatchLoop args c result config
private def getMatchCore (d : DiscrTree α s) (e : Expr) : MetaM (Key s × Array α) :=
private def getMatchCore (d : DiscrTree α) (e : Expr) (config : WhnfCoreConfig) : MetaM (Key × Array α) :=
withReducible do
let result := getStarResult d
let (k, args) getMatchKeyArgs e (root := true)
let (k, args) getMatchKeyArgs e (root := true) config
match k with
| .star => return (k, result)
/- See note about "dep-arrow vs arrow" at `getMatchLoop` -/
| .arrow => return (k, ( getMatchRoot d k args ( getMatchRoot d .other #[] result)))
| _ => return (k, ( getMatchRoot d k args result))
| .arrow => return (k, ( getMatchRoot d k args ( getMatchRoot d .other #[] result config) config))
| _ => return (k, ( getMatchRoot d k args result config))
/--
Find values that match `e` in `d`.
-/
def getMatch (d : DiscrTree α s) (e : Expr) : MetaM (Array α) :=
return ( getMatchCore d e).2
def getMatch (d : DiscrTree α) (e : Expr) (config : WhnfCoreConfig) : MetaM (Array α) :=
return ( getMatchCore d e config).2
/--
Similar to `getMatch`, but returns solutions that are prefixes of `e`.
We store the number of ignored arguments in the result.-/
partial def getMatchWithExtra (d : DiscrTree α s) (e : Expr) : MetaM (Array (α × Nat)) := do
let (k, result) getMatchCore d e
partial def getMatchWithExtra (d : DiscrTree α) (e : Expr) (config : WhnfCoreConfig) : MetaM (Array (α × Nat)) := do
let (k, result) getMatchCore d e config
let result := result.map (·, 0)
if !e.isApp then
return result
@@ -614,8 +614,8 @@ partial def getMatchWithExtra (d : DiscrTree α s) (e : Expr) : MetaM (Array (α
else
go e.appFn! 1 result
where
mayMatchPrefix (k : Key s) : MetaM Bool :=
let cont (k : Key s) : MetaM Bool :=
mayMatchPrefix (k : Key) : MetaM Bool :=
let cont (k : Key) : MetaM Bool :=
if d.root.find? k |>.isSome then
return true
else
@@ -627,15 +627,15 @@ where
| _ => return false
go (e : Expr) (numExtra : Nat) (result : Array (α × Nat)) : MetaM (Array (α × Nat)) := do
let result := result ++ ( getMatch d e).map (., numExtra)
let result := result ++ ( getMatchCore d e config).2.map (., numExtra)
if e.isApp then
go e.appFn! (numExtra + 1) result
else
return result
partial def getUnify (d : DiscrTree α s) (e : Expr) : MetaM (Array α) :=
partial def getUnify (d : DiscrTree α) (e : Expr) (config : WhnfCoreConfig) : MetaM (Array α) :=
withReducible do
let (k, args) getUnifyKeyArgs e (root := true)
let (k, args) getUnifyKeyArgs e (root := true) config
match k with
| .star => d.root.foldlM (init := #[]) fun result k c => process k.arity #[] c result
| _ =>
@@ -644,7 +644,7 @@ partial def getUnify (d : DiscrTree α s) (e : Expr) : MetaM (Array α) :=
| none => return result
| some c => process 0 args c result
where
process (skip : Nat) (todo : Array Expr) (c : Trie α s) (result : Array α) : MetaM (Array α) := do
process (skip : Nat) (todo : Array Expr) (c : Trie α) (result : Array α) : MetaM (Array α) := do
match skip, c with
| skip+1, .node _ cs =>
if cs.isEmpty then
@@ -659,14 +659,14 @@ where
else
let e := todo.back
let todo := todo.pop
let (k, args) getUnifyKeyArgs e (root := false)
let (k, args) getUnifyKeyArgs e (root := false) config
let visitStar (result : Array α) : MetaM (Array α) :=
let first := cs[0]!
if first.1 == .star then
process 0 todo first.2 result
else
return result
let visitNonStar (k : Key s) (args : Array Expr) (result : Array α) : MetaM (Array α) :=
let visitNonStar (k : Key) (args : Array Expr) (result : Array α) : MetaM (Array α) :=
match findKey cs k with
| none => return result
| some c => process 0 (todo ++ args) c.2 result

View File

@@ -13,55 +13,50 @@ namespace DiscrTree
/--
Discrimination tree key. See `DiscrTree`
-/
inductive Key (simpleReduce : Bool) where
| const : Name Nat Key simpleReduce
| fvar : FVarId Nat Key simpleReduce
| lit : Literal Key simpleReduce
| star : Key simpleReduce
| other : Key simpleReduce
| arrow : Key simpleReduce
| proj : Name Nat Nat Key simpleReduce
inductive Key where
| const : Name Nat Key
| fvar : FVarId Nat Key
| lit : Literal Key
| star : Key
| other : Key
| arrow : Key
| proj : Name Nat Nat Key
deriving Inhabited, BEq, Repr
protected def Key.hash : Key s UInt64
| Key.const n a => mixHash 5237 $ mixHash (hash n) (hash a)
| Key.fvar n a => mixHash 3541 $ mixHash (hash n) (hash a)
| Key.lit v => mixHash 1879 $ hash v
| Key.star => 7883
| Key.other => 2411
| Key.arrow => 17
| Key.proj s i a => mixHash (hash a) $ mixHash (hash s) (hash i)
protected def Key.hash : Key UInt64
| .const n a => mixHash 5237 $ mixHash (hash n) (hash a)
| .fvar n a => mixHash 3541 $ mixHash (hash n) (hash a)
| .lit v => mixHash 1879 $ hash v
| .star => 7883
| .other => 2411
| .arrow => 17
| .proj s i a => mixHash (hash a) $ mixHash (hash s) (hash i)
instance : Hashable (Key s) := Key.hash
instance : Hashable Key := Key.hash
/--
Discrimination tree trie. See `DiscrTree`.
-/
inductive Trie (α : Type) (simpleReduce : Bool) where
| node (vs : Array α) (children : Array (Key simpleReduce × Trie α simpleReduce)) : Trie α simpleReduce
inductive Trie (α : Type) where
| node (vs : Array α) (children : Array (Key × Trie α)) : Trie α
end DiscrTree
open DiscrTree
/--
Discrimination trees. It is an index from terms to values of type `α`.
/-!
Notes regarding term reduction at the `DiscrTree` module.
If `simpleReduce := true`, then only simple reduction are performed while
indexing/retrieving terms. For example, `iota` reduction is not performed.
We use `simpleReduce := false` in the type class resolution module,
and `simpleReduce := true` in `simp`.
Motivations:
- In `simp`, we want to have `simp` theorem such as
```
@[simp] theorem liftOn_mk (a : α) (f : αγ) (h : ∀ a₁ a₂, r a₁ a₂ → f a₁ = f a₂) :
Quot.liftOn (Quot.mk r a) f h = f a := rfl
```
If we enable `iota`, then the lhs is reduced to `f a`.
Note that when retrieving terms, we may also disable `beta` and `zeta` reduction.
See issue https://github.com/leanprover/lean4/issues/2669
- During type class resolution, we often want to reduce types using even `iota`.
- During type class resolution, we often want to reduce types using even `iota` and projection reductionn.
Example:
```
inductive Ty where
@@ -80,7 +75,11 @@ def f (a b : Ty.bool.interp) : Ty.bool.interp :=
test (.==.) a b
```
-/
structure DiscrTree (α : Type) (simpleReduce : Bool) where
root : PersistentHashMap (Key simpleReduce) (Trie α simpleReduce) := {}
/--
Discrimination trees. It is an index from terms to values of type `α`.
-/
structure DiscrTree (α : Type) where
root : PersistentHashMap Key (Trie α) := {}
end Lean.Meta

View File

@@ -37,7 +37,7 @@ def f (a b : Ty.bool.interp) : Ty.bool.interp :=
See comment at `DiscrTree`.
-/
abbrev InstanceKey := DiscrTree.Key (simpleReduce := false)
abbrev InstanceKey := DiscrTree.Key
structure InstanceEntry where
keys : Array InstanceKey
@@ -63,7 +63,7 @@ instance : ToFormat InstanceEntry where
| some n => format n
| _ => "<local>"
abbrev InstanceTree := DiscrTree InstanceEntry (simpleReduce := false)
abbrev InstanceTree := DiscrTree InstanceEntry
structure Instances where
discrTree : InstanceTree := DiscrTree.empty
@@ -71,10 +71,13 @@ structure Instances where
erased : PHashSet Name := {}
deriving Inhabited
/-- Configuration for the discrimination tree module -/
def tcDtConfig : WhnfCoreConfig := {}
def addInstanceEntry (d : Instances) (e : InstanceEntry) : Instances :=
match e.globalName? with
| some n => { d with discrTree := d.discrTree.insertCore e.keys e, instanceNames := d.instanceNames.insert n e, erased := d.erased.erase n }
| none => { d with discrTree := d.discrTree.insertCore e.keys e }
| some n => { d with discrTree := d.discrTree.insertCore e.keys e tcDtConfig, instanceNames := d.instanceNames.insert n e, erased := d.erased.erase n }
| none => { d with discrTree := d.discrTree.insertCore e.keys e tcDtConfig }
def Instances.eraseCore (d : Instances) (declName : Name) : Instances :=
{ d with erased := d.erased.insert declName, instanceNames := d.instanceNames.erase declName }
@@ -94,7 +97,7 @@ private def mkInstanceKey (e : Expr) : MetaM (Array InstanceKey) := do
let type inferType e
withNewMCtxDepth do
let (_, _, type) forallMetaTelescopeReducing type
DiscrTree.mkPath type
DiscrTree.mkPath type tcDtConfig
/--
Compute the order the arguments of `inst` should by synthesized.
@@ -207,7 +210,7 @@ builtin_initialize
modifyEnv fun env => instanceExtension.modifyState env fun _ => s
}
def getGlobalInstancesIndex : CoreM (DiscrTree InstanceEntry (simpleReduce := false)) :=
def getGlobalInstancesIndex : CoreM (DiscrTree InstanceEntry) :=
return Meta.instanceExtension.getState ( getEnv) |>.discrTree
def getErasedInstances : CoreM (PHashSet Name) :=

View File

@@ -194,7 +194,7 @@ def getInstances (type : Expr) : MetaM (Array Instance) := do
| none => throwError "type class instance expected{indentExpr type}"
| some className =>
let globalInstances getGlobalInstancesIndex
let result globalInstances.getUnify type
let result globalInstances.getUnify type tcDtConfig
-- Using insertion sort because it is stable and the array `result` should be mostly sorted.
-- Most instances have default priority.
let result := result.insertionSort fun e₁ e₂ => e₁.priority < e₂.priority

View File

@@ -119,8 +119,8 @@ private def reduceProjFn? (e : Expr) : SimpM (Option Expr) := do
-- `structure` projections
reduceProjCont? ( unfoldDefinition? e)
private def reduceFVar (cfg : Config) (e : Expr) : MetaM Expr := do
if cfg.zeta then
private def reduceFVar (cfg : Config) (thms : SimpTheoremsArray) (e : Expr) : MetaM Expr := do
if cfg.zeta || thms.isLetDeclToUnfold e.fvarId! then
match ( getFVarLocalDecl e).value? with
| some v => return v
| none => return e
@@ -254,8 +254,8 @@ private partial def dsimp (e : Expr) : M Expr := do
if r.expr != e then
return .visit r.expr
let mut eNew reduce e
if cfg.zeta && eNew.isFVar then
eNew reduceFVar cfg eNew
if eNew.isFVar then
eNew reduceFVar cfg ( getSimpTheorems) eNew
if eNew != e then return .visit eNew else return .done e
transform (usedLetOnly := cfg.zeta) e (pre := pre) (post := post)
@@ -363,7 +363,7 @@ where
| Expr.sort .. => return { expr := e }
| Expr.lit .. => simpLit e
| Expr.mvar .. => return { expr := ( instantiateMVars e) }
| Expr.fvar .. => return { expr := ( reduceFVar ( getConfig) e) }
| Expr.fvar .. => return { expr := ( reduceFVar ( getConfig) ( getSimpTheorems) e) }
simpLit (e : Expr) : M Result := do
match e.natLit? with

View File

@@ -142,11 +142,24 @@ def tryTheorem? (e : Expr) (thm : SimpTheorem) (discharge? : Expr → SimpM (Opt
tryTheoremCore lhs xs bis val type e thm (eNumArgs - lhsNumArgs) discharge?
else
return none
/--
Return a WHNF configuration for retrieving `[simp]` from the discrimination tree.
If user has disabled `zeta` and/or `beta` reduction in the simplifier, we must also
disable them when retrieving lemmas from discrimination tree. See issues: #2669 and #2281
-/
def getDtConfig (cfg : Config) : WhnfCoreConfig :=
match cfg.beta, cfg.zeta with
| true, true => simpDtConfig
| true, false => { simpDtConfig with zeta := false }
| false, true => { simpDtConfig with beta := false }
| false, false => { simpDtConfig with beta := false, zeta := false }
/--
Remark: the parameter tag is used for creating trace messages. It is irrelevant otherwise.
-/
def rewrite? (e : Expr) (s : SimpTheoremTree) (erased : PHashSet Origin) (discharge? : Expr SimpM (Option Expr)) (tag : String) (rflOnly : Bool) : SimpM (Option Result) := do
let candidates s.getMatchWithExtra e
let candidates s.getMatchWithExtra e (getDtConfig ( getConfig))
if candidates.isEmpty then
trace[Debug.Meta.Tactic.simp] "no theorems found for {tag}-rewriting {e}"
return none

View File

@@ -61,7 +61,7 @@ If we use `iota`, then the lhs is reduced to `f a`.
See comment at `DiscrTree`.
-/
abbrev SimpTheoremKey := DiscrTree.Key (simpleReduce := true)
abbrev SimpTheoremKey := DiscrTree.Key
/--
The fields `levelParams` and `proof` are used to encode the proof of the simp theorem.
@@ -151,22 +151,29 @@ def ppSimpTheorem [Monad m] [MonadLiftT IO m] [MonadEnv m] [MonadError m] (s : S
instance : BEq SimpTheorem where
beq e₁ e₂ := e₁.proof == e₂.proof
abbrev SimpTheoremTree := DiscrTree SimpTheorem (simpleReduce := true)
abbrev SimpTheoremTree := DiscrTree SimpTheorem
structure SimpTheorems where
pre : SimpTheoremTree := DiscrTree.empty
post : SimpTheoremTree := DiscrTree.empty
lemmaNames : PHashSet Origin := {}
/--
Constants (and let-declaration `FVarId`) to unfold.
When `zeta := false`, the simplifier will expand a let-declaration if it is in this set.
-/
toUnfold : PHashSet Name := {}
erased : PHashSet Origin := {}
toUnfoldThms : PHashMap Name (Array Name) := {}
deriving Inhabited
/-- Configuration for the discrimination tree. -/
def simpDtConfig : WhnfCoreConfig := { iota := false, proj := .no }
def addSimpTheoremEntry (d : SimpTheorems) (e : SimpTheorem) : SimpTheorems :=
if e.post then
{ d with post := d.post.insertCore e.keys e, lemmaNames := updateLemmaNames d.lemmaNames }
{ d with post := d.post.insertCore e.keys e simpDtConfig, lemmaNames := updateLemmaNames d.lemmaNames }
else
{ d with pre := d.pre.insertCore e.keys e, lemmaNames := updateLemmaNames d.lemmaNames }
{ d with pre := d.pre.insertCore e.keys e simpDtConfig, lemmaNames := updateLemmaNames d.lemmaNames }
where
updateLemmaNames (s : PHashSet Origin) : PHashSet Origin :=
s.insert e.origin
@@ -174,10 +181,17 @@ where
def SimpTheorems.addDeclToUnfoldCore (d : SimpTheorems) (declName : Name) : SimpTheorems :=
{ d with toUnfold := d.toUnfold.insert declName }
def SimpTheorems.addLetDeclToUnfold (d : SimpTheorems) (fvarId : FVarId) : SimpTheorems :=
-- A small hack that relies on the fact that constants and `FVarId` names should be disjoint.
{ d with toUnfold := d.toUnfold.insert fvarId.name }
/-- Return `true` if `declName` is tagged to be unfolded using `unfoldDefinition?` (i.e., without using equational theorems). -/
def SimpTheorems.isDeclToUnfold (d : SimpTheorems) (declName : Name) : Bool :=
d.toUnfold.contains declName
def SimpTheorems.isLetDeclToUnfold (d : SimpTheorems) (fvarId : FVarId) : Bool :=
d.toUnfold.contains fvarId.name -- See comment at `addLetDeclToUnfold`
def SimpTheorems.isLemma (d : SimpTheorems) (thmId : Origin) : Bool :=
d.lemmaNames.contains thmId
@@ -218,7 +232,7 @@ private partial def isPerm : Expr → Expr → MetaM Bool
| s, t => return s == t
private def checkBadRewrite (lhs rhs : Expr) : MetaM Unit := do
let lhs DiscrTree.reduceDT lhs (root := true) (simpleReduce := true)
let lhs DiscrTree.reduceDT lhs (root := true) simpDtConfig
if lhs == rhs && lhs.isFVar then
throwError "invalid `simp` theorem, equation is equivalent to{indentExpr (← mkEq lhs rhs)}"
@@ -305,7 +319,7 @@ private def mkSimpTheoremCore (origin : Origin) (e : Expr) (levelParams : Array
let type whnfR type
let (keys, perm)
match type.eq? with
| some (_, lhs, rhs) => pure ( DiscrTree.mkPath lhs, isPerm lhs rhs)
| some (_, lhs, rhs) => pure ( DiscrTree.mkPath lhs simpDtConfig, isPerm lhs rhs)
| none => throwError "unexpected kind of 'simp' theorem{indentExpr type}"
return { origin, keys, perm, post, levelParams, proof, priority := prio, rfl := ( isRflProof proof) }
@@ -467,6 +481,9 @@ def SimpTheoremsArray.isErased (thmsArray : SimpTheoremsArray) (thmId : Origin)
def SimpTheoremsArray.isDeclToUnfold (thmsArray : SimpTheoremsArray) (declName : Name) : Bool :=
thmsArray.any fun thms => thms.isDeclToUnfold declName
def SimpTheoremsArray.isLetDeclToUnfold (thmsArray : SimpTheoremsArray) (fvarId : FVarId) : Bool :=
thmsArray.any fun thms => thms.isLetDeclToUnfold fvarId
macro (name := _root_.Lean.Parser.Command.registerSimpAttr) doc:(docComment)?
"register_simp_attr" id:ident : command => do
let str := id.getId.toString

View File

@@ -80,8 +80,8 @@ def post (e : Expr) : M Step := do
def discharge? (e : Expr) : M (Option Expr) := do
( read).discharge? e
def getConfig : M Config :=
return ( readThe Context).config
def getConfig : SimpM Config :=
return ( read).config
@[inline] def withParent (parent : Expr) (f : M α) : M α :=
withTheReader Context (fun ctx => { ctx with parent? := parent }) f

View File

@@ -10,14 +10,14 @@ import Lean.Meta.SynthInstance
namespace Lean.Meta
abbrev UnificationHintKey := DiscrTree.Key (simpleReduce := true)
abbrev UnificationHintKey := DiscrTree.Key
structure UnificationHintEntry where
keys : Array UnificationHintKey
val : Name
deriving Inhabited
abbrev UnificationHintTree := DiscrTree Name (simpleReduce := true)
abbrev UnificationHintTree := DiscrTree Name
structure UnificationHints where
discrTree : UnificationHintTree := DiscrTree.empty
@@ -26,8 +26,10 @@ structure UnificationHints where
instance : ToFormat UnificationHints where
format h := format h.discrTree
def UnificationHints.config : WhnfCoreConfig := { iota := false, proj := .no }
def UnificationHints.add (hints : UnificationHints) (e : UnificationHintEntry) : UnificationHints :=
{ hints with discrTree := hints.discrTree.insertCore e.keys e.val }
{ hints with discrTree := hints.discrTree.insertCore e.keys e.val config }
builtin_initialize unificationHintExtension : SimpleScopedEnvExtension UnificationHintEntry UnificationHints
registerSimpleScopedEnvExtension {
@@ -78,7 +80,7 @@ def addUnificationHint (declName : Name) (kind : AttributeKind) : MetaM Unit :=
match decodeUnificationHint body with
| Except.error msg => throwError msg
| Except.ok hint =>
let keys DiscrTree.mkPath hint.pattern.lhs
let keys DiscrTree.mkPath hint.pattern.lhs UnificationHints.config
validateHint hint
unificationHintExtension.add { keys := keys, val := declName } kind
@@ -98,7 +100,7 @@ def tryUnificationHints (t s : Expr) : MetaM Bool := do
if t.isMVar then
return false
let hints := unificationHintExtension.getState ( getEnv)
let candidates hints.discrTree.getMatch t
let candidates hints.discrTree.getMatch t UnificationHints.config
for candidate in candidates do
if ( tryCandidate candidate) then
return true

View File

@@ -59,33 +59,32 @@ def isAuxDef (constName : Name) : MetaM Bool := do
let env getEnv
return isAuxRecursor env constName || isNoConfusion env constName
@[inline] private def matchConstAux {α} (e : Expr) (failK : Unit MetaM α) (k : ConstantInfo List Level MetaM α) : MetaM α :=
match e with
| Expr.const name lvls => do
let (some cinfo) getUnfoldableConst? name | failK ()
k cinfo lvls
| _ => failK ()
@[inline] private def matchConstAux {α} (e : Expr) (failK : Unit MetaM α) (k : ConstantInfo List Level MetaM α) : MetaM α := do
let .const name lvls := e
| failK ()
let (some cinfo) getUnfoldableConst? name
| failK ()
k cinfo lvls
-- ===========================
/-! # Helper functions for reducing recursors -/
-- ===========================
private def getFirstCtor (d : Name) : MetaM (Option Name) := do
let some (ConstantInfo.inductInfo { ctors := ctor::_, ..}) getUnfoldableConstNoEx? d | pure none
let some (ConstantInfo.inductInfo { ctors := ctor::_, ..}) getUnfoldableConstNoEx? d |
return none
return some ctor
private def mkNullaryCtor (type : Expr) (nparams : Nat) : MetaM (Option Expr) := do
match type.getAppFn with
| Expr.const d lvls =>
let (some ctor) getFirstCtor d | pure none
return mkAppN (mkConst ctor lvls) (type.getAppArgs.shrink nparams)
| _ =>
return none
let .const d lvls := type.getAppFn
| return none
let (some ctor) getFirstCtor d | pure none
return mkAppN (mkConst ctor lvls) (type.getAppArgs.shrink nparams)
private def getRecRuleFor (recVal : RecursorVal) (major : Expr) : Option RecursorRule :=
match major.getAppFn with
| Expr.const fn _ => recVal.rules.find? fun r => r.ctor == fn
| _ => none
| .const fn _ => recVal.rules.find? fun r => r.ctor == fn
| _ => none
private def toCtorWhenK (recVal : RecursorVal) (major : Expr) : MetaM Expr := do
let majorType inferType major
@@ -165,7 +164,7 @@ private def reduceRec (recVal : RecursorVal) (recLvls : List Level) (recArgs : A
let majorIdx := recVal.getMajorIdx
if h : majorIdx < recArgs.size then do
let major := recArgs.get majorIdx, h
let mut major if isWFRec recVal.name && ( getTransparency) == TransparencyMode.default then
let mut major if isWFRec recVal.name && ( getTransparency) == .default then
-- If recursor is `Acc.rec` or `WellFounded.rec` and transparency is default,
-- then we bump transparency to .all to make sure we can unfold defs defined by WellFounded recursion.
-- We use this trick because we abstract nested proofs occurring in definitions.
@@ -307,8 +306,55 @@ end
/-! # Weak Head Normal Form auxiliary combinators -/
-- ===========================
/--
Configuration for projection reduction. See `whnfCore`.
-/
inductive ProjReductionKind where
/-- Projections `s.i` are not reduced at `whnfCore`. -/
| no
/--
Projections `s.i` are reduced at `whnfCore`, and `whnfCore` is used at `s` during the process.
Recall that `whnfCore` does not perform `delta` reduction (i.e., it will not unfold constant declarations).
-/
| yes
/--
Projections `s.i` are reduced at `whnfCore`, and `whnf` is used at `s` during the process.
Recall that `whnfCore` does not perform `delta` reduction (i.e., it will not unfold constant declarations), but `whnf` does.
-/
| yesWithDelta
deriving DecidableEq, Inhabited, Repr
/--
Configuration options for `whnfEasyCases` and `whnfCore`.
-/
structure WhnfCoreConfig where
/-- If `true`, reduce recursor/matcher applications, e.g., `Nat.rec true (fun _ _ => false) Nat.zero` reduces to `true` -/
iota : Bool := true
/-- If `true`, reduce terms such as `(fun x => t[x]) a` into `t[a]` -/
beta : Bool := true
/-- Control projection reduction at `whnfCore`. -/
proj : ProjReductionKind := .yesWithDelta
/--
Zeta reduction.
It includes two kinds of reduction:
- `let x := v; e[x]` reduces to `e[v]`.
- Given a local context containing entry `x : t := e`, free variable `x` reduces to `e`.
We say a let-declaration `let x := v; e` is non dependent if it is equivalent to `(fun x => e) v`.
Recall that
```
fun x : BitVec 5 => let n := 5; fun y : BitVec n => x = y
```
is type correct, but
```
fun x : BitVec 5 => (fun n => fun y : BitVec n => x = y) 5
```
is not.
-/
zeta : Bool := true
/-- Auxiliary combinator for handling easy WHNF cases. It takes a function for handling the "hard" cases as an argument -/
@[specialize] partial def whnfEasyCases (e : Expr) (k : Expr MetaM Expr) : MetaM Expr := do
@[specialize] partial def whnfEasyCases (e : Expr) (k : Expr MetaM Expr) (config : WhnfCoreConfig := {}) : MetaM Expr := do
match e with
| .forallE .. => return e
| .lam .. => return e
@@ -319,22 +365,19 @@ end
| .const .. => k e
| .app .. => k e
| .proj .. => k e
| .mdata _ e => whnfEasyCases e k
| .mdata _ e => whnfEasyCases e k config
| .fvar fvarId =>
let decl fvarId.getDecl
match decl with
| .cdecl .. => return e
| .ldecl (value := v) (nonDep := nonDep) .. =>
let cfg getConfig
if nonDep && !cfg.zetaNonDep then
return e
else
if cfg.trackZeta then
modify fun s => { s with zetaFVarIds := s.zetaFVarIds.insert fvarId }
whnfEasyCases v k
| .ldecl (value := v) .. =>
unless config.zeta do return e
if ( getConfig).trackZeta then
modify fun s => { s with zetaFVarIds := s.zetaFVarIds.insert fvarId }
whnfEasyCases v k config
| .mvar mvarId =>
match ( getExprMVarAssignment? mvarId) with
| some v => whnfEasyCases v k
| some v => whnfEasyCases v k config
| none => return e
@[specialize] private def deltaDefinition (c : ConstantInfo) (lvls : List Level)
@@ -389,8 +432,8 @@ inductive ReduceMatcherResult where
-/
def canUnfoldAtMatcher (cfg : Config) (info : ConstantInfo) : CoreM Bool := do
match cfg.transparency with
| TransparencyMode.all => return true
| TransparencyMode.default => return !( isIrreducible info.name)
| .all => return true
| .default => return !( isIrreducible info.name)
| _ =>
if ( isReducible info.name) || isGlobalInstance ( getEnv) info.name then
return true
@@ -429,31 +472,29 @@ private def whnfMatcher (e : Expr) : MetaM Expr := do
whnf e
def reduceMatcher? (e : Expr) : MetaM ReduceMatcherResult := do
match e.getAppFn with
| Expr.const declName declLevels =>
let some info getMatcherInfo? declName
| return ReduceMatcherResult.notMatcher
let args := e.getAppArgs
let prefixSz := info.numParams + 1 + info.numDiscrs
if args.size < prefixSz + info.numAlts then
return ReduceMatcherResult.partialApp
else
let constInfo getConstInfo declName
let f instantiateValueLevelParams constInfo declLevels
let auxApp := mkAppN f args[0:prefixSz]
let auxAppType inferType auxApp
forallBoundedTelescope auxAppType info.numAlts fun hs _ => do
let auxApp whnfMatcher (mkAppN auxApp hs)
let auxAppFn := auxApp.getAppFn
let mut i := prefixSz
for h in hs do
if auxAppFn == h then
let result := mkAppN args[i]! auxApp.getAppArgs
let result := mkAppN result args[prefixSz + info.numAlts:args.size]
return ReduceMatcherResult.reduced result.headBeta
i := i + 1
return ReduceMatcherResult.stuck auxApp
| _ => pure ReduceMatcherResult.notMatcher
let .const declName declLevels := e.getAppFn
| return .notMatcher
let some info getMatcherInfo? declName
| return .notMatcher
let args := e.getAppArgs
let prefixSz := info.numParams + 1 + info.numDiscrs
if args.size < prefixSz + info.numAlts then
return ReduceMatcherResult.partialApp
let constInfo getConstInfo declName
let f instantiateValueLevelParams constInfo declLevels
let auxApp := mkAppN f args[0:prefixSz]
let auxAppType inferType auxApp
forallBoundedTelescope auxAppType info.numAlts fun hs _ => do
let auxApp whnfMatcher (mkAppN auxApp hs)
let auxAppFn := auxApp.getAppFn
let mut i := prefixSz
for h in hs do
if auxAppFn == h then
let result := mkAppN args[i]! auxApp.getAppArgs
let result := mkAppN result args[prefixSz + info.numAlts:args.size]
return ReduceMatcherResult.reduced result.headBeta
i := i + 1
return ReduceMatcherResult.stuck auxApp
private def projectCore? (e : Expr) (i : Nat) : MetaM (Option Expr) := do
let e := e.toCtorIfLit
@@ -471,8 +512,8 @@ def project? (e : Expr) (i : Nat) : MetaM (Option Expr) := do
/-- Reduce kernel projection `Expr.proj ..` expression. -/
def reduceProj? (e : Expr) : MetaM (Option Expr) := do
match e with
| Expr.proj _ i c => project? c i
| _ => return none
| .proj _ i c => project? c i
| _ => return none
/--
Auxiliary method for reducing terms of the form `?m t_1 ... t_n` where `?m` is delayed assigned.
@@ -509,51 +550,47 @@ then delta reduction is used to reduce `s` (i.e., `whnf` is used), otherwise `wh
If `simpleReduceOnly`, then `iota` and projection reduction are not performed.
Note that the value of `deltaAtProj` is irrelevant if `simpleReduceOnly = true`.
-/
partial def whnfCore (e : Expr) (deltaAtProj : Bool := true) (simpleReduceOnly := false) : MetaM Expr :=
partial def whnfCore (e : Expr) (config : WhnfCoreConfig := {}): MetaM Expr :=
go e
where
go (e : Expr) : MetaM Expr :=
whnfEasyCases e fun e => do
whnfEasyCases e (config := config) fun e => do
trace[Meta.whnf] e
match e with
| Expr.const .. => pure e
| Expr.letE _ _ v b _ => go <| b.instantiate1 v
| Expr.app f .. =>
| .const .. => pure e
| .letE _ _ v b _ => if config.zeta then go <| b.instantiate1 v else return e
| .app f .. =>
let f := f.getAppFn
let f' go f
if f'.isLambda then
if config.beta && f'.isLambda then
let revArgs := e.getAppRevArgs
go <| f'.betaRev revArgs
else if let some eNew whnfDelayedAssigned? f' e then
go eNew
else
let e := if f == f' then e else e.updateFn f'
if simpleReduceOnly then
return e
else
match ( reduceMatcher? e) with
| ReduceMatcherResult.reduced eNew => go eNew
| ReduceMatcherResult.partialApp => pure e
| ReduceMatcherResult.stuck _ => pure e
| ReduceMatcherResult.notMatcher =>
matchConstAux f' (fun _ => return e) fun cinfo lvls =>
match cinfo with
| ConstantInfo.recInfo rec => reduceRec rec lvls e.getAppArgs (fun _ => return e) go
| ConstantInfo.quotInfo rec => reduceQuotRec rec lvls e.getAppArgs (fun _ => return e) go
| c@(ConstantInfo.defnInfo _) => do
if ( isAuxDef c.name) then
deltaBetaDefinition c lvls e.getAppRevArgs (fun _ => return e) go
else
return e
| _ => return e
| Expr.proj _ i c =>
if simpleReduceOnly then
return e
else
let c if deltaAtProj then whnf c else whnfCore c
match ( projectCore? c i) with
| some e => go e
| none => return e
unless config.iota do return e
match ( reduceMatcher? e) with
| .reduced eNew => go eNew
| .partialApp => pure e
| .stuck _ => pure e
| .notMatcher =>
matchConstAux f' (fun _ => return e) fun cinfo lvls =>
match cinfo with
| .recInfo rec => reduceRec rec lvls e.getAppArgs (fun _ => return e) go
| .quotInfo rec => reduceQuotRec rec lvls e.getAppArgs (fun _ => return e) go
| c@(.defnInfo _) => do
if ( isAuxDef c.name) then
deltaBetaDefinition c lvls e.getAppRevArgs (fun _ => return e) go
else
return e
| _ => return e
| .proj _ i c =>
if config.proj == .no then return e
let c if config.proj == .yesWithDelta then whnf c else go c
match ( projectCore? c i) with
| some e => go e
| none => return e
| _ => unreachable!
/--
@@ -591,11 +628,11 @@ partial def smartUnfoldingReduce? (e : Expr) : MetaM (Option Expr) :=
where
go (e : Expr) : OptionT MetaM Expr := do
match e with
| Expr.letE n t v b _ => withLetDecl n t ( go v) fun x => do mkLetFVars #[x] ( go (b.instantiate1 x))
| Expr.lam .. => lambdaTelescope e fun xs b => do mkLambdaFVars xs ( go b)
| Expr.app f a .. => return mkApp ( go f) ( go a)
| Expr.proj _ _ s => return e.updateProj! ( go s)
| Expr.mdata _ b =>
| .letE n t v b _ => withLetDecl n t ( go v) fun x => do mkLetFVars #[x] ( go (b.instantiate1 x))
| .lam .. => lambdaTelescope e fun xs b => do mkLambdaFVars xs ( go b)
| .app f a .. => return mkApp ( go f) ( go a)
| .proj _ _ s => return e.updateProj! ( go s)
| .mdata _ b =>
if let some m := smartUnfoldingMatch? e then
goMatch m
else
@@ -625,7 +662,7 @@ mutual
-/
partial def unfoldProjInst? (e : Expr) : MetaM (Option Expr) := do
match e.getAppFn with
| Expr.const declName .. =>
| .const declName .. =>
match ( getProjectionFnInfo? declName) with
| some { fromClass := true, .. } =>
match ( withDefault <| unfoldDefinition? e) with
@@ -651,7 +688,7 @@ mutual
/-- Unfold definition using "smart unfolding" if possible. -/
partial def unfoldDefinition? (e : Expr) : MetaM (Option Expr) :=
match e with
| Expr.app f _ =>
| .app f _ =>
matchConstAux f.getAppFn (fun _ => unfoldProjInstWhenIntances? e) fun fInfo fLvls => do
if fInfo.levelParams.length != fLvls.length then
return none
@@ -663,7 +700,7 @@ mutual
return none
if smartUnfolding.get ( getOptions) then
match (( getEnv).find? (mkSmartUnfoldingNameFor fInfo.name)) with
| some fAuxInfo@(ConstantInfo.defnInfo _) =>
| some fAuxInfo@(.defnInfo _) =>
-- We use `preserveMData := true` to make sure the smart unfolding annotation are not erased in an over-application.
deltaBetaDefinition fAuxInfo fLvls e.getAppRevArgs (preserveMData := true) (fun _ => pure none) fun e₁ => do
let some r smartUnfoldingReduce? e₁ | return none
@@ -719,7 +756,7 @@ mutual
unfoldDefault ()
else
unfoldDefault ()
| Expr.const declName lvls => do
| .const declName lvls => do
if smartUnfolding.get ( getOptions) && ( getEnv).contains (mkSmartUnfoldingNameFor declName) then
return none
else
@@ -757,12 +794,12 @@ def reduceRecMatcher? (e : Expr) : MetaM (Option Expr) := do
if !e.isApp then
return none
else match ( reduceMatcher? e) with
| ReduceMatcherResult.reduced e => return e
| .reduced e => return e
| _ => matchConstAux e.getAppFn (fun _ => pure none) fun cinfo lvls => do
match cinfo with
| ConstantInfo.recInfo «rec» => reduceRec «rec» lvls e.getAppArgs (fun _ => pure none) (fun e => pure (some e))
| ConstantInfo.quotInfo «rec» => reduceQuotRec «rec» lvls e.getAppArgs (fun _ => pure none) (fun e => pure (some e))
| c@(ConstantInfo.defnInfo _) =>
| .recInfo «rec» => reduceRec «rec» lvls e.getAppArgs (fun _ => pure none) (fun e => pure (some e))
| .quotInfo «rec» => reduceQuotRec «rec» lvls e.getAppArgs (fun _ => pure none) (fun e => pure (some e))
| c@(.defnInfo _) =>
if ( isAuxDef c.name) then
deltaBetaDefinition c lvls e.getAppRevArgs (fun _ => pure none) (fun e => pure (some e))
else
@@ -812,12 +849,12 @@ def reduceNat? (e : Expr) : MetaM (Option Expr) :=
if e.hasFVar || e.hasMVar then
return none
else match e with
| Expr.app (Expr.const fn _) a =>
| .app (.const fn _) a =>
if fn == ``Nat.succ then
reduceUnaryNatOp Nat.succ a
else
return none
| Expr.app (Expr.app (Expr.const fn _) a1) a2 =>
| .app (.app (.const fn _) a1) a2 =>
if fn == ``Nat.add then reduceBinNatOp Nat.add a1 a2
else if fn == ``Nat.sub then reduceBinNatOp Nat.sub a1 a2
else if fn == ``Nat.mul then reduceBinNatOp Nat.mul a1 a2
@@ -839,25 +876,25 @@ def reduceNat? (e : Expr) : MetaM (Option Expr) :=
return false
else
match ( getConfig).transparency with
| TransparencyMode.default => return true
| TransparencyMode.all => return true
| _ => return false
| .default => return true
| .all => return true
| _ => return false
@[inline] private def cached? (useCache : Bool) (e : Expr) : MetaM (Option Expr) := do
if useCache then
match ( getConfig).transparency with
| TransparencyMode.default => return ( get).cache.whnfDefault.find? e
| TransparencyMode.all => return ( get).cache.whnfAll.find? e
| _ => unreachable!
| .default => return ( get).cache.whnfDefault.find? e
| .all => return ( get).cache.whnfAll.find? e
| _ => unreachable!
else
return none
private def cache (useCache : Bool) (e r : Expr) : MetaM Expr := do
if useCache then
match ( getConfig).transparency with
| TransparencyMode.default => modify fun s => { s with cache.whnfDefault := s.cache.whnfDefault.insert e r }
| TransparencyMode.all => modify fun s => { s with cache.whnfAll := s.cache.whnfAll.insert e r }
| _ => unreachable!
| .default => modify fun s => { s with cache.whnfDefault := s.cache.whnfDefault.insert e r }
| .all => modify fun s => { s with cache.whnfAll := s.cache.whnfAll.insert e r }
| _ => unreachable!
return r
@[export lean_whnf]
@@ -884,7 +921,7 @@ def reduceProjOf? (e : Expr) (p : Name → Bool) : MetaM (Option Expr) := do
if !e.isApp then
pure none
else match e.getAppFn with
| Expr.const name .. => do
| .const name .. => do
let env getEnv
match env.getProjectionStructureName? name with
| some structName =>

15
tests/lean/run/2669.lean Normal file
View File

@@ -0,0 +1,15 @@
def f : Nat Nat := fun x => x - x
@[simp] theorem f_zero (n : Nat) : f n = 0 :=
Nat.sub_self n
example (n : Nat) : False := by
let g := f n
have : g + n = n := by
fail_if_success simp (config := { zeta := false }) [Nat.zero_add] -- Should not succeed
simp
sorry
example (h : a = b) : (fun x => a + x) 0 = b := by
fail_if_success simp (config := { beta := false })
simp [*]

View File

@@ -0,0 +1,12 @@
example (a : Nat) : let n := 0; n + a = a := by
intro n
fail_if_success simp (config := { zeta := false })
simp (config := { zeta := false }) [n]
example (a b : Nat) (h : a = b) : let n := 0; n + a = b := by
intro n
fail_if_success simp (config := { zeta := false })
trace_state
simp (config := { zeta := false }) [n]
trace_state
simp [h]

View File

@@ -233,7 +233,7 @@ do print "----- tst14 -----";
print stateM;
let monad mkMonad stateM;
let globalInsts getGlobalInstancesIndex;
let insts globalInsts.getUnify monad;
let insts globalInsts.getUnify monad {};
print (insts.map (·.val));
pure ()

View File

@@ -32,27 +32,27 @@ def succ := mkConst `Nat.succ
def add := mkAppN (mkConst `Add.add [levelZero]) #[nat, mkConst `Nat.add]
def tst1 : MetaM Unit :=
do let d : DiscrTree Nat true := {};
do let d : DiscrTree Nat := {};
let mvar mkFreshExprMVar nat;
let d d.insert (mkAppN add #[mvar, mkNatLit 10]) 1;
let d d.insert (mkAppN add #[mkNatLit 0, mkNatLit 10]) 2;
let d d.insert (mkAppN (mkConst `Nat.add) #[mkNatLit 0, mkNatLit 20]) 3;
let d d.insert (mkAppN add #[mvar, mkNatLit 20]) 4;
let d d.insert mvar 5;
let d d.insert (mkAppN add #[mvar, mkNatLit 10]) 1 {};
let d d.insert (mkAppN add #[mkNatLit 0, mkNatLit 10]) 2 {};
let d d.insert (mkAppN (mkConst `Nat.add) #[mkNatLit 0, mkNatLit 20]) 3 {};
let d d.insert (mkAppN add #[mvar, mkNatLit 20]) 4 {};
let d d.insert mvar 5 {};
print (format d);
let vs d.getMatch (mkAppN add #[mkNatLit 1, mkNatLit 10]);
let vs d.getMatch (mkAppN add #[mkNatLit 1, mkNatLit 10]) {};
print (format vs);
let t := mkAppN add #[mvar, mvar];
print t;
let vs d.getMatch t;
let vs d.getMatch t {};
print (format vs);
let vs d.getUnify t;
let vs d.getUnify t {};
print (format vs);
let vs d.getUnify mvar;
let vs d.getUnify mvar {};
print (format vs);
let vs d.getUnify $ mkAppN add #[mkNatLit 0, mvar];
let vs d.getUnify (mkAppN add #[mkNatLit 0, mvar]) {};
print (format vs);
let vs d.getUnify $ mkAppN add #[mvar, mkNatLit 20];
let vs d.getUnify (mkAppN add #[mvar, mkNatLit 20]) {};
print (format vs);
pure ()

View File

@@ -44,7 +44,7 @@ def tst2 : MetaM Unit := do
| some (_, lhs, _) =>
trace[Meta.debug] "lhs: {lhs}"
let s Meta.getSimpTheorems
let m s.post.getMatch lhs
let m s.post.getMatch lhs {}
trace[Meta.debug] "result: {m}"
assert! m.any fun s => s.origin == .decl `ex2