Compare commits

...

1 Commits

Author SHA1 Message Date
Leonardo de Moura
4c1e0c354c feat: add extensible state mechanism for SymM
This PR add `SymExtension`, a typed extensible state mechanism for `SymM`,
following the same pattern as `Grind.SolverExtension`. Extensions are
registered at initialization time via `registerSymExtension` and provide
typed `getState`/`modifyState` accessors. Extension state persists across
`simp` invocations within a `sym =>` block and is re-initialized on each
`SymM.run`.

This enables modules (e.g., the upcoming arithmetic normalizer) to
register persistent state without modifying `Sym.State` directly.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-23 20:44:25 -07:00
6 changed files with 124 additions and 1 deletions

View File

@@ -15,6 +15,48 @@ register_builtin_option sym.debug : Bool := {
descr := "check invariants"
}
/-!
## Sym Extensions
Extensible state mechanism for `SymM`, allowing modules to register persistent state
that lives across `simp` invocations within a `sym =>` block. Follows the same pattern
as `Grind.SolverExtension` in `Lean/Meta/Tactic/Grind/Types.lean`.
-/
/-- Opaque extension state type used to store type-erased extension values. -/
opaque SymExtensionStateSpec : (α : Type) × Inhabited α := Unit, ()
@[expose] def SymExtensionState : Type := SymExtensionStateSpec.fst
instance : Inhabited SymExtensionState := SymExtensionStateSpec.snd
/--
A registered extension for `SymM`. Each extension gets a unique index into the
extensions array in `Sym.State`. Can only be created via `registerSymExtension`.
-/
structure SymExtension (σ : Type) where private mk ::
id : Nat
mkInitial : IO σ
deriving Inhabited
private builtin_initialize symExtensionsRef : IO.Ref (Array (SymExtension SymExtensionState)) IO.mkRef #[]
/--
Registers a new `SymM` state extension. Extensions can only be registered during initialization.
Returns a handle for typed access to the extension's state.
-/
def registerSymExtension {σ : Type} (mkInitial : IO σ) : IO (SymExtension σ) := do
unless ( initializing) do
throw (IO.userError "failed to register `Sym` extension, extensions can only be registered during initialization")
let exts symExtensionsRef.get
let id := exts.size
let ext : SymExtension σ := { id, mkInitial }
symExtensionsRef.modify fun exts => exts.push (unsafe unsafeCast ext)
return ext
/-- Returns initial state for all registered extensions. -/
def SymExtensions.mkInitialStates : IO (Array SymExtensionState) := do
let exts symExtensionsRef.get
exts.mapM fun ext => ext.mkInitial
/--
Information about a single argument position in a function's type signature.
@@ -133,6 +175,8 @@ structure State where
congrInfo : PHashMap ExprPtr CongrInfo := {}
/-- Cache for `isDefEqI` results -/
defEqI : PHashMap (ExprPtr × ExprPtr) Bool := {}
/-- State for registered `SymExtension`s, indexed by extension id. -/
extensions : Array SymExtensionState := #[]
debug : Bool := false
abbrev SymM := ReaderT Context <| StateRefT State MetaM
@@ -150,7 +194,8 @@ private def mkSharedExprs : AlphaShareCommonM SharedExprs := do
def SymM.run (x : SymM α) : MetaM α := do
let (sharedExprs, share) := mkSharedExprs |>.run {}
let debug := sym.debug.get ( getOptions)
x { sharedExprs } |>.run' { debug, share }
let extensions SymExtensions.mkInitialStates
x { sharedExprs } |>.run' { debug, share, extensions }
/-- Returns maximally shared commonly used terms -/
def getSharedExprs : SymM SharedExprs :=
@@ -230,4 +275,26 @@ def isDefEqI (s t : Expr) : SymM Bool := do
modify fun s => { s with defEqI := s.defEqI.insert key result }
return result
instance : Inhabited (SymM α) where
default := throwError "<SymM default value>"
/-! ### SymExtension accessors -/
private unsafe def SymExtension.getStateCoreImpl (ext : SymExtension σ) (extensions : Array SymExtensionState) : IO σ :=
return unsafeCast extensions[ext.id]!
@[implemented_by SymExtension.getStateCoreImpl]
opaque SymExtension.getStateCore (ext : SymExtension σ) (extensions : Array SymExtensionState) : IO σ
def SymExtension.getState (ext : SymExtension σ) : SymM σ := do
ext.getStateCore ( get).extensions
private unsafe def SymExtension.modifyStateImpl (ext : SymExtension σ) (f : σ σ) : SymM Unit := do
modify fun s => { s with
extensions := s.extensions.modify ext.id fun state => unsafeCast (f (unsafeCast state))
}
@[implemented_by SymExtension.modifyStateImpl]
opaque SymExtension.modifyState (ext : SymExtension σ) (f : σ σ) : SymM Unit
end Lean.Meta.Sym

View File

@@ -0,0 +1,28 @@
/-
Tests for `SymExtension` mechanism.
-/
module
public meta import SymExt.Decl
meta import Lean.Meta.Sym
open Lean Elab Tactic Meta Sym
run_meta SymM.run do
let c0 getMyCounter
assert! c0 == 0
incrementMyCounter
let c1 getMyCounter
assert! c1 == 1
incrementMyCounter
let c2 getMyCounter
assert! c2 == 2
run_meta do
SymM.run do
incrementMyCounter
let c getMyCounter
assert! c == 1
-- Should reset the counter
SymM.run do
let c getMyCounter
assert! c == 0

View File

@@ -0,0 +1,20 @@
/-
Declare a SymExtension and register it.
-/
module
public import Lean.Meta.Sym.SymM
open Lean Meta Sym
/-- Simple counter state for testing. -/
public structure MyExtState where
counter : Nat := 0
deriving Inhabited
initialize myExt : SymExtension MyExtState
registerSymExtension (return {})
public def getMyCounter : SymM Nat := do
return ( myExt.getState).counter
public def incrementMyCounter : SymM Unit := do
myExt.modifyState fun s => { s with counter := s.counter + 1 }

View File

@@ -0,0 +1,5 @@
import Lake
open System Lake DSL
package sym_ext
@[default_target] lean_lib SymExt

View File

@@ -0,0 +1 @@
../../../build/release/stage1

View File

@@ -0,0 +1,2 @@
rm -rf .lake/build
lake build