Compare commits

...

2 Commits

Author SHA1 Message Date
Leonardo de Moura
832f5a3ef8 perf: precise cache for foldConsts
It addresses a performance issue at https://github.com/leanprover/LNSym/blob/proof_size_expt/Proofs/SHA512/Experiments/Sym20.lean
2024-07-30 11:18:09 -07:00
Leonardo de Moura
bdd4d42559 chore: remove unnecessary import 2024-07-30 11:16:34 -07:00
2 changed files with 23 additions and 41 deletions

View File

@@ -5,7 +5,6 @@ Authors: Leonardo de Moura
-/
prelude
import Lean.Expr
import Lean.Util.PtrSet
namespace Lean
namespace Expr

View File

@@ -11,52 +11,35 @@ namespace Lean
namespace Expr
namespace FoldConstsImpl
abbrev cacheSize : USize := 8192 - 1
unsafe structure State where
visited : PtrSet Expr := mkPtrSet
visitedConsts : NameHashSet := {}
structure State where
visitedTerms : Array Expr -- Remark: cache based on pointer address. Our "unsafe" implementation relies on the fact that `()` is not a valid Expr
visitedConsts : NameHashSet -- cache based on structural equality
unsafe abbrev FoldM := StateM State
abbrev FoldM := StateM State
unsafe def visited (e : Expr) (size : USize) : FoldM Bool := do
let s get
let h := ptrAddrUnsafe e
let i := h % size
let k := s.visitedTerms.uget i lcProof
if ptrAddrUnsafe k == h then pure true
else do
modify fun s => { s with visitedTerms := s.visitedTerms.uset i e lcProof }
pure false
unsafe def fold {α : Type} (f : Name α α) (size : USize) (e : Expr) (acc : α) : FoldM α :=
unsafe def fold {α : Type} (f : Name α α) (e : Expr) (acc : α) : FoldM α :=
let rec visit (e : Expr) (acc : α) : FoldM α := do
if ( visited e size) then
pure acc
else
match e with
| Expr.forallE _ d b _ => visit b ( visit d acc)
| Expr.lam _ d b _ => visit b ( visit d acc)
| Expr.mdata _ b => visit b acc
| Expr.letE _ t v b _ => visit b ( visit v ( visit t acc))
| Expr.app f a => visit a ( visit f acc)
| Expr.proj _ _ b => visit b acc
| Expr.const c _ =>
let s get
if s.visitedConsts.contains c then
pure acc
else do
modify fun s => { s with visitedConsts := s.visitedConsts.insert c };
pure $ f c acc
| _ => pure acc
if ( get).visited.contains e then
return acc
modify fun s => { s with visited := s.visited.insert e }
match e with
| .forallE _ d b _ => visit b ( visit d acc)
| .lam _ d b _ => visit b ( visit d acc)
| .mdata _ b => visit b acc
| .letE _ t v b _ => visit b ( visit v ( visit t acc))
| .app f a => visit a ( visit f acc)
| .proj _ _ b => visit b acc
| .const c _ =>
if ( get).visitedConsts.contains c then
return acc
else
modify fun s => { s with visitedConsts := s.visitedConsts.insert c };
return f c acc
| _ => return acc
visit e acc
unsafe def initCache : State :=
{ visitedTerms := mkArray cacheSize.toNat (cast lcProof ()),
visitedConsts := {} }
@[inline] unsafe def foldUnsafe {α : Type} (e : Expr) (init : α) (f : Name α α) : α :=
(fold f cacheSize e init).run' initCache
(fold f e init).run' {}
end FoldConstsImpl