Compare commits

...

7 Commits

Author SHA1 Message Date
Leonardo de Moura
31cd6bac1b chore: update comments 2025-10-22 17:43:09 -07:00
Leonardo de Moura
e5c6459887 chore: add checkSystem 2025-10-22 17:34:16 -07:00
Leonardo de Moura
acedf24747 chore: missing file 2025-10-22 17:24:43 -07:00
Leonardo de Moura
12df6ce4a3 feat: use parameter optimizer at finish? 2025-10-22 17:23:49 -07:00
Leonardo de Moura
73f6c29d8b chore: 2025-10-22 16:30:03 -07:00
Leonardo de Moura
a93800ac19 chore: move 2025-10-22 16:30:03 -07:00
Leonardo de Moura
5642496dbe feat: simple parameter minimizer 2025-10-22 16:30:03 -07:00
5 changed files with 305 additions and 31 deletions

View File

@@ -7,8 +7,10 @@ module
prelude
public import Lean.Meta.Tactic.Grind.Action
public import Lean.Meta.Tactic.Grind.Intro
import Lean.Util.ParamMinimizer
import Lean.Meta.Tactic.Grind.EMatch
import Lean.Meta.Tactic.Grind.EMatchTheoremParam
import Lean.Meta.Tactic.Grind.EMatchTheoremPtr
import Lean.Meta.Tactic.Grind.MarkNestedSubsingletons
namespace Lean.Meta.Grind.Action
@@ -23,20 +25,32 @@ the original order.
The order in which things are asserted affects the proof found by `grind`.
Thus, preserving the original order should intuitively help ensure that the generated
tactic script for the continuation still closes the goal when combined with the
generated `instantiate` tactic. However, it does not guarantee that the
script can be successfully replayed, since we are filtering out instantiations that do
not appear in the final proof term. Recall that a theorem instance may
contribute to the proof search even if it does not appear in the final proof term.
generated `instantiate` tactic.
**Note**: We use a simple parameter optimizer for computing the `instantiate` tactic parameter.
We have a lower and upper bound for their parameters
The lower bound consists of the theorems actually used in the proof term, while the upper
bound includes all the theorems instantiated in a particular theorem instantiation step.
The lower bound is often sufficient to replay the proof, but in some cases, additional
theorems must be included because a theorem instantiation may contribute to the proof by
providing terms and many not be present in the final proof term.
**Note*: If an working `instantiate [...]` tactic cannot be produced, we produce the
tactic `instantiate approx` to indicate that this step is approximate and tweaking is needed.
We currently used unlimited budget for find the optimal parameter setting. We will add
a parameter to set the maximum number of iterations. After this implemented, we may generate
`instantiate only approx [...]` to indicate the parameter search has been interrupted and
a non-minimal set of parameters was used.
-/
structure CollectState where
visited : Std.HashSet ExprPtr := {}
collectedThms : Std.HashSet (Origin × EMatchTheoremKind) := {}
idAndThms : Array (Name × EMatchTheorem) := #[]
thms : Array EMatchTheorem := #[]
def collect (e : Expr) (map : EMatch.InstanceMap) : Array (Name × EMatchTheorem) :=
def collect (e : Expr) (map : EMatch.InstanceMap) : Array EMatchTheorem :=
let (_, s) := go e |>.run {}
s.idAndThms
s.thms
where
go (e : Expr) : StateM CollectState Unit := do
if isMarkedSubsingletonApp e then
@@ -52,7 +66,7 @@ where
if let some thm := map[uniqueId]? then
let key := (thm.origin, thm.kind)
unless ( get).collectedThms.contains key do
modify fun s => { s with collectedThms := s.collectedThms.insert key, idAndThms := s.idAndThms.push (uniqueId, thm) }
modify fun s => { s with collectedThms := s.collectedThms.insert key, thms := s.thms.push thm }
match e with
| .lam _ d b _
| .forallE _ d b _ => go d; go b
@@ -91,7 +105,7 @@ def mkInstantiateTactic (goal : Goal) (usedThms : Array EMatchTheorem) (approx :
| true, false => `(grind| instantiate only)
| false, false => `(grind| instantiate only [$params,*])
| true, true => `(grind| instantiate approx)
| false, true => `(grind| instantiate approx [$params,*])
| false, true => `(grind| instantiate only approx [$params,*])
def mkNewSeq (goal : Goal) (thms : Array EMatchTheorem) (seq : List TGrind) (approx : Bool) : GrindM (List TGrind) := do
if thms.isEmpty then
@@ -99,8 +113,33 @@ def mkNewSeq (goal : Goal) (thms : Array EMatchTheorem) (seq : List TGrind) (app
else
return (( mkInstantiateTactic goal thms approx) :: seq)
def getAllTheorems (map : EMatch.InstanceMap) : Array EMatchTheorem :=
map.toArray.map (·.2)
abbrev EMatchTheoremIds := Std.HashMap EMatchTheoremPtr Nat
def getAllTheorems (map : EMatch.InstanceMap) : Array EMatchTheorem × EMatchTheoremIds := Id.run do
let idAndThms := map.toArray
-- **Note**: See note above. We want to sort here to reproduce the original instantiation order.
let idAndThms := idAndThms.qsort fun (id₁, _) (id₂, _) => id₁.lt id₂
let mut map := {}
let mut thms := #[]
for (_, thm) in idAndThms do
unless map.contains { thm } do
map := map.insert { thm } thms.size
thms := thms.push thm
return (thms, map)
def mkMask (map : EMatchTheoremIds) (thms : Array EMatchTheorem) : CoreM (Array Bool) := do
let mut result := Array.replicate map.size false
for thm in thms do
let some i := map.get? { thm } | throwError "`grind` internal error, theorem index not found"
result := result.set! i true
return result
def maskToThms (thms : Array EMatchTheorem) (mask : Array Bool) : Array EMatchTheorem := Id.run do
let mut result := #[]
for h : i in *...mask.size do
if mask[i] then
result := result.push thms[i]!
return result
public def instantiate' : Action := fun goal kna kp => do
let saved? saveStateIfTracing
@@ -110,17 +149,21 @@ public def instantiate' : Action := fun goal kna kp => do
| .closed seq =>
if ( getConfig).trace then
let proof instantiateMVars (mkMVar goal.mvarId)
let usedIdAndThms := collect proof map
-- **Note**: See note above. We want to sort here to reproduce the original instantiation order.
let usedIdAndThms := usedIdAndThms.qsort fun (id₁, _) (id₂, _) => id₁.lt id₂
let usedThms := usedIdAndThms.map (·.2)
let newSeq mkNewSeq goal usedThms seq (approx := false)
if ( checkSeqAt saved? goal newSeq) then
return .closed newSeq
else
let allThms := getAllTheorems map
let newSeq mkNewSeq goal allThms seq (approx := true)
return .closed newSeq
let usedThms := collect proof map
let (allThms, map) := getAllTheorems map
-- We must have at least the ones used in the proof
let initMask mkMask map usedThms
let testMask (mask : Array Bool) : GrindM Bool := do
checkSystem "`grind` `instantiate` tactic parameter optimizer"
let thms := maskToThms allThms mask
let newSeq mkNewSeq goal thms seq (approx := false)
checkSeqAt saved? goal newSeq
let r Util.ParamMinimizer.search initMask testMask
let newSeq match r.status with
| .missing => mkNewSeq goal #[] seq (approx := true)
| .approx => mkNewSeq goal (maskToThms allThms r.paramMask) seq (approx := true)
| .precise => mkNewSeq goal (maskToThms allThms r.paramMask) seq (approx := false)
return .closed newSeq
else
return .closed []
| r => return r

View File

@@ -0,0 +1,27 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Tactic.Grind.EMatchTheorem
public section
namespace Lean.Meta.Grind
@[inline] def isSameEMatchTheoremPtr (a b : EMatchTheorem) : Bool :=
unsafe ptrEq a b
structure EMatchTheoremPtr where
thm : EMatchTheorem
abbrev hashEMatchTheoremPtr (thm : EMatchTheorem) : UInt64 :=
unsafe (ptrAddrUnsafe thm >>> 3).toUInt64
instance : Hashable EMatchTheoremPtr where
hash k := hashEMatchTheoremPtr k.thm
instance : BEq EMatchTheoremPtr where
beq k₁ k₂ := isSameEMatchTheoremPtr k₁.thm k₂.thm
end Lean.Meta.Grind

View File

@@ -4,7 +4,6 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Util.CollectFVars
public import Lean.Util.CollectLevelParams
@@ -38,3 +37,4 @@ public import Lean.Util.NumApps
public import Lean.Util.FVarSubset
public import Lean.Util.SortExprs
public import Lean.Util.Reprove
public import Lean.Util.ParamMinimizer

View File

@@ -0,0 +1,205 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Init.Data.Array.Basic
public import Init.While
public import Init.Data.Range.Polymorphic
namespace Lean.Util.ParamMinimizer
/-!
A very simple parameter minimizer.
-/
/-- Status of the parameter minimizer procedure. -/
public inductive Status where
| /-- Has not found a solution. -/
missing
| /-- Found a non minimal solution. -/
approx
| /-- Found a precise solution. -/
precise
deriving Inhabited, Repr
/--
Result type for the parameter minimizer.
-/
public structure Result where
/-- Search outcome (`missing`, `approx`, or `precise`) -/
status : Status
/-- The final parameter bitmask. -/
paramMask : Array Bool
/-- Number of `test` invocations performed. -/
numCalls : Nat
structure Context (m : Type Type) where
/-- Initial parameter selection -/
initialMask : Array Bool
/--
An expensive monotonic predicate for testing whether a given parameter
configuration works or not.
-/
test : Array Bool m Bool
/--
Budget. That is, the maximum number of calls to `test` that we are willing to perform.
`0` means unbounded.
-/
maxCalls : Nat
structure State where
cur : Array Bool
added : Array Nat := #[]
numCalls : Nat := 0
found : Bool := false
/-!
We use `throw ()` to interrupt the search.
-/
abbrev M (m : Type Type) := ReaderT (Context m) (ExceptT Unit (StateT State m))
/--
Marks that a solution has been found. That is, we found a bitmask where
`test` returned `true`
-/
def markFound [Monad m] : M m Unit :=
modify fun s => { s with found := true }
def incNumCalls [Monad m] : M m Unit :=
modify fun s => { s with numCalls := s.numCalls + 1 }
/--
Adds parameter `i` to current set.
Sets bit `i` to `true` and records that it was added.
-/
def add (i : Nat) [Monad m] : M m Unit :=
modify fun s => { s with
added := s.added.push i
cur := s.cur.set! i true
}
/--
Removes parameter `i` from current set.
Sets bit `i` to `false`.
-/
def erase (i : Nat) [Monad m] : M m Unit :=
modify fun s => { s with
cur := s.cur.set! i false
}
/--
Restores parameter `i` after an unsuccessful removal
Sets bit `i` back to `true`.
-/
def restore (i : Nat) [Monad m] : M m Unit :=
modify fun s => { s with
cur := s.cur.set! i true
}
def tryCur [Monad m] : M m Bool := do
let maxCalls := ( read).maxCalls
if maxCalls > 0 && ( get).numCalls maxCalls then
throw ()
else
modify fun s => { s with numCalls := s.numCalls + 1 }
if ( ( read).test ( get).cur) then
markFound (m := m)
return true
else
return false
/--
**Initialization (growing phase).**
Starting from `initialMask`, this procedure sequentially activates parameters
(i.e., flips `false` bits to `true`) until `test` first returns `true`.
For each inactive parameter index `i`, it:
1. sets `cur[i] := true` and records `i` in `added`;
2. calls `tryCur` to evaluate the updated mask;
3. stops immediately once `test` succeeds.
This phase exploits the assumption that `test` is *monotonic* and that the
minimal true configuration is *close* to `initialMask`. It guarantees that
at completion, the current mask `cur` satisfies `test` if there is a configuration
that satisfies it. `(← get).added.back!` is the element whose inclusion first made `test` true.
-/
def init [Monad m] : M m Unit := do
let initialMask := ( read).initialMask
for h : i in *...initialMask.size do
unless initialMask[i] do
add i
if ( tryCur) then return
/--
**Pruning (minimization phase).**
Starting from a configuration `cur` known to satisfy `test`, this procedure
iterates through the indices stored in `added` **in reverse order**, removing
each one temporarily to check if it is necessary.
For each recorded index `i` (except the last one added, which is known to be
required since its removal made `test` fail during `init`):
1. sets `cur[i] := false`;
2. re-evaluates `tryCur`;
3. if `test` remains true, keeps `i` cleared;
otherwise restores `cur[i] := true`.
After this phase, `cur` is guaranteed to be *1-minimal*: removing any single
`true` bit would make `test` return `false`.
-/
def prune [Monad m] : M m Unit := do
-- **Note**: We skip the last added element because removing it
-- would necessarily make `test` fail — that's the one that flipped it to true.
let mut k := ( get).added.size - 1
while k > 0 do
k := k - 1
let i := ( get).added[k]!
erase i
unless ( tryCur) do
restore i
def main [Monad m] : M m Unit := do
init
if ( get).found then
prune
/--
**Runs the parameter minimization procedure.**
Given an initial bitmask `initialMask` representing the active parameters,
and a monotonic predicate `test : Array Bool → m Bool`, this function searches
for the smallest (1-minimal) superset of `initialMask` that satisfies `test`.
Execution proceeds in two phases:
1. **`init`** gradually activates parameters until `test` first succeeds;
2. **`prune`** removes superfluous active parameters while preserving success.
If the search completes without exceeding `maxCalls`, the result is marked as
`.precise`. If the budget is exceeded but a valid configuration was found,
the result is `.approx`. If no configuration satisfies `test`, the result is
`.missing`.
`maxCalls = 0` disables the call budget limit.
-/
public def search
[Monad m]
(initialMask : Array Bool)
(test : Array Bool m Bool)
(maxCalls : Nat := 0) -- 0 means unbounded
: m Result := do
if ( test initialMask) then
return { paramMask := initialMask, numCalls := 1, status := .precise }
let (r, s) main { initialMask, test, maxCalls} |>.run |>.run { cur := initialMask, numCalls := 1 }
let status := if s.found then
match r with
| .ok _ => .precise
| .error _ => .approx
else
.missing
return { paramMask := s.cur, numCalls := s.numCalls, status }
end Lean.Util.ParamMinimizer

View File

@@ -237,7 +237,7 @@ example (m : IndexMap α β) (a a' : α) (b : β) :
/--
info: Try this:
[apply] ⏎
instantiate approx [= getElem_def, = mem_indices_of_mem, insert]
instantiate only [= mem_indices_of_mem, insert, = getElem_def]
instantiate only [= getElem?_neg, = getElem?_pos]
cases #f590
next =>
@@ -269,8 +269,7 @@ example (m : IndexMap α β) (a a' : α) (b : β) (h : a' ∈ m.insert a b) :
example (m : IndexMap α β) (a a' : α) (b : β) (h : a' m.insert a b) :
(m.insert a b)[a'] = if h' : a' == a then b else m[a'] := by
grind =>
-- **TODO**: Check approx here
instantiate approx [= getElem_def, = mem_indices_of_mem, insert]
instantiate only [= mem_indices_of_mem, insert, = getElem_def]
instantiate only [= getElem?_neg, = getElem?_pos]
cases #f590
next =>
@@ -280,20 +279,20 @@ example (m : IndexMap α β) (a a' : α) (b : β) (h : a' ∈ m.insert a b) :
instantiate only [= Array.getElem_set]
next =>
instantiate only
instantiate only [= Array.getElem_push, size, = HashMap.getElem_insert,
= HashMap.mem_insert]
instantiate only [size, = HashMap.mem_insert, = HashMap.getElem_insert,
= Array.getElem_push]
next =>
instantiate only [= getElem_def, = mem_indices_of_mem]
instantiate only [= mem_indices_of_mem, = getElem_def]
instantiate only [usr getElem_indices_lt]
instantiate only [size]
cases #ffdf
next =>
instantiate only [=_ WF]
instantiate only [= Array.getElem_set, = getElem?_neg, = getElem?_pos]
instantiate only [= getElem?_neg, = getElem?_pos, = Array.getElem_set]
instantiate only [WF']
next =>
instantiate only
instantiate only [= Array.getElem_push, = HashMap.mem_insert, = HashMap.getElem_insert]
instantiate only [= HashMap.mem_insert, = HashMap.getElem_insert, = Array.getElem_push]
/--
info: Try this: