Compare commits

...

3 Commits

Author SHA1 Message Date
Joe Hendrix
9af3bc52a9 backtrack from attempted refactor 2024-03-22 18:23:43 -04:00
Joe Hendrix
cb0ff5ee8a Apply suggestions from code review
Co-authored-by: Scott Morrison <scott.morrison@gmail.com>
2024-03-22 18:23:43 -04:00
Joe Hendrix
94ca9e4921 feat: upstream rw? tactic 2024-03-22 18:23:42 -04:00
9 changed files with 737 additions and 71 deletions

View File

@@ -1318,6 +1318,22 @@ used when closing the goal.
-/
syntax (name := apply?) "apply?" (" using " (colGt term),+)? : tactic
/--
Syntax for excluding some names, e.g. `[-my_lemma, -my_theorem]`.
-/
syntax rewrites_forbidden := " [" (("-" ident),*,?) "]"
/--
`rw?` tries to find a lemma which can rewrite the goal.
`rw?` should not be left in proofs; it is a search tool, like `apply?`.
Suggestions are printed as `rw [h]` or `rw [← h]`.
You can use `rw? [-my_lemma, -my_theorem]` to prevent `rw?` using the named lemmas.
-/
syntax (name := rewrites?) "rw?" (ppSpace location)? (rewrites_forbidden)? : tactic
/--
`show_term tac` runs `tac`, then prints the generated term in the form
"exact X Y Z" or "refine X ?_ Z" if there are remaining subgoals.

View File

@@ -39,3 +39,4 @@ import Lean.Elab.Tactic.SolveByElim
import Lean.Elab.Tactic.LibrarySearch
import Lean.Elab.Tactic.ShowTerm
import Lean.Elab.Tactic.Rfl
import Lean.Elab.Tactic.Rewrites

View File

