Compare commits

...

1 Commits

Author SHA1 Message Date
Leonardo de Moura
0577032c1f refactor: simplify AlphaShareCommon.State and use mutable hashmap
This PR simplifies `AlphaShareCommon.State` by separating the
persistent and transient parts of the state.

The `map` field caches visited sub-expressions during a single
`shareCommonAlpha` call to handle DAGs efficiently, the input
expression may contain shared sub-expressions that are not yet
maximally shared. However, this cache does not need to persist between
different `shareCommonAlpha` calls.

**Changes:**
- Moved `map` from the persistent `AlphaShareCommon.State` to a private `State` used only within individual `shareCommonAlpha` calls.
- Replaced `PHashMap ExprPtr Expr` with (the more efficient) `Std.HashMap ExprPtr Expr` for `map`, since it is now local to each call and does not need persistence.
- The public `AlphaShareCommon.State` now only contains the `set` of alpha-equivalent expressions that should persist
2025-12-25 09:49:40 -08:00

View File

@@ -59,12 +59,17 @@ instance : BEq AlphaKey where
beq k₁ k₂ := private alphaEq k₁.expr k₂.expr
structure AlphaShareCommon.State where
map : PHashMap ExprPtr Expr := {}
set : PHashSet AlphaKey := {}
abbrev AlphaShareCommonM := StateM AlphaShareCommon.State
private def save (e : Expr) (r : Expr) : AlphaShareCommonM Expr := do
private structure State where
map : Std.HashMap ExprPtr Expr := {}
set : PHashSet AlphaKey := {}
private abbrev M := StateM State
private def save (e : Expr) (r : Expr) : M Expr := do
if let some r := ( get).set.find? { expr := r } then
let r := r.expr
modify fun { set, map } => {
@@ -79,35 +84,50 @@ private def save (e : Expr) (r : Expr) : AlphaShareCommonM Expr := do
}
return r
private abbrev visit (e : Expr) (k : AlphaShareCommonM Expr) : AlphaShareCommonM Expr := do
if let some r := ( get).map.find? { expr := e } then
private abbrev visit (e : Expr) (k : M Expr) : M Expr := do
/-
**Note**: The input may be a DAG, and we don't want to visit the same sub-expression
over and over again.
-/
if let some r := ( get).map[{ expr := e : ExprPtr }]? then
return r
else
/-
**Note**: The input may contain sub-expressions that have already been processed and are
already maximally shared.
-/
if let some r := ( get).set.find? { expr := e } then
return r.expr
else
save e ( k)
private def go (e : Expr) : M Expr := do
match e with
| .bvar .. | .mvar .. | .const .. | .fvar .. | .sort .. | .lit .. =>
if let some r := ( get).set.find? { expr := e } then
return r.expr
else
modify fun { set, map } => { set := set.insert { expr := e }, map }
return e
| .app f a =>
visit e (return mkApp ( go f) ( go a))
| .letE n t v b nd =>
visit e (return mkLet n t ( go v) ( go b) nd)
| .forallE n d b bi =>
visit e (return mkForall n bi ( go d) ( go b))
| .lam n d b bi =>
visit e (return mkLambda n bi ( go d) ( go b))
| .mdata d b =>
visit e (return mkMData d ( go b))
| .proj n i b =>
visit e (return mkProj n i ( go b))
/-- Similar to `shareCommon`, but handles alpha-equivalence. -/
def shareCommonAlpha (e : Expr) : AlphaShareCommonM Expr :=
go e
where
go (e : Expr) : AlphaShareCommonM Expr := do
match e with
| .bvar .. | .mvar .. | .const .. | .fvar .. | .sort .. | .lit .. =>
if let some r := ( get).set.find? { expr := e } then
return r.expr
else
modify fun { set, map } => { set := set.insert { expr := e }, map }
return e
| .app f a =>
visit e (return mkApp ( go f) ( go a))
| .letE n t v b nd =>
visit e (return mkLet n t ( go v) ( go b) nd)
| .forallE n d b bi =>
visit e (return mkForall n bi ( go d) ( go b))
| .lam n d b bi =>
visit e (return mkLambda n bi ( go d) ( go b))
| .mdata d b =>
visit e (return mkMData d ( go b))
| .proj n i b =>
visit e (return mkProj n i ( go b))
@[inline] def shareCommonAlpha (e : Expr) (s : AlphaShareCommon.State) : Expr × AlphaShareCommon.State :=
if let some r := s.set.find? { expr := e } then
(r.expr, s)
else
let (e, { set, .. }) := go e |>.run { map := {}, set := s.set }
(e, set)
end Lean.Meta.Grind