Compare commits

...

2 Commits

Author SHA1 Message Date
Leonardo de Moura
21dcd1048d feat: simplify lambdas in Sym.simp
This PR adds support for simplifying lambda expressions in `Sym.simp`.
It is much more efficient than standard simp for very large lambda
expressions with many binders. The key idea is to generate a custom
function extensionality theorem for the type of the lambda being simplified.
2026-01-04 16:50:17 -08:00
Leonardo de Moura
e113fff3bf feat: add mkFunextFor 2026-01-04 16:50:17 -08:00
5 changed files with 203 additions and 3 deletions

View File

@@ -0,0 +1,54 @@
/-
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Basic
import Lean.Meta.InferType
namespace Lean.Meta.Sym.Simp
/--
Given `xs` containing free variables
`(x₁ : α₁) (x₂ : α₂[x₁]) ... (xₙ : αₙ[x₁, ..., x_{n-1}])`
and `β` a type of the form `β[x₁, ..., xₙ]`,
creates the custom function extensionality theorem
```
∀ (f g : (x₁ : α₁) → (x₂ : α₂[x₁]) → ... → (xₙ : αₙ[x₁, ..., x_{n-1}]) → β[x₁, ..., xₙ])
(h : ∀ x₁ ... xₙ, f x₁ ... xₙ = g x₁ ... xₙ),
f = g
```
The theorem has three arguments `f`, `g`, and `h`.
This auxiliary theorem is used by the simplifier when visiting lambda expressions.
-/
public def mkFunextFor (xs : Array Expr) (β : Expr) : MetaM Expr := do
let type mkForallFVars xs β
let v getLevel β
withLocalDeclD `f type fun f =>
withLocalDeclD `g type fun g => do
let lhs := mkAppN f xs
let rhs := mkAppN g xs
let p := mkApp3 (mkConst ``Eq [v]) β lhs rhs
let p mkForallFVars xs p
withLocalDeclD `h p fun h => do
let mut result := mkAppN h xs |>.abstract xs
let mut i := xs.size
let mut β := β.abstract xs
let mut v := v
let mut f := mkAppN f xs |>.abstract xs
let mut g := mkAppN g xs |>.abstract xs
while i > 0 do
i := i - 1
let x := xs[i]!
let α_i inferType x
let u_i getLevel α_i
let α_i := α_i.abstractRange i xs
f := f.appFn!.lowerLooseBVars 1 1
g := g.appFn!.lowerLooseBVars 1 1
result := mkLambda `x default α_i result
result := mkApp5 (mkConst ``funext [u_i, v]) α_i (mkLambda `x .default α_i β) f g result
β := mkForall `x .default α_i β
v := mkLevelIMax' u_i v
mkLambdaFVars #[f, g, h] result
end Lean.Meta.Sym.Simp

View File

@@ -7,15 +7,35 @@ module
prelude
public import Lean.Meta.Sym.Simp.SimpM
import Lean.Meta.Tactic.Grind.AlphaShareBuilder
import Lean.Meta.Sym.InferType
import Lean.Meta.Sym.Simp.Result
import Lean.Meta.Sym.Simp.Simproc
import Lean.Meta.Sym.Simp.Congr
import Lean.Meta.Sym.Simp.Funext
namespace Lean.Meta.Sym.Simp
open Grind
def simpLambda (_ : Expr) : SimpM Result := do
-- **TODO**
return .rfl
def simpLambda (e : Expr) : SimpM Result := do
-- **TODO**: Add free variable reuse
lambdaTelescope e fun xs b => do
match ( simp b) with
| .rfl => return .rfl
| .step b' h =>
let h mkLambdaFVars xs h
-- **TODO**: Add `mkLambdaFVarsS`?
let e' shareCommonInc ( mkLambdaFVars xs b')
let funext getFunext xs b
return .step e' (mkApp3 funext e e' h)
where
getFunext (xs : Array Expr) (b : Expr) : SimpM Expr := do
let key inferType e
if let some h := ( get).funext.find? { expr := key } then
return h
else
let β inferType b
let h mkFunextFor xs β
modify fun s => { s with funext := s.funext.insert { expr := key } h }
return h
def simpForall (_ : Expr) : SimpM Result := do
-- **TODO**

View File

@@ -138,6 +138,8 @@ structure State where
binderStack : List (ExprPtr × FVarId) := []
/-- Number of steps performed so far. -/
numSteps := 0
/-- Cache for generated funext theorems -/
funext : PHashMap ExprPtr Expr := {}
/-- Monad for the structural simplifier, layered on top of `SymM`. -/
abbrev SimpM := ReaderT MethodsRef $ ReaderT Context StateRefT State SymM