@@ -0,0 +1,69 @@
/-
Copyright (c) 2023 Scott Morrison. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Scott Morrison
-/
prelude
import Lean.Elab.Tactic.Location
import Lean.Meta.Tactic.Replace
import Lean.Meta.Tactic.Rewrites
/-!
# The `rewrites` tactic.
`rw?` tries to find a lemma which can rewrite the goal.
`rw?` should not be left in proofs; it is a search tool, like `apply?`.
Suggestions are printed as `rw [h]` or `rw [← h]`.
-/
namespace Lean.Elab.Rewrites
open Lean Meta Rewrites
open Lean.Parser.Tactic
open Lean Elab Tactic
@[builtin_tactic Lean.Parser.Tactic.rewrites?]
def evalExact : Tactic := fun stx => do
let `(tactic| rw?%$tk $[$loc]? $[[ $[-$forbidden],* ]]?) := stx
| throwUnsupportedSyntax
let moduleRef createModuleTreeRef
let forbidden : NameSet :=
((forbidden.getD #[]).map Syntax.getId).foldl (init := ) fun s n => s.insert n
reportOutOfHeartbeats `findRewrites tk
let goal getMainGoal
withLocation (expandOptLocation (Lean.mkOptionalNode loc))
fun f => do
let some a f.findDecl? | return
if a.isImplementationDetail then return
let target instantiateMVars ( f.getType)
let hyps localHypotheses (except := [f])
let results findRewrites hyps moduleRef goal target (stopAtRfl := false) forbidden
reportOutOfHeartbeats `rewrites tk
if results.isEmpty then
throwError "Could not find any lemmas which can rewrite the hypothesis {← f.getUserName}"
for r in results do withMCtx r.mctx do
Tactic.TryThis.addRewriteSuggestion tk [(r.expr, r.symm)]
r.result.eNew (loc? := .some (.fvar f)) (origSpan? := getRef)
if let some r := results[0]? then
setMCtx r.mctx
let replaceResult goal.replaceLocalDecl f r.result.eNew r.result.eqProof
replaceMainGoal (replaceResult.mvarId :: r.result.mvarIds)
do
let target instantiateMVars ( goal.getType)
let hyps localHypotheses
let results findRewrites hyps moduleRef goal target (stopAtRfl := true) forbidden
reportOutOfHeartbeats `rewrites tk
if results.isEmpty then
throwError "Could not find any lemmas which can rewrite the goal"
results.forM (·.addSuggestion tk)
if let some r := results[0]? then
setMCtx r.mctx
replaceMainGoal
(( goal.replaceTargetEq r.result.eNew r.result.eqProof) :: r.result.mvarIds)
evalTactic ( `(tactic| try rfl))
(fun _ => throwError "Failed to find a rewrite for some location")
end Lean.Elab.Rewrites

View File

@@ -1881,6 +1881,22 @@ def letFunAppArgs? (e : Expr) : Option (Array Expr × Name × Expr × Expr × Ex
| .lam n _ b _ => some (rest, n, t, v, b)
| _ => some (rest, .anonymous, t, v, .app f (.bvar 0))
/-- Maps `f` on each immediate child of the given expression. -/
@[specialize]
def traverseChildren [Applicative M] (f : Expr M Expr) : Expr M Expr
| e@(forallE _ d b _) => pure e.updateForallE! <*> f d <*> f b
| e@(lam _ d b _) => pure e.updateLambdaE! <*> f d <*> f b
| e@(mdata _ b) => e.updateMData! <$> f b
| e@(letE _ t v b _) => pure e.updateLet! <*> f t <*> f v <*> f b
| e@(app l r) => pure e.updateApp! <*> f l <*> f r
| e@(proj _ _ b) => e.updateProj! <$> f b
| e => pure e
/-- `e.foldlM f a` folds the monadic function `f` over the subterms of the expression `e`,
with initial value `a`. -/
def foldlM {α : Type} {m} [Monad m] (f : α Expr m α) (init : α) (e : Expr) : m α :=
Prod.snd <$> StateT.run (e.traverseChildren (fun e' => fun a => Prod.mk e' <$> f a e')) init
end Expr
/--

View File

@@ -393,26 +393,37 @@ Get the root key and rest of terms of an expression using the specified config.
private def rootKey (cfg: WhnfCoreConfig) (e : Expr) : MetaM (Key × Array Expr) :=
pushArgs true (Array.mkEmpty initCapacity) e cfg
private partial def mkPathAux (root : Bool) (todo : Array Expr) (keys : Array Key)
(config : WhnfCoreConfig) : MetaM (Array Key) := do
private partial def buildPath (op : Bool Array Expr Expr MetaM (Key × Array Expr)) (root : Bool) (todo : Array Expr) (keys : Array Key) : MetaM (Array Key) := do
if todo.isEmpty then
return keys
else
let e := todo.back
let todo := todo.pop
let (k, todo) pushArgs root todo e config
mkPathAux false todo (keys.push k) config
let (k, todo) op root todo e
buildPath op false todo (keys.push k)
/--
Create a path from an expression.
Create a key path from an expression using the function used for patterns.
This differs from Lean.Meta.DiscrTree.mkPath in that the expression
This differs from Lean.Meta.DiscrTree.mkPath and targetPath in that the expression
should uses free variables rather than meta-variables for holes.
-/
private def mkPath (e : Expr) (config : WhnfCoreConfig) : MetaM (Array Key) := do
def patternPath (e : Expr) (config : WhnfCoreConfig) : MetaM (Array Key) := do
let todo : Array Expr := .mkEmpty initCapacity
let keys : Array Key := .mkEmpty initCapacity
mkPathAux (root := true) (todo.push e) keys config
let op root todo e := pushArgs root todo e config
buildPath op (root := true) (todo.push e) (.mkEmpty initCapacity)
/--
Create a key path from an expression we are matching against.
This should have mvars instantiated where feasible.
-/
def targetPath (e : Expr) (config : WhnfCoreConfig) : MetaM (Array Key) := do
let todo : Array Expr := .mkEmpty initCapacity
let op root todo e := do
let (k, args) MatchClone.getMatchKeyArgs e root config
pure (k, todo ++ args)
buildPath op (root := true) (todo.push e) (.mkEmpty initCapacity)
/- Monad for finding matches while resolving deferred patterns. -/
@[reducible]
@@ -512,7 +523,7 @@ A match result contains the terms formed from matching a term against
patterns in the discrimination tree.
-/
private structure MatchResult (α : Type) where
structure MatchResult (α : Type) where
/--
The elements in the match result.
@@ -525,7 +536,9 @@ private structure MatchResult (α : Type) where
-/
elts : Array (Array (Array α)) := #[]
private def MatchResult.push (r : MatchResult α) (score : Nat) (e : Array α) : MatchResult α :=
namespace MatchResult
private def push (r : MatchResult α) (score : Nat) (e : Array α) : MatchResult α :=
if e.isEmpty then
r
else if score < r.elts.size then
@@ -539,14 +552,28 @@ private def MatchResult.push (r : MatchResult α) (score : Nat) (e : Array α) :
termination_by score - a.size
loop r.elts
private partial def MatchResult.toArray (mr : MatchResult α) : Array α :=
loop (Array.mkEmpty n) mr.elts
where n := mr.elts.foldl (fun i a => a.foldl (fun n a => n + a.size) i) 0
loop (r : Array α) (a : Array (Array (Array α))) :=
if a.isEmpty then
r
else
loop (a.back.foldl (init := r) (fun r a => r ++ a)) a.pop
/--
Number of elements in result
-/
partial def size (mr : MatchResult α) : Nat :=
mr.elts.foldl (fun i a => a.foldl (fun n a => n + a.size) i) 0
/--
Append results to array
-/
@[specialize]
partial def appendResultsAux (mr : MatchResult α) (a : Array β) (f : Nat α β) : Array β :=
let aa := mr.elts
let n := aa.size
Nat.fold (n := n) (init := a) fun i r =>
let j := n-1-i
let b := aa[j]!
b.foldl (init := r) (· ++ ·.map (f j))
partial def appendResults (mr : MatchResult α) (a : Array α) : Array α :=
mr.appendResultsAux a (fun _ a => a)
end MatchResult
private partial def getMatchLoop (todo : Array Expr) (score : Nat) (c : TrieIndex)
(result : MatchResult α) : MatchM α (MatchResult α) := do
@@ -619,8 +646,8 @@ private def getMatchCore (root : Lean.HashMap Key TrieIndex) (e : Expr) :
The results are ordered so that the longest matches in terms of number of
non-star keys are first with ties going to earlier operators first.
-/
def getMatch (d : LazyDiscrTree α) (e : Expr) : MetaM (Array α × LazyDiscrTree α) :=
withReducible <| runMatch d <| (·.toArray) <$> getMatchCore d.roots e
def getMatch (d : LazyDiscrTree α) (e : Expr) : MetaM (MatchResult α × LazyDiscrTree α) :=
withReducible <| runMatch d <| getMatchCore d.roots e
/--
Structure for quickly initializing a lazy discrimination tree with a large number
@@ -845,21 +872,11 @@ def createLocalPreDiscrTree
let r (env.constants.map₂.foldlM (init := {}) act : BaseIO (PreDiscrTree α))
pure r
/-- Create an imported environment for tree. -/
def createLocalEnvironment
(act : Name ConstantInfo MetaM (Array (InitEntry α))) :
CoreM (LazyDiscrTree α) := do
let env getEnv
let ngen getChildNgen
let d ImportData.new
let t createLocalPreDiscrTree ngen env d act
let errors d.errors.get
if p : errors.size > 0 then
throw errors[0].exception
pure <| t.toLazy
def dropKeys (t : LazyDiscrTree α) (keys : List (List LazyDiscrTree.Key)) : MetaM (LazyDiscrTree α) := do
keys.foldlM (init := t) (·.dropKey ·)
/-- Create an imported environment for tree. -/
def createImportedEnvironment (ngen : NameGenerator) (env : Environment)
/-- Create a discriminator tree for imported environment. -/
def createImportedDiscrTree (ngen : NameGenerator) (env : Environment)
(act : Name ConstantInfo MetaM (Array (InitEntry α)))
(constantsPerTask : Nat := 1000) :
EIO Exception (LazyDiscrTree α) := do
@@ -889,23 +906,12 @@ def createImportedEnvironment (ngen : NameGenerator) (env : Environment)
throw r.errors[0].exception
pure <| r.tree.toLazy
def dropKeys (t : LazyDiscrTree α) (keys : List (List LazyDiscrTree.Key)) : MetaM (LazyDiscrTree α) := do
keys.foldlM (init := t) (·.dropKey ·)
/--
`findCandidates` searches for entries in a lazily initialized discriminator tree.
* `ext` should be an environment extension with an IO.Ref for caching the import lazy
discriminator tree.
* `addEntry` is the function for creating discriminator tree entries from constants.
* `droppedKeys` contains keys we do not want to consider when searching for matches.
It is used for dropping very general keys.
-/
def findCandidates (ext : EnvExtension (IO.Ref (Option (LazyDiscrTree α))))
(addEntry : Name ConstantInfo MetaM (Array (InitEntry α)))
(droppedKeys : List (List LazyDiscrTree.Key) := [])
(constantsPerTask : Nat := 1000)
(ty : Expr) : MetaM (Array α) := do
def findImportMatches
(ext : EnvExtension (IO.Ref (Option (LazyDiscrTree α))))
(addEntry : Name ConstantInfo MetaM (Array (InitEntry α)))
(droppedKeys : List (List LazyDiscrTree.Key) := [])
(constantsPerTask : Nat := 1000)
(ty : Expr) : MetaM (MatchResult α) := do
let ngen getNGen
let (cNGen, ngen) := ngen.mkChild
setNGen ngen
@@ -913,14 +919,106 @@ def findCandidates (ext : EnvExtension (IO.Ref (Option (LazyDiscrTree α))))
let ref := @EnvExtension.getState _ dummy ext (getEnv)
let importTree (ref.get).getDM $ do
profileitM Exception "lazy discriminator import initialization" (getOptions) $ do
let t createImportedEnvironment cNGen (getEnv) addEntry
let t createImportedDiscrTree cNGen (getEnv) addEntry
(constantsPerTask := constantsPerTask)
dropKeys t droppedKeys
let (localCandidates, _)
profileitM Exception "lazy discriminator local search" (getOptions) $ do
let t createLocalEnvironment addEntry
let t dropKeys t droppedKeys
t.getMatch ty
let (importCandidates, importTree) importTree.getMatch ty
ref.set importTree
pure (localCandidates ++ importCandidates)
ref.set (some importTree)
pure importCandidates
/--
A discriminator tree for the current module's declarations only.
Note. We use different discriminator trees for imported and current module
declarations since imported declarations are typically much more numerous but
not changed after the environment is created.
-/
structure ModuleDiscrTreeRef (α : Type _) where
ref : IO.Ref (LazyDiscrTree α)
/-- Create a discriminator tree for current module declarations. -/
def createModuleDiscrTree
(entriesForConst : Name ConstantInfo MetaM (Array (InitEntry α))) :
CoreM (LazyDiscrTree α) := do
let env getEnv
let ngen getChildNgen
let d ImportData.new
let t createLocalPreDiscrTree ngen env d entriesForConst
let errors d.errors.get
if p : errors.size > 0 then
throw errors[0].exception
pure <| t.toLazy
/--
Creates reference for lazy discriminator tree that only contains this module's definitions.
-/
def createModuleTreeRef (entriesForConst : Name ConstantInfo MetaM (Array (InitEntry α)))
(droppedKeys : List (List LazyDiscrTree.Key)) : MetaM (ModuleDiscrTreeRef α) := do
profileitM Exception "build module discriminator tree" (getOptions) $ do
let t createModuleDiscrTree entriesForConst
let t dropKeys t droppedKeys
pure { ref := IO.mkRef t }
/--
Returns candidates from this module in this module that match the expression.
* `moduleRef` is a references to a lazy discriminator tree only containing
this module's definitions.
-/
def findModuleMatches (moduleRef : ModuleDiscrTreeRef α) (ty : Expr) : MetaM (MatchResult α) := do
profileitM Exception "lazy discriminator local search" (getOptions) $ do
let discrTree moduleRef.ref.get
let (localCandidates, localTree) discrTree.getMatch ty
moduleRef.ref.set localTree
pure localCandidates
/--
`findMatchesExt` searches for entries in a lazily initialized discriminator tree.
It provides some additional capabilities beyond `findMatches` to adjust results
based on priority and cache module declarations
* `modulesTreeRef` points to the discriminator tree for local environment.
Used for caching and created by `createLocalTree`.
* `ext` should be an environment extension with an IO.Ref for caching the import lazy
discriminator tree.
* `addEntry` is the function for creating discriminator tree entries from constants.
* `droppedKeys` contains keys we do not want to consider when searching for matches.
It is used for dropping very general keys.
* `constantsPerTask` stores number of constants in imported modules used to
decide when to create new task.
* `adjustResult` takes the priority and value to produce a final result.
* `ty` is the expression type.
-/
def findMatchesExt
(moduleTreeRef : ModuleDiscrTreeRef α)
(ext : EnvExtension (IO.Ref (Option (LazyDiscrTree α))))
(addEntry : Name ConstantInfo MetaM (Array (InitEntry α)))
(droppedKeys : List (List LazyDiscrTree.Key) := [])
(constantsPerTask : Nat := 1000)
(adjustResult : Nat α β)
(ty : Expr) : MetaM (Array β) := do
let moduleMatches findModuleMatches moduleTreeRef ty
let importMatches findImportMatches ext addEntry droppedKeys constantsPerTask ty
return Array.mkEmpty (moduleMatches.size + importMatches.size)
|> moduleMatches.appendResultsAux (f := adjustResult)
|> importMatches.appendResultsAux (f := adjustResult)
/--
`findMatches` searches for entries in a lazily initialized discriminator tree.
* `ext` should be an environment extension with an IO.Ref for caching the import lazy
discriminator tree.
* `addEntry` is the function for creating discriminator tree entries from constants.
* `droppedKeys` contains keys we do not want to consider when searching for matches.
It is used for dropping very general keys.
-/
def findMatches (ext : EnvExtension (IO.Ref (Option (LazyDiscrTree α))))
(addEntry : Name ConstantInfo MetaM (Array (InitEntry α)))
(droppedKeys : List (List LazyDiscrTree.Key) := [])
(constantsPerTask : Nat := 1000)
(ty : Expr) : MetaM (Array α) := do
let moduleTreeRef createModuleTreeRef addEntry droppedKeys
let incPrio _ v := v
findMatchesExt moduleTreeRef ext addEntry droppedKeys constantsPerTask incPrio ty

View File

@@ -39,3 +39,4 @@ import Lean.Meta.Tactic.Backtrack
import Lean.Meta.Tactic.SolveByElim
import Lean.Meta.Tactic.FunInd
import Lean.Meta.Tactic.Rfl
import Lean.Meta.Tactic.Rewrites

View File

@@ -67,7 +67,7 @@ to find candidate lemmas.
@[reducible]
def CandidateFinder := Expr MetaM (Array (Name × DeclMod))
open LazyDiscrTree (InitEntry findCandidates)
open LazyDiscrTree (InitEntry findMatches)
private def addImport (name : Name) (constInfo : ConstantInfo) :
MetaM (Array (InitEntry (Name × DeclMod))) :=
@@ -111,7 +111,7 @@ private def constantsPerImportTask : Nat := 6500
/-- Create function for finding relevant declarations. -/
def libSearchFindDecls : Expr MetaM (Array (Name × DeclMod)) :=
findCandidates ext addImport
findMatches ext addImport
(droppedKeys := droppedKeys)
(constantsPerTask := constantsPerImportTask)
@@ -278,15 +278,15 @@ private def librarySearch' (goal : MVarId)
MetaM (Option (Array (List MVarId × MetavarContext))) := do
withTraceNode `Tactic.librarySearch (return m!"{librarySearchEmoji ·} {← goal.getType}") do
profileitM Exception "librarySearch" ( getOptions) do
-- Create predicate that returns true when running low on heartbeats.
let candidates librarySearchSymm libSearchFindDecls goal
let cfg : ApplyConfig := { allowSynthFailures := true }
let shouldAbort mkHeartbeatCheck leavePercentHeartbeats
let act := fun cand => do
if shouldAbort then
abortSpeculation
librarySearchLemma cfg tactic allowFailure cand
tryOnEach act candidates
-- Create predicate that returns true when running low on heartbeats.
let candidates librarySearchSymm libSearchFindDecls goal
let cfg : ApplyConfig := { allowSynthFailures := true }
let shouldAbort mkHeartbeatCheck leavePercentHeartbeats
let act := fun cand => do
if shouldAbort then
abortSpeculation
librarySearchLemma cfg tactic allowFailure cand
tryOnEach act candidates
/--
Tries to solve the goal either by:

View File

@@ -0,0 +1,339 @@
/-
Copyright (c) 2023 Scott Morrison. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Scott Morrison
-/
prelude
import Lean.Meta.LazyDiscrTree
import Lean.Meta.Tactic.Assumption
import Lean.Meta.Tactic.Rewrite
import Lean.Meta.Tactic.Rfl
import Lean.Meta.Tactic.SolveByElim
import Lean.Meta.Tactic.TryThis
import Lean.Util.Heartbeats
namespace Lean.Meta.Rewrites
open Lean.Meta.LazyDiscrTree (InitEntry MatchResult)
open Lean.Meta.SolveByElim
builtin_initialize registerTraceClass `Tactic.rewrites
builtin_initialize registerTraceClass `Tactic.rewrites.lemmas
/-- Extract the lemma, with arguments, that was used to produce a `RewriteResult`. -/
-- This assumes that `r.eqProof` was constructed as:
-- `mkApp6 (.const ``congrArg _) α eType lhs rhs motive heq`
-- in `Lean.Meta.Tactic.Rewrite` and we want `heq`.
def rewriteResultLemma (r : RewriteResult) : Option Expr :=
if r.eqProof.isAppOfArity ``congrArg 6 then
r.eqProof.getArg! 5
else
none
/-- Weight to multiply the "specificity" of a rewrite lemma by when rewriting forwards. -/
def forwardWeight := 2
/-- Weight to multiply the "specificity" of a rewrite lemma by when rewriting backwards. -/
def backwardWeight := 1
private def addImport (name : Name) (constInfo : ConstantInfo) :
MetaM (Array (InitEntry (Name × Bool × Nat))) := do
if constInfo.isUnsafe then return #[]
if !allowCompletion (getEnv) name then return #[]
-- We now remove some injectivity lemmas which are not useful to rewrite by.
if name matches .str _ "injEq" then return #[]
if name matches .str _ "sizeOf_spec" then return #[]
match name with
| .str _ n => if n.endsWith "_inj" n.endsWith "_inj'" then return #[]
| _ => pure ()
withNewMCtxDepth do withReducible do
forallTelescopeReducing constInfo.type fun _ type => do
match type.getAppFnArgs with
| (``Eq, #[_, lhs, rhs])
| (``Iff, #[lhs, rhs]) => do
let a := Array.mkEmpty 2
let a := a.push ( InitEntry.fromExpr lhs (name, false, forwardWeight))
let a := a.push ( InitEntry.fromExpr rhs (name, true, backwardWeight))
pure a
| _ => return #[]
/-- Configuration for `DiscrTree`. -/
def discrTreeConfig : WhnfCoreConfig := {}
/-- Select `=` and `↔` local hypotheses. -/
def localHypotheses (except : List FVarId := []) : MetaM (Array (Expr × Bool × Nat)) := do
let r getLocalHyps
let mut result := #[]
for h in r do
if except.contains h.fvarId! then continue
let (_, _, type) forallMetaTelescopeReducing ( inferType h)
let type whnfR type
match type.getAppFnArgs with
| (``Eq, #[_, lhs, rhs])
| (``Iff, #[lhs, rhs]) => do
let lhsKey : Array DiscrTree.Key DiscrTree.mkPath lhs discrTreeConfig
let rhsKey : Array DiscrTree.Key DiscrTree.mkPath rhs discrTreeConfig
result := result.push (h, false, forwardWeight * lhsKey.size)
|>.push (h, true, backwardWeight * rhsKey.size)
| _ => pure ()
return result
/--
We drop `.star` and `Eq * * *` from the discriminator trees because
they match too much.
-/
def droppedKeys : List (List LazyDiscrTree.Key) := [[.star], [.const `Eq 3, .star, .star, .star]]
def createModuleTreeRef : MetaM (LazyDiscrTree.ModuleDiscrTreeRef (Name × Bool × Nat)) :=
LazyDiscrTree.createModuleTreeRef addImport droppedKeys
private def ExtState := IO.Ref (Option (LazyDiscrTree (Name × Bool × Nat)))
private builtin_initialize ExtState.default : IO.Ref (Option (LazyDiscrTree (Name × Bool × Nat))) do
IO.mkRef .none
private instance : Inhabited ExtState where
default := ExtState.default
private builtin_initialize ext : EnvExtension ExtState
registerEnvExtension (IO.mkRef .none)
/--
The maximum number of constants an individual task may perform.
The value was picked because it roughly correponded to 50ms of work on the
machine this was developed on. Smaller numbers did not seem to improve
performance when importing Std and larger numbers (<10k) seemed to degrade
initialization performance.
-/
private def constantsPerImportTask : Nat := 6500
def incPrio : Nat Name × Bool × Nat Name × Bool × Nat
| p, (nm, d, prio) => (nm, d, prio * 100 + p)
/-- Create function for finding relevant declarations. -/
def rwFindDecls (moduleRef : LazyDiscrTree.ModuleDiscrTreeRef (Name × Bool × Nat)) : Expr MetaM (Array (Name × Bool × Nat)) :=
LazyDiscrTree.findMatchesExt moduleRef ext addImport
(droppedKeys := droppedKeys)
(constantsPerTask := constantsPerImportTask)
(adjustResult := incPrio)
/-- Data structure recording a potential rewrite to report from the `rw?` tactic. -/
structure RewriteResult where
/-- The lemma we rewrote by.
This is `Expr`, not just a `Name`, as it may be a local hypothesis. -/
expr : Expr
/-- `True` if we rewrote backwards (i.e. with `rw [← h]`). -/
symm : Bool
/-- The "weight" of the rewrite. This is calculated based on how specific the rewrite rule was. -/
weight : Nat
/-- The result from the `rw` tactic. -/
result : Meta.RewriteResult
/-- The metavariable context after the rewrite.
This needs to be stored as part of the result so we can backtrack the state. -/
mctx : MetavarContext
rfl? : Bool
/-- Update a `RewriteResult` by filling in the `rfl?` field if it is currently `none`,
to reflect whether the remaining goal can be closed by `with_reducible rfl`. -/
def computeRfl (mctx : MetavarContext) (res : Meta.RewriteResult) : MetaM Bool := do
try
withoutModifyingState <| withMCtx mctx do
-- We use `withReducible` here to follow the behaviour of `rw`.
withReducible ( mkFreshExprMVar res.eNew).mvarId!.applyRfl
-- We do not need to record the updated `MetavarContext` here.
pure true
catch _e =>
pure false
/--
Pretty print the result of the rewrite.
-/
private def RewriteResult.ppResult (r : RewriteResult) : MetaM String :=
return ( ppExpr r.result.eNew).pretty
/-- Should we try discharging side conditions? If so, using `assumption`, or `solve_by_elim`? -/
inductive SideConditions
| none
| assumption
| solveByElim
/-- Shortcut for calling `solveByElim`. -/
def solveByElim (goals : List MVarId) (depth : Nat := 6) : MetaM PUnit := do
-- There is only a marginal decrease in performance for using the `symm` option for `solveByElim`.
-- (measured via `lake build && time lake env lean test/librarySearch.lean`).
let cfg : SolveByElimConfig := { maxDepth := depth, exfalso := false, symm := true }
let lemmas, ctx mkAssumptionSet false false [] [] #[]
let [] SolveByElim.solveByElim cfg lemmas ctx goals
| failure
def rwLemma (ctx : MetavarContext) (goal : MVarId) (target : Expr) (side : SideConditions := .solveByElim)
(lem : Expr Name) (symm : Bool) (weight : Nat) : MetaM (Option RewriteResult) :=
withMCtx ctx do
let some expr (match lem with
| .inl hyp => pure (some hyp)
| .inr lem => some <$> mkConstWithFreshMVarLevels lem <|> pure none)
| return none
trace[Tactic.rewrites] m!"considering {if symm then " " else ""}{expr}"
let some result some <$> goal.rewrite target expr symm <|> pure none
| return none
if result.mvarIds.isEmpty then
let mctx getMCtx
let rfl? computeRfl mctx result
return some { expr, symm, weight, result, mctx, rfl? }
else
-- There are side conditions, which we try to discharge using local hypotheses.
let discharge
match side with
| .none => pure false
| .assumption => ((fun _ => true) <$> result.mvarIds.mapM fun m => m.assumption) <|> pure false
| .solveByElim => (solveByElim result.mvarIds >>= fun _ => pure true) <|> pure false
match discharge with
| false =>
return none
| true =>
-- If we succeed, we need to reconstruct the expression to report that we rewrote by.
let some expr := rewriteResultLemma result | return none
let expr instantiateMVars expr
let (expr, symm) := if expr.isAppOfArity ``Eq.symm 4 then
(expr.getArg! 3, true)
else
(expr, false)
let mctx getMCtx
let rfl? computeRfl mctx result
return some { expr, symm, weight, result, mctx, rfl? }
/--
Find keys which match the expression, or some subexpression.
Note that repeated subexpressions will be visited each time they appear,
making this operation potentially very expensive.
It would be good to solve this problem!
Implementation: we reverse the results from `getMatch`,
so that we return lemmas matching larger subexpressions first,
and amongst those we return more specific lemmas first.
-/
partial def getSubexpressionMatches (op : Expr MetaM (Array α)) (e : Expr) : MetaM (Array α) := do
match e with
| .bvar _ => return #[]
| .forallE _ _ _ _ =>
forallTelescope e fun args body => do
args.foldlM (fun acc arg => return acc ++ ( getSubexpressionMatches op ( inferType arg)))
( getSubexpressionMatches op body).reverse
| .lam _ _ _ _
| .letE _ _ _ _ _ =>
lambdaLetTelescope e (fun args body => do
args.foldlM (fun acc arg => return acc ++ ( getSubexpressionMatches op ( inferType arg)))
( getSubexpressionMatches op body).reverse)
| _ =>
let init := (( op e).reverse)
e.foldlM (init := init) (fun a f => return a ++ ( getSubexpressionMatches op f))
/--
Find lemmas which can rewrite the goal.
See also `rewrites` for a more convenient interface.
-/
def rewriteCandidates (hyps : Array (Expr × Bool × Nat))
(moduleRef : LazyDiscrTree.ModuleDiscrTreeRef (Name × Bool × Nat))
(target : Expr)
(forbidden : NameSet := ) :
MetaM (Array ((Expr Name) × Bool × Nat)) := do
-- Get all lemmas which could match some subexpression
let candidates getSubexpressionMatches (rwFindDecls moduleRef) target
-- Sort them by our preferring weighting
-- (length of discriminant key, doubled for the forward implication)
let candidates := candidates.insertionSort fun (_, _, rp) (_, _, sp) => rp > sp
-- Now deduplicate. We can't use `Array.deduplicateSorted` as we haven't completely sorted,
-- and in fact want to keep some of the residual ordering from the discrimination tree.
let mut forward : NameSet :=
let mut backward : NameSet :=
let mut deduped := #[]
for (l, s, w) in candidates do
if forbidden.contains l then continue
if s then
if ¬ backward.contains l then
deduped := deduped.push (l, s, w)
backward := backward.insert l
else
if ¬ forward.contains l then
deduped := deduped.push (l, s, w)
forward := forward.insert l
trace[Tactic.rewrites.lemmas] m!"Candidate rewrite lemmas:\n{deduped}"
let hyps := hyps.map fun hyp, symm, weight => (Sum.inl hyp, symm, weight)
let lemmas := deduped.map fun lem, symm, weight => (Sum.inr lem, symm, weight)
pure <| hyps ++ lemmas
def RewriteResult.newGoal (r : RewriteResult) : Option Expr :=
if r.rfl? = true then
some (Expr.lit (.strVal "no goals"))
else
some r.result.eNew
def RewriteResult.addSuggestion (ref : Syntax) (r : RewriteResult) : Elab.TermElabM Unit := do
withMCtx r.mctx do
Tactic.TryThis.addRewriteSuggestion ref [(r.expr, r.symm)] (type? := r.newGoal) (origSpan? := getRef)
structure RewriteResultConfig where
stopAtRfl : Bool
max : Nat
minHeartbeats : Nat
goal : MVarId
target : Expr
side : SideConditions := .solveByElim
mctx : MetavarContext
def takeListAux (cfg : RewriteResultConfig) (seen : HashMap String Unit) (acc : Array RewriteResult)
(xs : List ((Expr Name) × Bool × Nat)) : MetaM (Array RewriteResult) := do
let mut seen := seen
let mut acc := acc
for (lem, symm, weight) in xs do
if ( getRemainingHeartbeats) < cfg.minHeartbeats then
return acc
if acc.size cfg.max then
return acc
let res
withoutModifyingState <| withMCtx cfg.mctx do
rwLemma cfg.mctx cfg.goal cfg.target cfg.side lem symm weight
match res with
| none => continue
| some r =>
let s withoutModifyingState <| withMCtx r.mctx r.ppResult
if seen.contains s then
continue
let rfl? computeRfl r.mctx r.result
if cfg.stopAtRfl then
if rfl? then
return #[r]
else
seen := seen.insert s ()
acc := acc.push r
else
seen := seen.insert s ()
acc := acc.push r
return acc
/-- Find lemmas which can rewrite the goal. -/
def findRewrites (hyps : Array (Expr × Bool × Nat))
(moduleRef : LazyDiscrTree.ModuleDiscrTreeRef (Name × Bool × Nat))
(goal : MVarId) (target : Expr)
(forbidden : NameSet := ) (side : SideConditions := .solveByElim)
(stopAtRfl : Bool) (max : Nat := 20)
(leavePercentHeartbeats : Nat := 10) : MetaM (List RewriteResult) := do
let mctx getMCtx
let candidates rewriteCandidates hyps moduleRef target forbidden
let minHeartbeats : Nat :=
if ( getMaxHeartbeats) = 0 then
0
else
leavePercentHeartbeats * ( getRemainingHeartbeats) / 100
let cfg : RewriteResultConfig :=
{ stopAtRfl, minHeartbeats, max, mctx, goal, target, side }
return ( takeListAux cfg {} (Array.mkEmpty max) candidates.toList).toList
end Lean.Meta.Rewrites

View File

@@ -0,0 +1,126 @@
attribute [refl] Eq.refl
private axiom test_sorry : {α}, α
-- To see the (sorted) list of lemmas that `rw?` will try rewriting by, use:
-- set_option trace.Tactic.rewrites.lemmas true
/--
info: Try this: rw [@List.map_append]
-- "no goals"
-/
#guard_msgs in
example (f : α β) (L M : List α) : (L ++ M).map f = L.map f ++ M.map f := by
rw?
/--
info: Try this: rw [Nat.one_mul]
-- "no goals"
-/
#guard_msgs in
example (h : Nat) : 1 * h = h := by
rw?
#guard_msgs(drop info) in
example (h : Int) (hyp : g * 1 = h) : g = h := by
rw? at hyp
assumption
#guard_msgs(drop info) in
example : (x y : Nat), x y := by
intros x y
rw? -- Used to be an error here https://leanprover.zulipchat.com/#narrow/stream/287929-mathlib4/topic/panic.20and.20error.20with.20rw.3F/near/370495531
exact test_sorry
example : (x y : Nat), x y := by
-- Used to be a panic here https://leanprover.zulipchat.com/#narrow/stream/287929-mathlib4/topic/panic.20and.20error.20with.20rw.3F/near/370495531
fail_if_success rw?
exact test_sorry
axiom K : Type
@[instance] axiom K.hasOne : OfNat K 1
@[instance] axiom K.hasIntCoe : Coe K Int
noncomputable def foo : K K := test_sorry
#guard_msgs(drop info) in
example : foo x = 1 k : Int, x = k := by
rw? -- Used to panic, see https://leanprover.zulipchat.com/#narrow/stream/287929-mathlib4/topic/panic.20and.20error.20with.20rw.3F/near/370598036
exact test_sorry
theorem six_eq_seven : 6 = 7 := test_sorry
-- This test also verifies that we are removing duplicate results;
-- it previously also reported `Nat.cast_ofNat`
#guard_msgs(drop info) in
example : (x : Nat), x 6 := by
rw?
guard_target = (x : Nat), x 7
exact test_sorry
#guard_msgs(drop info) in
example : (x : Nat) (_w : x 6), x 8 := by
rw?
guard_target = (x : Nat) (_w : x 7), x 8
exact test_sorry
-- check we can look inside let expressions
#guard_msgs(drop info) in
example (n : Nat) : let y := 3; n + y = 3 + n := by
rw?
axiom α : Type
axiom f : α α
axiom z : α
axiom f_eq (n) : f n = z
-- Check that the same lemma isn't used multiple times.
-- This used to report two redundant copies of `f_eq`.
-- It be lovely if `rw?` could produce two *different* rewrites by `f_eq` here!
#guard_msgs(drop info) in
theorem test : f n = f m := by
fail_if_success rw? [-f_eq] -- Check that we can forbid lemmas.
rw?
rw [f_eq]
-- Check that we can rewrite by local hypotheses.
#guard_msgs(drop info) in
example (h : 1 = 2) : 2 = 1 := by
rw?
def zero : Nat := 0
-- This used to (incorrectly!) succeed because `rw?` would try `rfl`,
-- rather than `withReducible` `rfl`.
#guard_msgs(drop info) in
example : zero = 0 := by
rw?
exact test_sorry
-- Discharge side conditions from local hypotheses.
/--
info: Try this: rw [h p]
-- "no goals"
-/
#guard_msgs in
example {P : Prop} (p : P) (h : P 1 = 2) : 2 = 1 := by
rw?
-- Use `solve_by_elim` to discharge side conditions.
/--
info: Try this: rw [h (f p)]
-- "no goals"
-/
#guard_msgs in
example {P Q : Prop} (p : P) (f : P Q) (h : Q 1 = 2) : 2 = 1 := by
rw?
-- Rewrite in reverse, discharging side conditions from local hypotheses.
/--
info: Try this: rw [← h₁ p]
-- Q a
-/
#guard_msgs in
example {P : Prop} (p : P) (Q : α Prop) (a b : α) (h₁ : P a = b) (w : Q a) : Q b := by
rw?
exact w