Compare commits

...

7 Commits

Author SHA1 Message Date
Leonardo de Moura
ed15c06a2a chore: dead code 2026-01-19 20:46:41 -08:00
Leonardo de Moura
c1f14642d9 feat: offset terms 2026-01-19 20:44:56 -08:00
Leonardo de Moura
7827bf5a2d feat: add BackwardRule.apply? 2026-01-19 20:44:56 -08:00
Leonardo de Moura
9dd1e81eb9 refactor: add Sym/LitValues 2026-01-19 20:44:56 -08:00
Leonardo de Moura
594681c5d7 fix: missing withContext 2026-01-19 20:44:56 -08:00
Leonardo de Moura
9a80ea6bc9 chore: helper function 2026-01-19 20:44:56 -08:00
Leonardo de Moura
7045b69e26 chore: use simp only in benchmark 2026-01-19 20:44:56 -08:00
9 changed files with 297 additions and 107 deletions

View File

@@ -103,7 +103,19 @@ Applies a backward rule to a goal, returning new subgoals.
2. Assigns the goal metavariable to the theorem application
3. Returns new goals for unassigned arguments (per `resultPos`)
Throws an error if unification fails.
Returns `none` if unification fails.
-/
public def BackwardRule.apply? (mvarId : MVarId) (rule : BackwardRule) : SymM (Option (List MVarId)) := mvarId.withContext do
let decl mvarId.getDecl
if let some result rule.pattern.unify? decl.type then
mvarId.assign (mkValue rule.expr rule.pattern result)
return some <| rule.resultPos.map fun i =>
result.args[i]!.mvarId!
else
return none
/--
Similar to `BackwardRule.apply?`, but throws an error if unification fails.
-/
public def BackwardRule.apply (mvarId : MVarId) (rule : BackwardRule) : SymM (List MVarId) := mvarId.withContext do
let decl mvarId.getDecl

View File

@@ -0,0 +1,76 @@
/-
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Expr
public import Init.Data.Rat
public section
namespace Lean.Meta.Sym
/-!
Pure functions for extracting values. They are pure (`OptionT Id`) rather than monadic (`MetaM`).
This is possible because `Sym` assumes terms are in canonical form, no `whnf` or
reduction is needed to recognize literals.
-/
def getNatValue? (e : Expr) : OptionT Id Nat := do
let_expr OfNat.ofNat _ n _ := e | failure
let .lit (.natVal n) := n | failure
return n
def getIntValue? (e : Expr) : OptionT Id Int := do
let_expr Neg.neg _ _ a := e | getNatValue? e
let v : Int getNatValue? a
return -v
def getRatValue? (e : Expr) : OptionT Id Rat := do
let_expr HDiv.hDiv _ _ _ _ n d := e | getIntValue? e
let n : Rat getIntValue? n
let d : Rat getNatValue? d
return n / d
structure BitVecValue where
n : Nat
val : BitVec n
def getBitVecValue? (e : Expr) : OptionT Id BitVecValue :=
match_expr e with
| BitVec.ofNat nExpr vExpr => do
let n getNatValue? nExpr
let v getNatValue? vExpr
return n, BitVec.ofNat n v
| BitVec.ofNatLT nExpr vExpr _ => do
let n getNatValue? nExpr
let v getNatValue? vExpr
return n, BitVec.ofNat n v
| OfNat.ofNat α v _ => do
let_expr BitVec n := α | failure
let n getNatValue? n
let .lit (.natVal v) := v | failure
return n, BitVec.ofNat n v
| _ => failure
def getUInt8Value? (e : Expr) : OptionT Id UInt8 := return UInt8.ofNat ( getNatValue? e)
def getUInt16Value? (e : Expr) : OptionT Id UInt16 := return UInt16.ofNat ( getNatValue? e)
def getUInt32Value? (e : Expr) : OptionT Id UInt32 := return UInt32.ofNat ( getNatValue? e)
def getUInt64Value? (e : Expr) : OptionT Id UInt64 := return UInt64.ofNat ( getNatValue? e)
def getInt8Value? (e : Expr) : OptionT Id Int8 := return Int8.ofInt ( getIntValue? e)
def getInt16Value? (e : Expr) : OptionT Id Int16 := return Int16.ofInt ( getIntValue? e)
def getInt32Value? (e : Expr) : OptionT Id Int32 := return Int32.ofInt ( getIntValue? e)
def getInt64Value? (e : Expr) : OptionT Id Int64 := return Int64.ofInt ( getIntValue? e)
structure FinValue where
n : Nat
val : Fin n
def getFinValue? (e : Expr) : OptionT Id FinValue := do
let_expr OfNat.ofNat α v _ := e | failure
let_expr Fin n := α | failure
let n getNatValue? n
let .lit (.natVal v) := v | failure
if h : n = 0 then failure else
let : NeZero n := h
return { n, val := Fin.ofNat n v }
end Lean.Meta.Sym

