Compare commits

...

3 Commits

Author SHA1 Message Date
Leonardo de Moura
9478c9a1ec chore: add todo 2026-01-03 12:20:25 -08:00
Leonardo de Moura
9afa6a74d3 feat: use DiscrTree 2026-01-03 12:15:07 -08:00
Leonardo de Moura
b93b3f045b feat: add getMatch 2026-01-03 12:02:20 -08:00
4 changed files with 115 additions and 9 deletions

View File

@@ -132,8 +132,103 @@ public def insertPattern [BEq α] (d : DiscrTree α) (p : Pattern) (v : α) : Di
let keys := p.mkDiscrTreeKeys
d.insertKeyValue keys v
/-!
**TODO** Retrieval.
def getKeyArgs (e : Expr) : Key × Array Expr :=
match e.getAppFn with
| .lit v => (.lit v, #[])
| .const declName _ => (.const declName e.getAppNumArgs, e.getAppRevArgs)
| .fvar fvarId => (.fvar fvarId e.getAppNumArgs, e.getAppRevArgs)
| .forallE _ d b _ => (.arrow, #[b, d])
| _ => (.other, #[])
abbrev findKey? (cs : Array (Key × Trie α)) (k : Key) : Option (Key × Trie α) :=
cs.binSearch (k, default) (fun a b => a.1 < b.1)
partial def getMatchLoop (todo : Array Expr) (c : Trie α) (result : Array α) : Array α :=
match c with
| .node vs cs =>
if todo.isEmpty then
result ++ vs
else if cs.isEmpty then
result
else
let e := todo.back!
let todo := todo.pop
let first := cs[0]! /- Recall that `Key.star` is the minimal key -/
let (k, args) := getKeyArgs e
/- We must always visit `Key.star` edges since they are wildcards.
Thus, `todo` is not used linearly when there is `Key.star` edge
and there is an edge for `k` and `k != Key.star`. -/
let visitStar (result : Array α) : Array α :=
if first.1 == .star then
getMatchLoop todo first.2 result
else
result
let visitNonStar (k : Key) (args : Array Expr) (result : Array α) : Array α :=
match findKey? cs k with
| none => result
| some c => getMatchLoop (todo ++ args) c.2 result
let result := visitStar result
match k with
| .star => result
| _ => visitNonStar k args result
def getMatchRoot (d : DiscrTree α) (k : Key) (args : Array Expr) (result : Array α) : Array α :=
match d.root.find? k with
| none => result
| some c => getMatchLoop args c result
def getStarResult (d : DiscrTree α) : Array α :=
let result : Array α := .mkEmpty initCapacity
match d.root.find? .star with
| none => result
| some (.node vs _) => result ++ vs
def getMatchCore (d : DiscrTree α) (e : Expr) : Key × Array α :=
let result := getStarResult d
let (k, args) := getKeyArgs e
match k with
| .star => (k, result)
| _ => (k, getMatchRoot d k args result)
/--
Retrieves all values whose patterns match the expression `e`.
-/
public def getMatch (d : DiscrTree α) (e : Expr) : Array α :=
getMatchCore d e |>.2
/--
Retrieves all values whose patterns match a prefix of `e`, along with the number of
extra (ignored) arguments.
This is useful for rewriting: if a pattern matches `f x` but `e` is `f x y z`, we can
still apply the rewrite and return `(value, 2)` indicating 2 extra arguments.
-/
public partial def getMatchWithExtra (d : DiscrTree α) (e : Expr) : Array (α × Nat) :=
let (k, result) := getMatchCore d e
let result := result.map (·, 0)
if !e.isApp then
result
else if !mayMatchPrefix k then
result
else
go e.appFn! 1 result
where
mayMatchPrefix (k : Key) : Bool :=
let cont (k : Key) : Bool :=
if d.root.find? k |>.isSome then
true
else
mayMatchPrefix k
match k with
| .const f (n+1) => cont (.const f n)
| .fvar f (n+1) => cont (.fvar f n)
| _ => false
go (e : Expr) (numExtra : Nat) (result : Array (α × Nat)) : Array (α × Nat) :=
let result := result ++ (getMatchCore d e).2.map (., numExtra)
if e.isApp then
go e.appFn! (numExtra + 1) result
else
result
end Lean.Meta.Sym

View File

@@ -8,6 +8,7 @@ prelude
public import Lean.Meta.Sym.SimpM
public import Lean.Meta.Sym.SimpFun
import Lean.Meta.Sym.InstantiateS
import Lean.Meta.Sym.DiscrTree
namespace Lean.Meta.Sym.Simp
open Grind
@@ -41,8 +42,8 @@ public def Theorem.rewrite? (thm : Theorem) (e : Expr) : SimpM (Option Result) :
return none
public def rewrite : SimpFun := fun e => do
-- **TODO**: use indexing
for thm in ( read).thms.thms do
-- **TODO**: over-applied terms
for thm in ( read).thms.getMatch e do
if let some result thm.rewrite? e then
return result
return { expr := e }

View File

@@ -7,6 +7,7 @@ module
prelude
public import Lean.Meta.Sym.SymM
public import Lean.Meta.Sym.Pattern
import Lean.Meta.Sym.DiscrTree
public section
namespace Lean.Meta.Sym.Simp
@@ -129,10 +130,18 @@ structure Theorem where
/-- Right-hand side of the equation. -/
rhs : Expr
instance : BEq Theorem where
beq thm₁ thm₂ := thm₁.expr == thm₂.expr
/-- Collection of simplification theorems available to the simplifier. -/
structure Theorems where
/-- **TODO**: No indexing for now. We will add a structural discrimination tree later. -/
thms : Array Theorem := #[]
thms : DiscrTree Theorem := {}
def Theorems.insert (thms : Theorems) (thm : Theorem) : Theorems :=
{ thms with thms := insertPattern thms.thms thm.pattern thm }
def Theorems.getMatch (thms : Theorems) (e : Expr) : Array Theorem :=
Sym.getMatch thms.thms e
/-- Read-only context for the simplifier. -/
structure Context where
@@ -182,7 +191,7 @@ abbrev getCache : SimpM Cache :=
end Simp
public def simp (e : Expr) (thms : Simp.Theorems := {}) (config : Simp.Config := {}) : SymM Simp.Result := do
def simp (e : Expr) (thms : Simp.Theorems := {}) (config : Simp.Config := {}) : SymM Simp.Result := do
Simp.SimpM.run (Simp.simp e) thms config
end Lean.Meta.Sym

View File

@@ -8,8 +8,9 @@ namespace SimpBench
-/
def mkSimpTheorems : MetaM Sym.Simp.Theorems := do
let thm Sym.Simp.mkTheoremFromDecl ``Nat.zero_add
return { thms := #[thm] }
let result : Sym.Simp.Theorems := {}
let result := result.insert ( Sym.Simp.mkTheoremFromDecl ``Nat.zero_add)
return result
def simp (e : Expr) : MetaM (Sym.Simp.Result × Float) := Sym.SymM.run' do
let e Grind.shareCommon e