Compare commits

...

1 Commits

Author SHA1 Message Date
Leonardo de Moura
077045ef5d test: local rewrite with Sym.simp
This PR adds a new API for helping users write focused rewrites.
2026-01-24 17:24:59 -08:00
3 changed files with 41 additions and 7 deletions

View File

@@ -224,7 +224,7 @@ position. However, the type is only meaningful (non-`default`) when `Result` is
`.step`, since we only need types for constructing congruence proofs. This avoids
unnecessary type inference when no rewriting occurs.
-/
def simpFixedPrefix (e : Expr) (prefixSize : Nat) (suffixSize : Nat) : SimpM Result := do
public def simpFixedPrefix (e : Expr) (prefixSize : Nat) (suffixSize : Nat) : SimpM Result := do
let numArgs := e.getAppNumArgs
if numArgs prefixSize then
-- Nothing to be done
@@ -274,7 +274,7 @@ Uses `rewritable[i]` to determine whether argument `i` should be simplified.
For rewritable arguments, calls `simp` and uses `congrFun'`, `congrArg`, and `congr`; for fixed arguments,
uses `congrFun` to propagate changes from earlier arguments.
-/
def simpInterlaced (e : Expr) (rewritable : Array Bool) : SimpM Result := do
public def simpInterlaced (e : Expr) (rewritable : Array Bool) : SimpM Result := do
let numArgs := e.getAppNumArgs
if h : numArgs = 0 then
-- Nothing to be done

View File

@@ -22,4 +22,8 @@ public abbrev mkEqTransResult (e₁ : Expr) (e₂ : Expr) (h₁ : Expr) (r₂ :
| .rfl done => return .step e₂ h₁ done
| .step e₃ h₂ done => return .step e₃ ( mkEqTrans e₁ e₂ h₁ e₃ h₂) done
public def Result.markAsDone : Result Result
| .rfl _ => .rfl true
| .step e h _ => .step e h true
end Lean.Meta.Sym.Simp

View File

@@ -104,6 +104,29 @@ def isBind (goal : Goal) : MetaM Bool := do
let_expr Exec _ _ _ k _ := target | return false
return k.isAppOf ``Bind.bind
def simpExecState : Sym.Simp.Simproc := fun e =>
/-
**Remark**: This simproc demonstrates how to perform targeted simplification steps using `Sym.simp`.
We only want to simplify the third argument of an `Exec`-application. We accomplished that
by using this simproc as a pre-method, using `simpInterlaced` where the `rewritable` mask
instructs the function to rewrite only the third argument, and then mark the resulting term
as simplified.
-/
let_expr Exec _ _ _ _ _ := e | return .rfl
-- Simplifies only the state (the third argument)
return ( Simp.simpInterlaced e #[false, false, true, false, false]).markAsDone
theorem add_assoc_rev (a b c : Nat) : a + (b + c) = (a + b) + c := by simp +arith
def mkSimpExecStateMethods : MetaM Sym.Simp.Methods := do
-- **Note**: we don't have `simp +arith` in `Sym.simp` yet. This is just a cheap hack
-- allow `Sym.simp` to simplify terms such as `2 + (1 + s)`.
let thm Sym.Simp.mkTheoremFromDecl ``add_assoc_rev
return {
pre := simpExecState
post := Sym.Simp.evalGround.andThen thm.rewrite
}
partial def solve (mvarId : MVarId) : GrindM Unit := do
/-
Creates an `BackwardRule` for each theorem `T` we want to use `apply T`.
@@ -120,6 +143,7 @@ partial def solve (mvarId : MVarId) : GrindM Unit := do
-/
let preMethods mkSimpMethods #[``step.eq_1, ``loop.eq_1, ``loop.eq_2,
``Nat.add_zero, ``Nat.sub_zero, ``bind_pure_comp, ``map_bind, ``id_map', ``unit_map, ``bind_assoc]
let execStateMethods mkSimpExecStateMethods
-- ## Initialize
let goal mkGoal mvarId
let .goal _ goal goal.introN 1 | failure
@@ -127,9 +151,10 @@ partial def solve (mvarId : MVarId) : GrindM Unit := do
let goal goal.internalizeAll -- Internalize all hypotheses
-- ## Loop
-- We simulate the `repeat` block using a tail-recursive function `loop`
let rec loop (goal : Goal) : GrindM Goal := do
-- logInfo goal₀.mvarId
let .goals [goal] goal₀.apply execBindRule | return goal₀
let rec loop (goal : Goal) : GrindM Goal := do
let .goal goal goal.simpIgnoringNoProgress execStateMethods | failure
-- logInfo goal.mvarId
let .goals [goal] goal.apply execBindRule | return goal
let .goals [goal] goal.apply execGetRule | failure
let .goals [goal] goal.apply execBindRule | failure
let .goals [goal] goal.apply execSetRule | failure
@@ -155,5 +180,10 @@ def solveUsingGrind (n : Nat) (check := true) : MetaM Unit := do
driver n check fun mvarId => SymM.run <| GrindM.run (params := params) do
solve mvarId
-- **TODO**: the proof term grows quadratically because we are not simplifying the state
#eval solveUsingGrind 50
def runBenchUsingGrind : MetaM Unit := do
IO.println "=== Symbolic Simulation Tests ==="
IO.println ""
for n in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150] do
solveUsingGrind n
#eval runBenchUsingGrind