View File

@@ -0,0 +1,93 @@
/-
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Sym.LitValues
public section
namespace Lean.Meta.Sym
/-!
# Offset representation for natural number expressions
This module provides utilities for representing `Nat` expressions in the form `e + k`,
where `e` is an arbitrary expression and `k` is a constant.
This normalization is used during pattern matching and unification.
-/
/--
Represents a natural number expression as a base plus a constant offset.
- `.num k` represents the literal `k`
- `.add e k` represents `e + k`
Used for pattern matching and unification.
-/
inductive Offset where
| num (k : Nat)
| add (e : Expr) (k : Nat)
deriving Inhabited
/-- Increments the constant part of the offset by `k'`. -/
def Offset.inc : Offset Nat Offset
| .num k, k' => .num (k+k')
| .add e k, k' => .add e (k+k')
/--
Returns `some offset` if `e` is an offset term. That is, it is of the form
- `Nat.succ a`, OR
- `a + k` where `k` is a numeral.
Assumption: standard instances are used for `OfNat Nat n` and `HAdd Nat Nat Nat`
-/
partial def isOffset? (e : Expr) : OptionT Id Offset :=
match_expr e with
| Nat.succ a => do
return get a |>.inc 1
| HAdd.hAdd α _ _ _ a b => do
guard (α.isConstOf ``Nat)
let n getNatValue? b
return get a |>.inc n
| _ => failure
where
get (e : Expr) : Offset :=
isOffset? e |>.getD (.add e 0)
/-- Variant of `isOffset?` that first checks if `declName` is `Nat.succ` or `HAdd.hAdd`. -/
def isOffset?' (declName : Name) (p : Expr) : OptionT Id Offset := do
guard (declName == ``Nat.succ || declName == ``HAdd.hAdd)
isOffset? p
/-- Returns `true` if `e` is an offset term.-/
partial def isOffset (e : Expr) : Bool :=
match_expr e with
| Nat.succ _ => true
| HAdd.hAdd α _ _ _ _ b =>
α.isConstOf ``Nat &&
match_expr b with
| OfNat.ofNat _ n _ => (n matches .lit (.natVal _))
| _ => false
| _ => false
/-- Variant of `isOffset?` that first checks if `declName` is `Nat.succ` or `HAdd.hAdd`. -/
def isOffset' (declName : Name) (p : Expr) : Bool :=
(declName == ``Nat.succ || declName == ``HAdd.hAdd) && isOffset p
/--
Converts the given expression into an offset.
Assumptions:
- `e` has type `Nat`.
- standard instances are used for `OfNat Nat n` and `HAdd Nat Nat Nat`.
-/
partial def toOffset (e : Expr) : Offset :=
match_expr e with
| Nat.succ a => toOffset a |>.inc 1
| HAdd.hAdd _ _ _ _ a b => Id.run do
let some n := getNatValue? b | .add e 0
toOffset a |>.inc n
| OfNat.ofNat _ n _ => Id.run do
let .lit (.natVal n) := n | .add e 0
.num n
| _ => .add e 0
end Lean.Meta.Sym

View File

