Compare commits

...

4 Commits

Author SHA1 Message Date
Leonardo de Moura
b3b07f1b13 test: 2026-01-11 18:14:27 -08:00
Leonardo de Moura
1e01ed29c9 fix: type check pattern vars 2026-01-11 18:11:00 -08:00
Leonardo de Moura
9cf9829ae9 fix: missing check 2026-01-11 16:25:29 -08:00
Leonardo de Moura
14cf40d6d0 feat: helper functions for debugging 2026-01-11 16:25:03 -08:00
6 changed files with 242 additions and 36 deletions

View File

@@ -42,6 +42,10 @@ framework (`Sym`). The design prioritizes performance by using a two-phase appro
- `instantiateRevS` ensures maximal sharing of result expressions
-/
/-- Helper function for checking whether types `α` and `β` are definitionally equal during unification/matching. -/
def isDefEqTypes (α β : Expr) : MetaM Bool := do
withReducible <| isDefEq α β
/--
Collects `ProofInstInfo` for all function symbols occurring in `pattern`.
@@ -56,11 +60,18 @@ def mkProofInstInfoMapFor (pattern : Expr) : MetaM (AssocList Name ProofInstInfo
return fnInfos
public structure Pattern where
levelParams : List Name
varTypes : Array Expr
isInstance : Array Bool
pattern : Expr
fnInfos : AssocList Name ProofInstInfo
levelParams : List Name
varTypes : Array Expr
isInstance : Array Bool
pattern : Expr
fnInfos : AssocList Name ProofInstInfo
/--
If `checkTypeMask? = some mask`, then we must check the type of pattern variable `i`
if `mask[i]` is true.
Moreover `mask.size == varTypes.size`.
See `mkCheckTypeMask`
-/
checkTypeMask? : Option (Array Bool)
deriving Inhabited
def uvarPrefix : Name := `_uvar
@@ -79,6 +90,65 @@ def preprocessPattern (declName : Name) : MetaM (List Name × Expr) := do
let type preprocessType type
return (levelParams, type)
/--
Creates a mask indicating which pattern variables require type checking during matching.
When matching a pattern against a target expression, we must ensure that pattern variable
assignments are type-correct. However, checking types for every variable is expensive.
This function identifies which variables actually need type checking.
**Key insight**: A pattern variable appearing as an argument to a function application
does not need its type checked separately, because the type information is already
encoded in the application structure, and we assume the input is type correct.
**Variables that need type checking**:
- Variables in function position: `f x` where `f` is a pattern variable
- Variables in binder domains or bodies: `∀ x : α, β` or `fun x : α => b`
- Variables appearing alone (not as part of any application)
**Variables that skip type checking**:
- Variables appearing only as arguments to applications: in `f x`, the variable `x`
does not need checking because the type of `f` constrains the type of `x`
**Examples**:
- `bv0_eq (x : BitVec 0) : x = 0`: pattern is just `x`, must check type to ensure `BitVec 0`
- `forall_true : (∀ _ : α, True) = True`: `α` appears in binder domain, must check
- `Nat.add_zero (x : Nat) : x + 0 = x`: `x` is argument to `HAdd.hAdd`, no check needed
**Note**: This analysis is conservative. It may mark some variables for checking even when
the type information is redundant (already determined by other constraints). This is
harmless—just extra work, not incorrect behavior.
Returns an array of booleans parallel to the pattern's `varTypes`, where `true` indicates
the variable's type must be checked against the matched subterm's type.
-/
def mkCheckTypeMask (pattern : Expr) (numPatternVars : Nat) : Array Bool :=
let mask := Array.replicate numPatternVars false
go pattern 0 false mask
where
go (e : Expr) (offset : Nat) (isArg : Bool) : Array Bool Array Bool :=
match e with
| .app f a => go f offset isArg go a offset true
| .letE .. => unreachable! -- We zeta-reduce at `preprocessType`
| .const .. | .fvar _ | .sort _ | .mvar _ | .lit _ => id
| .mdata _ b => go b offset isArg
| .proj .. => id -- Should not occur in patterns
| .forallE _ d b _
| .lam _ d b _ => go d offset false go b (offset+1) false
| .bvar idx => fun mask =>
if idx >= offset && !isArg then
let idx := idx - offset
mask.set! (mask.size - idx - 1) true
else
mask
def mkPatternCore (levelParams : List Name) (varTypes : Array Expr) (isInstance : Array Bool)
(pattern : Expr) : MetaM Pattern := do
let fnInfos mkProofInstInfoMapFor pattern
let checkTypeMask := mkCheckTypeMask pattern varTypes.size
let checkTypeMask? := if checkTypeMask.all (· == false) then none else some checkTypeMask
return { levelParams, varTypes, isInstance, pattern, fnInfos, checkTypeMask? }
/--
Creates a `Pattern` from the type of a theorem.
@@ -100,9 +170,7 @@ public def mkPatternFromDecl (declName : Name) (num? : Option Nat := none) : Met
if i < num then
if let .forallE _ d b _ := type then
return ( go (i+1) b (varTypes.push d) (isInstance.push (isClass? ( getEnv) d).isSome))
let pattern := type
let fnInfos mkProofInstInfoMapFor pattern
return { levelParams, varTypes, isInstance, pattern, fnInfos }
mkPatternCore levelParams varTypes isInstance type
go 0 type #[] #[]
/--
@@ -123,9 +191,8 @@ public def mkEqPatternFromDecl (declName : Name) : MetaM (Pattern × Expr) := do
return ( go b (varTypes.push d) (isInstance.push (isClass? ( getEnv) d).isSome))
else
let_expr Eq _ lhs rhs := type | throwError "resulting type for `{.ofConstName declName}` is not an equality"
let pattern := lhs
let fnInfos mkProofInstInfoMapFor pattern
return ({ levelParams, varTypes, isInstance, pattern, fnInfos }, rhs)
let pattern mkPatternCore levelParams varTypes isInstance lhs
return (pattern, rhs)
go type #[] #[]
structure UnifyM.Context where
@@ -139,6 +206,11 @@ structure UnifyM.State where
ePending : Array (Expr × Expr) := #[]
uPending : Array (Level × Level) := #[]
iPending : Array (Expr × Expr) := #[]
/--
Contains the index of the pattern variables that we must check whether its type
matches the type of the value assigned to it.
-/
tPending : Array Nat := #[]
us : List Level := []
args : Array Expr := #[]
@@ -153,6 +225,14 @@ def pushLevelPending (u : Level) (v : Level) : UnifyM Unit :=
def pushInstPending (p : Expr) (e : Expr) : UnifyM Unit :=
modify fun s => { s with iPending := s.iPending.push (p, e) }
/--
Mark pattern variable `i` for type checking. That is, at the end of phase 1
we must check whether the type of this pattern variable is compatible with the type of
the value assigned to it.
-/
def pushCheckTypePending (i : Nat) : UnifyM Unit :=
modify fun s => { s with tPending := s.tPending.push i }
def assignExprIfUnassigned (bidx : Nat) (e : Expr) : UnifyM Unit := do
let s get
let i := s.eAssignment.size - bidx - 1
@@ -169,6 +249,8 @@ def assignExpr (bidx : Nat) (e : Expr) : UnifyM Bool := do
return true
else
modify fun s => { s with eAssignment := s.eAssignment.set! i (some e) }
if ( read).pattern.checkTypeMask?.isSome then
pushCheckTypePending i
return true
def assignLevel (uidx : Nat) (u : Level) : UnifyM Bool := do
@@ -369,6 +451,11 @@ structure DefEqM.Context where
If `unify` is `false`, it contains which variables can be assigned.
-/
mvarsNew : Array MVarId := #[]
/--
If a metavariable is in this collection, when we perform the assignment `?m := v`,
we must check whether their types are compatible.
-/
mvarsToCheckType : Array MVarId := #[]
abbrev DefEqM := ReaderT DefEqM.Context SymM
@@ -481,6 +568,12 @@ def mayAssign (t s : Expr) : SymM Bool := do
let tMaxFVarDecl tMaxFVarId.getDecl
return tMaxFVarDecl.index sMaxFVarDecl.index
@[inline] def whenUndefDo (x : DefEqM LBool) (k : DefEqM Bool) : DefEqM Bool := do
match ( x) with
| .true => return true
| .false => return false
| .undef => k
/--
Attempts to solve a unification constraint `t =?= s` where `t` has the form `?m a₁ ... aₙ`
and satisfies the Miller pattern condition (all `aᵢ` are distinct, newly-introduced free variables).
@@ -495,17 +588,20 @@ The `tFn` parameter must equal `t.getAppFn` (enforced by the proof argument).
Remark: `t` may be of the form `?m`.
-/
def tryAssignMillerPattern (tFn : Expr) (t : Expr) (s : Expr) (_ : tFn = t.getAppFn) : DefEqM Bool := do
let .mvar mvarId := tFn | return false
if !( isAssignableMVar mvarId) then return false
if !( isMillerPatternArgs t) then return false
def tryAssignMillerPattern (tFn : Expr) (t : Expr) (s : Expr) (_ : tFn = t.getAppFn) : DefEqM LBool := do
let .mvar mvarId := tFn | return .undef
if !( isAssignableMVar mvarId) then return .undef
if !( isMillerPatternArgs t) then return .undef
let s if t.isApp then
mkLambdaFVarsS t.getAppArgs s
else
pure s
if !( mayAssign tFn s) then return false
if !( mayAssign tFn s) then return .undef
if ( read).mvarsToCheckType.contains mvarId then
unless ( Sym.isDefEqTypes ( mvarId.getDecl).type ( inferType s)) do
return .false
mvarId.assign s
return true
return .true
/--
Structural definitional equality for applications without `ProofInstInfo`.
@@ -531,6 +627,11 @@ where
if ( mvarId.isAssigned) then return false
if !( isAssignableMVar mvarId) then return false
if !( mayAssign t s) then return false
/-
**Note**: we don't need to check the type of `mvarId` here even if the variable is marked for
checking. This is the case because `tryAssignUnassigned` is invoked only from a context where `t` and `s` are the arguments
of function applications.
-/
mvarId.assign s
return true
@@ -619,11 +720,10 @@ def isDefEqMainImpl (t : Expr) (s : Expr) : DefEqM Bool := do
isDefEqMain ( instantiateMVarsS t) s
else if ( isAssignedMVar sFn) then
isDefEqMain t ( instantiateMVarsS s)
else if ( tryAssignMillerPattern tFn t s rfl) then
return true
else if ( tryAssignMillerPattern sFn s t rfl) then
return true
else if let .fvar fvarId₁ := t then
else
whenUndefDo (tryAssignMillerPattern tFn t s rfl) do
whenUndefDo (tryAssignMillerPattern sFn s t rfl) do
if let .fvar fvarId₁ := t then
unless ( read).zetaDelta do return false
let some val₁ fvarId₁.getValue? | return false
isDefEqMain val₁ s
@@ -634,17 +734,19 @@ def isDefEqMainImpl (t : Expr) (s : Expr) : DefEqM Bool := do
else
isDefEqApp tFn t s rfl
abbrev DefEqM.run (unify := true) (zetaDelta := true) (mvarsNew : Array MVarId := #[]) (x : DefEqM α) : SymM α := do
abbrev DefEqM.run (unify := true) (zetaDelta := true) (mvarsNew : Array MVarId := #[])
(mvarsToCheckType : Array MVarId := #[]) (x : DefEqM α) : SymM α := do
let lctx getLCtx
let lctxInitialNextIndex := lctx.decls.size
x { zetaDelta, lctxInitialNextIndex, unify, mvarsNew }
x { zetaDelta, lctxInitialNextIndex, unify, mvarsNew, mvarsToCheckType }
/--
A lightweight structural definitional equality for the symbolic simulation framework.
Unlike the full `isDefEq`, it avoids expensive operations while still supporting Miller pattern unification.
-/
public def isDefEqS (t : Expr) (s : Expr) (unify := true) (zetaDelta := true) (mvarsNew : Array MVarId := #[]) : SymM Bool := do
DefEqM.run (unify := unify) (zetaDelta := zetaDelta) (mvarsNew := mvarsNew) do
public def isDefEqS (t : Expr) (s : Expr) (unify := true) (zetaDelta := true)
(mvarsNew : Array MVarId := #[]) (mvarsToCheckType : Array MVarId := #[]): SymM Bool := do
DefEqM.run (unify := unify) (zetaDelta := zetaDelta) (mvarsNew := mvarsNew) (mvarsToCheckType := mvarsToCheckType) do
isDefEqMain t s
def noPending : UnifyM Bool := do
@@ -655,7 +757,11 @@ def instantiateLevelParamsS (e : Expr) (paramNames : List Name) (us : List Level
-- We do not assume `e` is maximally shared
shareCommon (e.instantiateLevelParams paramNames us)
def mkPreResult : UnifyM Unit := do
inductive MkPreResultResult where
| failed
| success (mvarsToCheckType : Array MVarId)
def mkPreResult : UnifyM MkPreResultResult := do
let us ( get).uAssignment.toList.mapM fun
| some val => pure val
| none => mkFreshLevelMVar
@@ -663,9 +769,20 @@ def mkPreResult : UnifyM Unit := do
let varTypes := pattern.varTypes
let isInstance := pattern.isInstance
let eAssignment := ( get).eAssignment
let tPending := ( get).tPending
let mut args := #[]
let mut mvarsToCheckType := #[]
for h : i in *...eAssignment.size do
if let .some val := eAssignment[i] then
if tPending.contains i then
let type := varTypes[i]!
let type instantiateLevelParamsS type pattern.levelParams us
let type instantiateRevBetaS type args
let valType inferType val
-- **Note**: we have to use the default `isDefEq` because the type of `val`
-- is not necessarily normalized.
unless ( isDefEqTypes type valType) do
return .failed
args := args.push val
else
let type := varTypes[i]!
@@ -677,8 +794,12 @@ def mkPreResult : UnifyM Unit := do
continue
let mvar mkFreshExprMVar type
let mvar shareCommon mvar
if let some mask := ( read).pattern.checkTypeMask? then
if mask[i]! then
mvarsToCheckType := mvarsToCheckType.push mvar.mvarId!
args := args.push mvar
modify fun s => { s with args, us }
return .success mvarsToCheckType
def processPendingLevel : UnifyM Bool := do
let uPending := ( get).uPending
@@ -704,7 +825,7 @@ def processPendingInst : UnifyM Bool := do
return false
return true
def processPendingExpr : UnifyM Bool := do
def processPendingExpr (mvarsToCheckType : Array MVarId) : UnifyM Bool := do
let ePending := ( get).ePending
if ePending.isEmpty then return true
let pattern := ( read).pattern
@@ -715,7 +836,7 @@ def processPendingExpr : UnifyM Bool := do
let mvarsNew := if unify then #[] else args.filterMap fun
| .mvar mvarId => some mvarId
| _ => none
DefEqM.run unify zetaDelta mvarsNew do
DefEqM.run unify zetaDelta mvarsNew mvarsToCheckType do
for (t, s) in ePending do
let t instantiateLevelParamsS t pattern.levelParams us
let t instantiateRevBetaS t args
@@ -723,11 +844,11 @@ def processPendingExpr : UnifyM Bool := do
return false
return true
def processPending : UnifyM Bool := do
def processPending (mvarsToCheckType : Array MVarId) : UnifyM Bool := do
if ( noPending) then
return true
else
processPendingLevel <&&> processPendingInst <&&> processPendingExpr
processPendingLevel <&&> processPendingInst <&&> processPendingExpr mvarsToCheckType
abbrev UnifyM.run (pattern : Pattern) (unify : Bool) (zetaDelta : Bool) (k : UnifyM α) : SymM α := do
let eAssignment := pattern.varTypes.map fun _ => none
@@ -745,9 +866,11 @@ def mkResult : UnifyM MatchUnifyResult := do
def main (p : Pattern) (e : Expr) (unify : Bool) (zetaDelta : Bool) : SymM (Option (MatchUnifyResult)) :=
UnifyM.run p unify zetaDelta do
unless ( process p.pattern e) do return none
mkPreResult
unless ( processPending) do return none
return some ( mkResult)
match ( mkPreResult) with
| .failed => return none
| .success mvarsToCheckType =>
unless ( processPending mvarsToCheckType) do return none
return some ( mkResult)
/--
Attempts to match expression `e` against pattern `p` using purely syntactic matching.

View File

@@ -17,3 +17,4 @@ public import Lean.Meta.Sym.Simp.Theorems
public import Lean.Meta.Sym.Simp.Have
public import Lean.Meta.Sym.Simp.Lambda
public import Lean.Meta.Sym.Simp.Forall
public import Lean.Meta.Sym.Simp.Debug

View File

@@ -0,0 +1,46 @@
/-
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Sym.Simp.SimpM
import Lean.Meta.Sym.Simp.Theorems
import Lean.Meta.Sym.Simp.Rewrite
import Lean.Meta.Sym.Util
import Lean.Meta.Tactic.Util
import Lean.Meta.AppBuilder
namespace Lean.Meta.Sym
open Simp
/-!
Helper functions for debugging purposes and creating tests.
-/
public def mkMethods (declNames : Array Name) : MetaM Methods := do
let mut thms : Theorems := {}
for declName in declNames do
thms := thms.insert ( mkTheoremFromDecl declName)
return { post := thms.rewrite }
public def simpWith (k : Expr SymM Result) (mvarId : MVarId) : MetaM (Option MVarId) := SymM.run do
let mvarId preprocessMVar mvarId
let decl mvarId.getDecl
let target := decl.type
match ( k target) with
| .rfl _ => throwError "`Sym.simp` made no progress "
| .step target' h _ =>
let mvarNew mkFreshExprSyntheticOpaqueMVar target' decl.userName
let h mkAppM ``Eq.mpr #[h, mvarNew]
mvarId.assign h
if target'.isTrue then
mvarNew.mvarId!.assign (mkConst ``True.intro)
return none
else
return some mvarNew.mvarId!
public def simpGoal (declNames : Array Name) (mvarId : MVarId) : MetaM (Option MVarId) := SymM.run do
let methods mkMethods declNames
simpWith (simp · methods) mvarId
end Lean.Meta.Sym

View File

@@ -33,7 +33,10 @@ public def Theorem.rewrite (thm : Theorem) (e : Expr) : SimpM Result := do
let rhs := thm.rhs.instantiateLevelParams thm.pattern.levelParams result.us
let rhs shareCommonInc rhs
let expr instantiateRevBetaS rhs result.args
return .step expr proof
if isSameExpr e expr then
return .rfl
else
return .step expr proof
else
return .rfl

View File

@@ -101,7 +101,7 @@ invalidating the cache and causing O(2^n) behavior on conditional trees.
/-- Configuration options for the structural simplifier. -/
structure Config where
/-- Maximum number of steps that can be performed by the simplifier. -/
maxSteps : Nat := 0
maxSteps : Nat := 1000
-- **TODO**: many are still missing
/--

View File

@@ -0,0 +1,33 @@
import Lean
open Lean Meta Elab Tactic
theorem bv0_eq (x : BitVec 0) : x = 0 := BitVec.of_length_zero
set_option warn.sorry false
elab "sym_simp" "[" declNames:ident,* "]" : tactic => do
let declNames declNames.getElems.mapM resolveGlobalConstNoOverload
liftMetaTactic1 <| Sym.simpGoal declNames
theorem heq_self : (x x) = True := by simp
theorem forall_true {α : Sort u} : ( _ : α, True) = True := by simp
example : x + 0 x := by
fail_if_success sym_simp []
sym_simp [Nat.add_zero, heq_self]
example : 0 + x + 0 = x := by
sym_simp [Nat.add_zero, Nat.zero_add, eq_self]
example : x = x := by
sym_simp [bv0_eq, eq_self]
example (x y : BitVec 0) : x = y := by
sym_simp [bv0_eq, eq_self]
example : x, 0 + x + 0 = x := by
sym_simp [Nat.add_zero, Nat.zero_add, eq_self]
sym_simp [forall_true]
example : x, 0 + x + 0 = x := by
sym_simp [Nat.add_zero, Nat.zero_add, eq_self, forall_true]