View File

@@ -0,0 +1,60 @@
import Lean
open Lean Meta
opaque f : Nat Nat
namespace SimpBench
/-!
## `MetaM` Simplifier benchmarks
-/
def getProofSize (r : Simp.Result) : MetaM Nat := do
( r.getProof).numObjs
def mkSimpContext (config : Simp.Config := {}) : MetaM Simp.Context := do
let s : SimpTheorems := {}
let s s.addConst ``Nat.zero_add
let config := { config with implicitDefEqProofs := false }
Simp.mkContext config #[s] {}
def simp (e : Expr) : MetaM (Simp.Result × Float) := Sym.SymM.run' do
-- let e ← Grind.shareCommon e
let startTime IO.monoNanosNow
let (r, _) Meta.simp e ( mkSimpContext)
let endTime IO.monoNanosNow
-- logInfo e
-- logInfo r.expr
-- check (← r.getProof)
let timeMs := (endTime - startTime).toFloat / 1000000.0
return (r, timeMs)
def mkLambdaBench (n : Nat) : MetaM Expr := do
let zero := mkNatLit 0
let rec go (n : Nat) (xs : Array Expr) (e : Expr) : MetaM Expr := do
match n with
| 0 => mkLambdaFVars xs e
| n+1 =>
withLocalDeclD `x (mkConst ``Nat) fun x =>
go n (xs.push x) (mkNatAdd zero (mkNatAdd e x))
go n #[] zero
def benchLambda (n : Nat) : MetaM Unit := do
let e mkLambdaBench n
let (r, timeMs) simp e
let proofSize getProofSize r
IO.println s!"lambda_{n}: {timeMs}ms, proof_size={proofSize}"
set_option maxRecDepth 100000
/-! ## Run all benchmarks -/
def runAllBenchmarks : MetaM Unit := do
IO.println "=== Simplifier Stress Tests ==="
IO.println ""
IO.println ""
IO.println "--- Benchmark 1: Transitivity chain ---"
for n in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200] do
benchLambda n
#eval runAllBenchmarks
end SimpBench

View File

@@ -0,0 +1,64 @@
import Lean
open Lean Meta
opaque f : Nat Nat
namespace SimpBench
/-!
## `SymM` Simplifier benchmarks
-/
def getProofSize (r : Sym.Simp.Result) : MetaM Nat :=
match r with
| .rfl => return 0
| .step _ p => p.numObjs
def mkSimpMethods : MetaM Sym.Simp.Methods := do
let thms : Sym.Simp.Theorems := {}
let thm Sym.Simp.mkTheoremFromDecl ``Nat.zero_add
let thms := thms.insert thm
return { post := thms.rewrite }
def simp (e : Expr) : MetaM (Sym.Simp.Result × Float) := Sym.SymM.run' do
let e Grind.shareCommon e
let methods mkSimpMethods
let startTime IO.monoNanosNow
let r Sym.simp e methods { maxSteps := 100000000 }
let endTime IO.monoNanosNow
-- logInfo e
-- match r with
-- | .rfl => logInfo "rfl"
-- | .step e' h => logInfo e'; logInfo h; check h
let timeMs := (endTime - startTime).toFloat / 1000000.0
return (r, timeMs)
def mkLambdaBench (n : Nat) : MetaM Expr := do
let zero := mkNatLit 0
let rec go (n : Nat) (xs : Array Expr) (e : Expr) : MetaM Expr := do
match n with
| 0 => mkLambdaFVars xs e
| n+1 =>
withLocalDeclD `x (mkConst ``Nat) fun x =>
go n (xs.push x) (mkNatAdd zero (mkNatAdd e x))
go n #[] zero
def benchLambda (n : Nat) : MetaM Unit := do
let e mkLambdaBench n
let (r, timeMs) simp e
let proofSize getProofSize r
IO.println s!"lambda_{n}: {timeMs}ms, proof_size={proofSize}"
set_option maxRecDepth 100000
/-! ## Run all benchmarks -/
def runAllBenchmarks : MetaM Unit := do
IO.println "=== Simplifier Stress Tests ==="
IO.println ""
IO.println ""
IO.println "--- Benchmark 1: Lambda block ---"
for n in [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, 170, 180, 190, 200] do
benchLambda n
#eval runAllBenchmarks
end SimpBench