@@ -16,6 +16,8 @@ import Lean.Meta.Sym.IsClass
import Lean.Meta.Sym.MaxFVar
import Lean.Meta.Sym.ProofInstInfo
import Lean.Meta.Sym.AlphaShareBuilder
import Lean.Meta.Sym.LitValues
import Lean.Meta.Sym.Offset
namespace Lean.Meta.Sym
open Internal
@@ -347,13 +349,43 @@ where
let some value fvarId.getValue? | return false
process p value
processApp (p : Expr) (e : Expr) : UnifyM Bool := do
let f := p.getAppFn
let .const declName _ := f | processAppDefault p e
processOffset (p : Offset) (e : Offset) : UnifyM Bool := do
-- **Note** Recall that we don't assume patterns are maximally shared terms.
match p, e with
| .num _, .num _ => unreachable!
| .num k₁, .add e k₂ =>
if k₁ < k₂ then return false
process (mkNatLit (k₁ - k₂)) e
| .add p k₁, .num k₂ =>
if k₂ < k₁ then return false
process p ( share (mkNatLit (k₂ - k₁)))
| .add p k₁, .add e k₂ =>
if k₁ == k₂ then
process p e
else if k₁ < k₂ then
if k₁ == 0 then return false
process p ( share (mkNatAdd e (mkNatLit (k₂ - k₁))))
else
if k₂ == 0 then return false
process (mkNatAdd p (mkNatLit (k₁ - k₂))) e
processConstApp (declName : Name) (p : Expr) (e : Expr) : UnifyM Bool := do
let some info := ( read).pattern.fnInfos.find? declName | process.processAppDefault p e
let numArgs := p.getAppNumArgs
processAppWithInfo p e (numArgs - 1) info
processApp (p : Expr) (e : Expr) : UnifyM Bool := withIncRecDepth do
let f := p.getAppFn
let .const declName _ := f | processAppDefault p e
if ( processConstApp declName p e) then
return true
else if let some p' := isOffset?' declName p then
processOffset p' (toOffset e)
else if let some e' := isOffset? e then
processOffset (toOffset p) e'
else
return false
processAppWithInfo (p : Expr) (e : Expr) (i : Nat) (info : ProofInstInfo) : UnifyM Bool := do
let .app fp ap := p | if e.isApp then return false else process p e
let .app fe ae := e | checkLetVar p e

View File

