Compare commits

...

3 Commits

Author SHA1 Message Date
Leonardo de Moura
efb647e7fe chore: remove dead code 2026-01-03 19:38:42 -08:00
Leonardo de Moura
86b2f729c2 chore: cleanup 2026-01-03 19:35:06 -08:00
Leonardo de Moura
539ca1a893 perf: Sym.Simp.DiscrTree retrieval
This PR improves the discrimination tree retrieval performance used by `Sym.simp`.
2026-01-03 19:30:29 -08:00

View File

@@ -44,16 +44,6 @@ Retrieval should use the standard `DiscrTree.getMatch` or similar, which will fi
whose key sequence is compatible with the query term.
-/
/--
Returns the number of child keys for a given discrimination tree key.
**Note**: Unlike the standard `DiscrTree` module, `Key.arrow` has arity 2.
-/
def getKeyArity : Key Nat
| .const _ a => a
| .fvar _ a => a
| .arrow => 2
| _ => 0
/-- Returns `true` if argument at position `i` should be ignored (is a proof or instance). -/
def ignoreArg (infos : Array ProofInstArgInfo) (i : Nat) : Bool :=
if h : i < infos.size then
@@ -132,69 +122,66 @@ public def insertPattern [BEq α] (d : DiscrTree α) (p : Pattern) (v : α) : Di
let keys := p.mkDiscrTreeKeys
d.insertKeyValue keys v
def getKeyArgs (e : Expr) : Key × Array Expr :=
match e.getAppFn with
| .lit v => (.lit v, #[])
| .const declName _ => (.const declName e.getAppNumArgs, e.getAppRevArgs)
| .fvar fvarId => (.fvar fvarId e.getAppNumArgs, e.getAppRevArgs)
| .forallE _ d b _ => (.arrow, #[b, d])
| _ => (.other, #[])
abbrev findKey? (cs : Array (Key × Trie α)) (k : Key) : Option (Key × Trie α) :=
cs.binSearch (k, default) (fun a b => a.1 < b.1)
def getKey (e : Expr) : Key :=
match e.getAppFn with
| .lit v => .lit v
| .const declName _ => .const declName e.getAppNumArgs
| .fvar fvarId => .fvar fvarId e.getAppNumArgs
| .forallE _ _ _ _ => .arrow
| _ => .other
/-- Push `e` arguments/children into the `todo` stack. -/
def pushArgsTodo (todo : Array Expr) (e : Expr) : Array Expr :=
match e with
| .app f a => pushArgsTodo (todo.push a) f
| .forallE _ d b _ => todo.push b |>.push d
| _ => todo
partial def getMatchLoop (todo : Array Expr) (c : Trie α) (result : Array α) : Array α :=
match c with
| .node vs cs =>
let csize := cs.size
if todo.isEmpty then
result ++ vs
else if cs.isEmpty then
else if h : csize = 0 then
result
else
let e := todo.back!
let todo := todo.pop
let first := cs[0]! /- Recall that `Key.star` is the minimal key -/
let (k, args) := getKeyArgs e
/- We must always visit `Key.star` edges since they are wildcards.
Thus, `todo` is not used linearly when there is `Key.star` edge
and there is an edge for `k` and `k != Key.star`. -/
let visitStar (result : Array α) : Array α :=
let first := cs[0] /- Recall that `Key.star` is the minimal key -/
if csize = 1 then
/- Special case: only one child node -/
if first.1 == .star then
getMatchLoop todo first.2 result
else if first.1 == getKey e then
getMatchLoop (pushArgsTodo todo e) first.2 result
else
result
else
/- We must always visit `Key.star` edges since they are wildcards.
Thus, `todo` is not used linearly when there is `Key.star` edge
and there is an edge for `k` and `k != Key.star`. -/
let result := if first.1 == .star then
getMatchLoop todo first.2 result
else
result
let visitNonStar (k : Key) (args : Array Expr) (result : Array α) : Array α :=
match findKey? cs k with
match findKey? cs (getKey e) with
| none => result
| some c => getMatchLoop (todo ++ args) c.2 result
let result := visitStar result
match k with
| .star => result
| _ => visitNonStar k args result
def getMatchRoot (d : DiscrTree α) (k : Key) (args : Array Expr) (result : Array α) : Array α :=
match d.root.find? k with
| none => result
| some c => getMatchLoop args c result
def getStarResult (d : DiscrTree α) : Array α :=
let result : Array α := .mkEmpty initCapacity
match d.root.find? .star with
| none => result
| some (.node vs _) => result ++ vs
def getMatchCore (d : DiscrTree α) (e : Expr) : Key × Array α :=
let result := getStarResult d
let (k, args) := getKeyArgs e
match k with
| .star => (k, result)
| _ => (k, getMatchRoot d k args result)
| some c => getMatchLoop (pushArgsTodo todo e) c.2 result
/--
Retrieves all values whose patterns match the expression `e`.
-/
public def getMatch (d : DiscrTree α) (e : Expr) : Array α :=
getMatchCore d e |>.2
let result := match d.root.find? .star with
| none => .mkEmpty initCapacity
| some (.node vs _) => vs
match d.root.find? (getKey e) with
| none => result
| some c => getMatchLoop (pushArgsTodo #[] e) c result
/--
Retrieves all values whose patterns match a prefix of `e`, along with the number of
@@ -204,11 +191,11 @@ This is useful for rewriting: if a pattern matches `f x` but `e` is `f x y z`, w
still apply the rewrite and return `(value, 2)` indicating 2 extra arguments.
-/
public partial def getMatchWithExtra (d : DiscrTree α) (e : Expr) : Array (α × Nat) :=
let (k, result) := getMatchCore d e
let result := getMatch d e
let result := result.map (·, 0)
if !e.isApp then
result
else if !mayMatchPrefix k then
else if !mayMatchPrefix (getKey e) then
result
else
go e.appFn! 1 result
@@ -225,7 +212,7 @@ where
| _ => false
go (e : Expr) (numExtra : Nat) (result : Array (α × Nat)) : Array (α × Nat) :=
let result := result ++ (getMatchCore d e).2.map (., numExtra)
let result := result ++ (getMatch d e).map (., numExtra)
if e.isApp then
go e.appFn! (numExtra + 1) result
else