Compare commits

...

1 Commits

Author SHA1 Message Date
Joe Hendrix
22695af20b chore: migrate find functionality into LazyDiscrTree.
This is designed to simplify process of creating library search-like computations
2024-03-14 13:36:43 -07:00
2 changed files with 129 additions and 95 deletions

View File

@@ -25,13 +25,11 @@ elaborated additional parts of the tree.
-/
namespace Lean.Meta.LazyDiscrTree
-- This namespace contains definitions copied from Lean.Meta.DiscrTree.
namespace MatchClone
/--
Discrimination tree key.
-/
private inductive Key where
inductive Key where
| const : Name Nat Key
| fvar : FVarId Nat Key
| lit : Literal Key
@@ -57,6 +55,9 @@ instance : Hashable Key := ⟨Key.hash⟩
end Key
-- This namespace contains definitions copied from Lean.Meta.DiscrTree.
namespace MatchClone
private def tmpMVarId : MVarId := { name := `_discr_tree_tmp }
private def tmpStar := mkMVar tmpMVarId
@@ -265,8 +266,6 @@ private abbrev getMatchKeyArgs (e : Expr) (root : Bool) (config : WhnfCoreConfig
end MatchClone
export MatchClone (Key Key.const)
/--
An unprocessed entry in the lazy discrimination tree.
-/
@@ -290,7 +289,7 @@ private structure Trie (α : Type) where
/-- Following matches based on key of trie. -/
children : HashMap Key TrieIndex
/-- Lazy entries at this trie that are not processed. -/
pending : Array (LazyEntry α)
pending : Array (LazyEntry α) := #[]
deriving Inhabited
instance : EmptyCollection (Trie α) := .node #[] 0 {} #[]
@@ -484,14 +483,29 @@ private def evalNode (c : TrieIndex) :
setTrie c <| .node vs star cs #[]
pure (vs, star, cs)
/--
Return the information about the trie at the given idnex.
def dropKeyAux (next : TrieIndex) (rest : List Key) :
MatchM α Unit :=
if next = 0 then
pure ()
else do
let (_, star, children) evalNode next
match rest with
| [] =>
modify (·.set! next {values := #[], star, children})
| k :: r => do
let next := if k == .star then star else children.findD k 0
dropKeyAux next r
Used for internal debugging purposes.
/--
This drops a specific key from the lazy discrimination tree so that
all the entries matching that key exactly are removed.
-/
private def getTrie (d : LazyDiscrTree α) (idx : TrieIndex) :
MetaM ((Array α × TrieIndex × HashMap Key TrieIndex) × LazyDiscrTree α) :=
runMatch d (evalNode idx)
def dropKey (t : LazyDiscrTree α) (path : List LazyDiscrTree.Key) : MetaM (LazyDiscrTree α) :=
match path with
| [] => pure t
| rootKey :: rest => do
let idx := t.roots.findD rootKey 0
Prod.snd <$> runMatch t (dropKeyAux idx rest)
/--
A match result contains the terms formed from matching a term against
@@ -638,7 +652,9 @@ private def push (d : PreDiscrTree α) (k : Key) (e : LazyEntry α) : PreDiscrTr
/-- Convert a pre-discrimination tree to a lazy discrimination tree. -/
private def toLazy (d : PreDiscrTree α) (config : WhnfCoreConfig := {}) : LazyDiscrTree α :=
let { roots, tries } := d
{ config, roots, tries := tries.map (.node {} 0 {}) }
-- Adjust trie indices so the first value is reserved (so 0 is never a valid trie index)
let roots := roots.fold (init := roots) (fun m k n => m.insert k (n+1))
{ config, roots, tries := #[default] ++ tries.map (.node {} 0 {}) }
/-- Merge two discrimination trees. -/
protected def append (x y : PreDiscrTree α) : PreDiscrTree α :=
@@ -810,6 +826,38 @@ private def createImportedEnvironmentSeq (ngen : NameGenerator) (env : Environme
private def combineGet [Append α] (z : α) (tasks : Array (Task α)) : α :=
tasks.foldl (fun x t => x ++ t.get) (init := z)
def getChildNgen [Monad M] [MonadNameGenerator M] : M NameGenerator := do
let ngen getNGen
let (cngen, ngen) := ngen.mkChild
setNGen ngen
pure cngen
def createLocalPreDiscrTree
(ngen : NameGenerator)
(env : Environment)
(d : ImportData)
(act : Name ConstantInfo MetaM (Array (InitEntry α))) :
BaseIO (PreDiscrTree α) := do
let modName := env.header.mainModule
let cacheRef IO.mkRef (Cache.empty ngen)
let act (t : PreDiscrTree α) (n : Name) (c : ConstantInfo) : BaseIO (PreDiscrTree α) :=
addConstImportData env modName d cacheRef t act n c
let r (env.constants.map₂.foldlM (init := {}) act : BaseIO (PreDiscrTree α))
pure r
/-- Create an imported environment for tree. -/
def createLocalEnvironment
(act : Name ConstantInfo MetaM (Array (InitEntry α))) :
CoreM (LazyDiscrTree α) := do
let env getEnv
let ngen getChildNgen
let d ImportData.new
let t createLocalPreDiscrTree ngen env d act
let errors d.errors.get
if p : errors.size > 0 then
throw errors[0].exception
pure <| t.toLazy
/-- Create an imported environment for tree. -/
def createImportedEnvironment (ngen : NameGenerator) (env : Environment)
(act : Name ConstantInfo MetaM (Array (InitEntry α)))
@@ -840,3 +888,39 @@ def createImportedEnvironment (ngen : NameGenerator) (env : Environment)
if p : r.errors.size > 0 then
throw r.errors[0].exception
pure <| r.tree.toLazy
def dropKeys (t : LazyDiscrTree α) (keys : List (List LazyDiscrTree.Key)) : MetaM (LazyDiscrTree α) := do
keys.foldlM (init := t) (·.dropKey ·)
/--
`findCandidates` searches for entries in a lazily initialized discriminator tree.
* `ext` should be an environment extension with an IO.Ref for caching the import lazy
discriminator tree.
* `addEntry` is the function for creating discriminator tree entries from constants.
* `droppedKeys` contains keys we do not want to consider when searching for matches.
It is used for dropping very general keys.
-/
def findCandidates (ext : EnvExtension (IO.Ref (Option (LazyDiscrTree α))))
(addEntry : Name ConstantInfo MetaM (Array (InitEntry α)))
(droppedKeys : List (List LazyDiscrTree.Key) := [])
(constantsPerTask : Nat := 1000)
(ty : Expr) : MetaM (Array α) := do
let ngen getNGen
let (cNGen, ngen) := ngen.mkChild
setNGen ngen
let dummy : IO.Ref (Option (LazyDiscrTree α)) IO.mkRef none
let ref := @EnvExtension.getState _ dummy ext (getEnv)
let importTree (ref.get).getDM $ do
profileitM Exception "lazy discriminator import initialization" (getOptions) $ do
let t createImportedEnvironment cNGen (getEnv) addEntry
(constantsPerTask := constantsPerTask)
dropKeys t droppedKeys
let (localCandidates, _)
profileitM Exception "lazy discriminator local search" (getOptions) $ do
let t createLocalEnvironment addEntry
let t dropKeys t droppedKeys
t.getMatch ty
let (importCandidates, importTree) importTree.getMatch ty
ref.set importTree
pure (localCandidates ++ importCandidates)

View File

@@ -28,6 +28,9 @@ example : Nat := by exact?
namespace Lean.Meta.LibrarySearch
builtin_initialize registerTraceClass `Tactic.librarySearch
builtin_initialize registerTraceClass `Tactic.librarySearch.lemmas
open SolveByElim
/--
@@ -64,44 +67,7 @@ to find candidate lemmas.
@[reducible]
def CandidateFinder := Expr MetaM (Array (Name × DeclMod))
namespace DiscrTreeFinder
/-- Adds a path to a discrimination tree. -/
private def addPath [BEq α] (config : WhnfCoreConfig) (tree : DiscrTree α) (tp : Expr) (v : α) :
MetaM (DiscrTree α) := do
let k DiscrTree.mkPath tp config
pure <| tree.insertCore k v
/-- Adds a constant with given name to tree. -/
private def updateTree (config : WhnfCoreConfig) (tree : DiscrTree (Name × DeclMod))
(name : Name) (constInfo : ConstantInfo) : MetaM (DiscrTree (Name × DeclMod)) := do
if constInfo.isUnsafe then return tree
if !allowCompletion (getEnv) name then return tree
withReducible do
let (_, _, type) forallMetaTelescope constInfo.type
let tree addPath config tree type (name, DeclMod.none)
match type.getAppFnArgs with
| (``Iff, #[lhs, rhs]) => do
let tree addPath config tree rhs (name, DeclMod.mp)
let tree addPath config tree lhs (name, DeclMod.mpr)
return tree
| _ =>
return tree
end DiscrTreeFinder
namespace IncDiscrTreeFinder
open LazyDiscrTree (InitEntry createImportedEnvironment)
/--
The maximum number of constants an individual task may perform.
The value was picked because it roughly correponded to 50ms of work on the machine this was
developed on. Smaller numbers did not seem to improve performance when importing Std and larger
numbers (<10k) seemed to degrade initialization performance.
-/
private def constantsPerTask : Nat := 6500
open LazyDiscrTree (InitEntry findCandidates)
private def addImport (name : Name) (constInfo : ConstantInfo) :
MetaM (Array (InitEntry (Name × DeclMod))) :=
@@ -115,37 +81,42 @@ private def addImport (name : Name) (constInfo : ConstantInfo) :
else
pure a
def findCandidates (ref : IO.Ref (Option (LazyDiscrTree (Name × DeclMod))))
(ty : Expr) : MetaM (Array (Name × DeclMod)) := do
let ngen getNGen
let (childNGen, ngen) := ngen.mkChild
setNGen ngen
let importTree (ref.get).getDM $ do
profileitM Exception "librarySearch launch" (getOptions) $
createImportedEnvironment childNGen (getEnv) (constantsPerTask := constantsPerTask) addImport
let (imports, importTree) importTree.getMatch ty
ref.set importTree
pure imports
end IncDiscrTreeFinder
builtin_initialize registerTraceClass `Tactic.librarySearch
builtin_initialize registerTraceClass `Tactic.librarySearch.lemmas
/-- State for resolving imports -/
/-- Stores import discrimination tree. -/
private def LibSearchState := IO.Ref (Option (LazyDiscrTree (Name × DeclMod)))
private builtin_initialize LibSearchState.default : IO.Ref (Option (LazyDiscrTree (Name × DeclMod))) do
private builtin_initialize defaultLibSearchState : IO.Ref (Option (LazyDiscrTree (Name × DeclMod))) do
IO.mkRef .none
private instance : Inhabited LibSearchState where
default := LibSearchState.default
default := defaultLibSearchState
private builtin_initialize ext : EnvExtension LibSearchState
registerEnvExtension (IO.mkRef .none)
/--
Return an action that returns true when the remaining heartbeats is less
We drop `.star` and `Eq * * *` from the discriminator trees because
they match too much.
-/
def droppedKeys : List (List LazyDiscrTree.Key) := [[.star], [.const `Eq 3, .star, .star, .star]]
/--
The maximum number of constants an individual task may perform.
The value was picked because it roughly correponded to 50ms of work on the
machine this was developed on. Smaller numbers did not seem to improve
performance when importing Std and larger numbers (<10k) seemed to degrade
initialization performance.
-/
private def constantsPerImportTask : Nat := 6500
/-- Create function for finding relevant declarations. -/
def libSearchFindDecls : Expr MetaM (Array (Name × DeclMod)) :=
findCandidates ext addImport
(droppedKeys := droppedKeys)
(constantsPerTask := constantsPerImportTask)
/--
Return an action that returns true when the remaining heartbeats is less
than the currently remaining heartbeats * leavePercent / 100.
-/
def mkHeartbeatCheck (leavePercent : Nat) : MetaM (MetaM Bool) := do
@@ -246,19 +217,6 @@ private def isVar (e : Expr) : Bool :=
| .mvar _ => true
| _ => false
private def isNonspecific (type : Expr) : MetaM Bool := do
forallTelescope type fun _ tp =>
match tp.getAppFn with
| .bvar _ => pure true
| .fvar _ => pure true
| .mvar _ => pure true
| .const nm _ =>
if nm = ``Eq then
pure (tp.getAppArgsN 3 |>.all isVar)
else
pure false
| _ => pure false
/--
Tries to apply the given lemma (with symmetry modifier) to the goal,
then tries to close subsequent goals using `solveByElim`.
@@ -273,9 +231,6 @@ private def librarySearchLemma (cfg : ApplyConfig) (act : List MVarId → MetaM
withTraceNode `Tactic.librarySearch (return m!"{emoji ·} trying {name}{ppMod mod} ") do
setMCtx mctx
let lem mkLibrarySearchLemma name mod
let lemType instantiateMVars ( inferType lem)
if isNonspecific lemType then
failure
let newGoals goal.apply lem cfg
try
act newGoals
@@ -323,15 +278,10 @@ private def librarySearch' (goal : MVarId)
MetaM (Option (Array (List MVarId × MetavarContext))) := do
withTraceNode `Tactic.librarySearch (return m!"{librarySearchEmoji ·} {← goal.getType}") do
profileitM Exception "librarySearch" ( getOptions) do
let importTreeRef := ext.getState (getEnv)
let searchFn (ty : Expr) := do
let localMap ( getEnv).constants.map₂.foldlM (init := {}) (DiscrTreeFinder.updateTree {})
let locals := ( localMap.getMatch ty {}).reverse
pure <| locals ++ ( IncDiscrTreeFinder.findCandidates importTreeRef ty)
-- Create predicate that returns true when running low on heartbeats.
let shouldAbort mkHeartbeatCheck leavePercentHeartbeats
let candidates librarySearchSymm searchFn goal
let candidates librarySearchSymm libSearchFindDecls goal
let cfg : ApplyConfig := { allowSynthFailures := true }
let shouldAbort mkHeartbeatCheck leavePercentHeartbeats
let act := fun cand => do
if shouldAbort then
abortSpeculation