@@ -43,7 +43,7 @@ public def simpWith (k : Expr → SymM Result) (mvarId : MVarId) : MetaM (Option
else
return some mvarNew.mvarId!
public def simpGoal (declNames : Array Name) (mvarId : MVarId) : MetaM (Option MVarId) := SymM.run do
public def simpGoal (declNames : Array Name) (mvarId : MVarId) : MetaM (Option MVarId) := SymM.run do mvarId.withContext do
let methods mkMethods declNames
simpWith (simp · methods) mvarId

View File

@@ -7,6 +7,7 @@ module
prelude
public import Lean.Meta.Sym.Pattern
public import Lean.Meta.DiscrTree.Basic
import Lean.Meta.Sym.Offset
namespace Lean.Meta.Sym
open DiscrTree
@@ -77,7 +78,7 @@ def pushArgsUsingInfo (infos : Array ProofInstArgInfo) (i : Nat) (e : Expr) (tod
Computes the discrimination tree key for an expression and pushes its subterms onto the todo stack.
Returns `Key.star` for bound variables and `noindex`-annotated terms.
-/
def pushArgs (fnInfos : AssocList Name ProofInstInfo) (todo : Array Expr) (e : Expr) : Key × Array Expr :=
def pushArgs (root : Bool) (fnInfos : AssocList Name ProofInstInfo) (todo : Array Expr) (e : Expr) : Key × Array Expr :=
if hasNoindexAnnotation e then
(.star, todo)
else
@@ -87,12 +88,15 @@ def pushArgs (fnInfos : AssocList Name ProofInstInfo) (todo : Array Expr) (e : E
| .bvar _ => (.star, todo)
| .forallE _ d b _ => (.arrow, todo.push b |>.push d)
| .const declName _ =>
let numArgs := e.getAppNumArgs
let todo := if let some info := fnInfos.find? declName then
pushArgsUsingInfo info.argsInfo (numArgs - 1) e todo
if !root && isOffset' declName e then
(.star, todo)
else
pushAllArgs e todo
(.const declName numArgs, todo)
let numArgs := e.getAppNumArgs
let todo := if let some info := fnInfos.find? declName then
pushArgsUsingInfo info.argsInfo (numArgs - 1) e todo
else
pushAllArgs e todo
(.const declName numArgs, todo)
| .fvar fvarId =>
let numArgs := e.getAppNumArgs
let todo := pushAllArgs e todo
@@ -100,14 +104,14 @@ def pushArgs (fnInfos : AssocList Name ProofInstInfo) (todo : Array Expr) (e : E
| _ => (.other, todo)
/-- Work-list based traversal that builds the key sequence for a pattern. -/
partial def mkPathAux (fnInfos : AssocList Name ProofInstInfo) (todo : Array Expr) (keys : Array Key) : Array Key :=
partial def mkPathAux (root : Bool) (fnInfos : AssocList Name ProofInstInfo) (todo : Array Expr) (keys : Array Key) : Array Key :=
if todo.isEmpty then
keys
else
let e := todo.back!
let todo := todo.pop
let (k, todo) := pushArgs fnInfos todo e
mkPathAux fnInfos todo (keys.push k)
let (k, todo) := pushArgs root fnInfos todo e
mkPathAux false fnInfos todo (keys.push k)
def initCapacity := 8
@@ -115,7 +119,7 @@ def initCapacity := 8
public def Pattern.mkDiscrTreeKeys (p : Pattern) : Array Key :=
let todo : Array Expr := .mkEmpty initCapacity
let keys : Array Key := .mkEmpty initCapacity
mkPathAux p.fnInfos (todo.push p.pattern) keys
mkPathAux true p.fnInfos (todo.push p.pattern) keys
/-- Inserts a pattern into a discrimination tree, associating it with value `v`. -/
public def insertPattern [BEq α] (d : DiscrTree α) (p : Pattern) (v : α) : DiscrTree α :=

View File

@@ -8,6 +8,7 @@ prelude
public import Lean.Meta.Sym.Simp.SimpM
import Init.Sym.Lemmas
import Init.Data.Int.Gcd
import Lean.Meta.Sym.LitValues
namespace Lean.Meta.Sym.Simp
/-!
@@ -21,9 +22,7 @@ performance issues in the standard `Meta.Simp` simprocs.
### 1. Pure Value Extraction
The `getValue?` functions are pure (`OptionT Id`) rather than monadic (`MetaM`).
This is possible because `Sym` assumes terms are in canonical form, no `whnf` or
reduction is needed to recognize literals.
It uses the pure `getValue?` functions defined in `Lean.Meta.Sym.LitValues`.
### 2. Proof by Definitional Equality
@@ -69,65 +68,6 @@ def skipIfUnchanged (e : Expr) (result : Result) : Result :=
| .step e' _ _ => if isSameExpr e e' then .rfl else result
| _ => result
def getNatValue? (e : Expr) : OptionT Id Nat := do
let_expr OfNat.ofNat _ n _ := e | failure
let .lit (.natVal n) := n | failure
return n
def getIntValue? (e : Expr) : OptionT Id Int := do
let_expr Neg.neg _ _ a := e | getNatValue? e
let v : Int getNatValue? a
return -v
def getRatValue? (e : Expr) : OptionT Id Rat := do
let_expr HDiv.hDiv _ _ _ _ n d := e | getIntValue? e
let n : Rat getIntValue? n
let d : Rat getNatValue? d
return n / d
structure BitVecValue where
n : Nat
val : BitVec n
def getBitVecValue? (e : Expr) : OptionT Id BitVecValue :=
match_expr e with
| BitVec.ofNat nExpr vExpr => do
let n getNatValue? nExpr
let v getNatValue? vExpr
return n, BitVec.ofNat n v
| BitVec.ofNatLT nExpr vExpr _ => do
let n getNatValue? nExpr
let v getNatValue? vExpr
return n, BitVec.ofNat n v
| OfNat.ofNat α v _ => do
let_expr BitVec n := α | failure
let n getNatValue? n
let .lit (.natVal v) := v | failure
return n, BitVec.ofNat n v
| _ => failure
def getUInt8Value? (e : Expr) : OptionT Id UInt8 := return UInt8.ofNat ( getNatValue? e)
def getUInt16Value? (e : Expr) : OptionT Id UInt16 := return UInt16.ofNat ( getNatValue? e)
def getUInt32Value? (e : Expr) : OptionT Id UInt32 := return UInt32.ofNat ( getNatValue? e)
def getUInt64Value? (e : Expr) : OptionT Id UInt64 := return UInt64.ofNat ( getNatValue? e)
def getInt8Value? (e : Expr) : OptionT Id Int8 := return Int8.ofInt ( getIntValue? e)
def getInt16Value? (e : Expr) : OptionT Id Int16 := return Int16.ofInt ( getIntValue? e)
def getInt32Value? (e : Expr) : OptionT Id Int32 := return Int32.ofInt ( getIntValue? e)
def getInt64Value? (e : Expr) : OptionT Id Int64 := return Int64.ofInt ( getIntValue? e)
structure FinValue where
n : Nat
val : Fin n
def getFinValue? (e : Expr) : OptionT Id FinValue := do
let_expr OfNat.ofNat α v _ := e | failure
let_expr Fin n := α | failure
let n getNatValue? n
let .lit (.natVal v) := v | failure
if h : n = 0 then failure else
let : NeZero n := h
return { n, val := Fin.ofNat n v }
abbrev evalUnary [ToExpr α] (toValue? : Expr Option α) (op : α α) (a : Expr) : SimpM Result := do
let some a := toValue? a | return .rfl
let e share <| toExpr (op a)

View File

@@ -177,39 +177,14 @@ def Goal (n : Nat) : Prop :=
set_option maxHeartbeats 0
set_option maxRecDepth 100000
/-!
`MetaM` Solution
-/
open Lean Meta Elab Tactic
/-
A tactic for solving goal `Goal n`
-/
macro "solve" : tactic => `(tactic| {
unfold Goal;
intros m l;
dsimp [generated_cmd, repeated_cmds];
apply Exec.seq_cps;
apply Exec.input;
intros v;
repeat (
apply Exec.seq_cps;
apply Exec.set;
simp [Expr.eval];
simp [PartialMap.get_put_diff, PartialMap.get_put, PartialMap.put_put, Binop.interp_add,
Binop.interp_sub, Word.add_sub_cancel];
try rfl);
apply Exec.skip;
simp [PartialMap.get_put, PartialMap.put_put, PartialMap.get_put_diff]
})
/--
Solves a goal of the form `Goal n` using the `solve` tactic.
-/
def solveUsingMeta (n : Nat) (check := true) : MetaM Unit := do
let mvar mkFreshExprMVar (mkApp (mkConst ``Goal) (mkNatLit n))
/-- Helper function for executing a tactic `k` for solving `Goal n`. -/
def driver (n : Nat) (check := true) (k : MVarId MetaM Unit) : MetaM Unit := do
let some goal unfoldDefinition? (mkApp (mkConst ``Goal) (mkNatLit n)) | throwError "UNFOLD FAILED!"
let mvar mkFreshExprMVar goal
let startTime IO.monoNanosNow
let ([], _) runTactic mvar.mvarId! ( `(tactic| solve)).raw {} {} | throwError "FAILED!"
k mvar.mvarId!
let endTime IO.monoNanosNow
let ms := (endTime - startTime).toFloat / 1000000.0
if check then
@@ -221,6 +196,38 @@ def solveUsingMeta (n : Nat) (check := true) : MetaM Unit := do
else
IO.println s!"goal_{n}: {ms} ms"
/-!
`MetaM` Solution
-/
/-
A tactic for solving goal `Goal n`
-/
macro "solve" : tactic => `(tactic| {
intros m l;
simp only [generated_cmd, repeated_cmds];
apply Exec.seq_cps;
apply Exec.input;
intros v;
repeat (
apply Exec.seq_cps;
apply Exec.set;
simp only [Expr.eval];
simp only [PartialMap.get_put_diff, PartialMap.get_put, PartialMap.put_put, Binop.interp_add,
Binop.interp_sub, Word.add_sub_cancel, Option.some.injEq, not_false_eq_true, String.reduceEq, ne_eq];
try rfl);
apply Exec.skip;
simp only [List.cons.injEq, IOEvent.IN.injEq, and_true, PartialMap.put_put, PartialMap.get_put,
Option.some.injEq, and_self, exists_eq']
})
/--
Solves a goal of the form `Goal n` using the `solve` tactic.
-/
def solveUsingMeta (n : Nat) (check := true) : MetaM Unit := do
driver n check fun mvarId => do
let ([], _) runTactic mvarId ( `(tactic| solve)).raw {} {} | throwError "FAILED!"
def runBenchUsingMeta : MetaM Unit := do
IO.println "=== Symbolic Simulation Tests ==="
IO.println ""

View File

@@ -169,3 +169,29 @@ example (f g : Nat → Nat → Nat) (h : a = b) : (bif a + 0 != b then id f else
trace_state -- `cond` branches should not have been simplified
subst h
sym_simp [Nat.add_zero, bne_self_eq_false, id_eq]
def pw (n : Nat) : Nat :=
match n with
| 0 => 1
| n+1 => 2 * pw n
example : pw 0 = 1 := by
sym_simp [pw.eq_1]
example : pw 2 = 4 := by
sym_simp [pw.eq_1, pw.eq_2]
example : pw 4 = 16 := by
sym_simp [pw.eq_1, pw.eq_2]
example : pw (a + 2) = 2 * (2 * pw a) := by
sym_simp [pw.eq_2]
example : pw (Nat.succ a) = 2 * pw a := by
sym_simp [pw.eq_2]
example : pw (a + 3) = 2 * (2 * (2 * pw a)) := by
sym_simp [pw.eq_2]
example : pw (Nat.succ (Nat.succ a)) = 2 * (2 * pw a) := by
sym_simp [pw.eq_2]