Compare commits

...

1 Commits

Author SHA1 Message Date
Leonardo de Moura
29aa90c790 feat: eta-reduction support in SymM
This PR adds support for eta-reduction in `SymM`.
2026-01-26 13:21:22 -08:00
5 changed files with 129 additions and 3 deletions

View File

@@ -23,6 +23,7 @@ public import Lean.Meta.Sym.Apply
public import Lean.Meta.Sym.InferType
public import Lean.Meta.Sym.Simp
public import Lean.Meta.Sym.Util
public import Lean.Meta.Sym.Eta
public import Lean.Meta.Sym.Grind
/-!

View File

@@ -0,0 +1,53 @@
/-
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Sym.ExprPtr
public import Lean.Meta.Basic
import Lean.Meta.Transform
namespace Lean.Meta.Sym
/--
Checks if `body` is eta-expanded with `n` applications: `f (.bvar (n-1)) ... (.bvar 0)`.
Returns `f` if so and `f` has no loose bvars; otherwise returns `default`.
- `n`: number of remaining applications to check
- `i`: expected bvar index (starts at 0, increments with each application)
- `default`: returned when not eta-reducible (enables pointer equality check)
-/
def etaReduceAux (body : Expr) (n : Nat) (i : Nat) (default : Expr) : Expr := Id.run do
match n with
| 0 => if body.hasLooseBVars then default else body
| n+1 =>
let .app f (.bvar j) := body | default
if j == i then etaReduceAux f n (i+1) default else default
/--
If `e` is of the form `(fun x₁ ... xₙ => f x₁ ... xₙ)` and `f` does not contain `x₁`, ..., `xₙ`,
then returns `f`. Otherwise, returns `e`.
Returns the original expression when not reducible to enable pointer equality checks.
-/
public def etaReduce (e : Expr) : Expr :=
go e 0
where
go (body : Expr) (n : Nat) : Expr :=
match body with
| .lam _ _ b _ => go b (n+1)
| _ => if n == 0 then e else etaReduceAux body n 0 e
/-- Returns `true` if `e` can be eta-reduced. Uses pointer equality for efficiency. -/
public def isEtaReducible (e : Expr) : Bool :=
!isSameExpr e (etaReduce e)
/-- Applies `etaReduce` to all subexpressions. Returns `e` unchanged if no subexpression is eta-reducible. -/
public def etaReduceAll (e : Expr) : MetaM Expr := do
unless Option.isSome <| e.find? isEtaReducible do return e
let pre (e : Expr) : MetaM TransformStep := do
let e' := etaReduce e
if isSameExpr e e' then return .continue
else return .visit e'
Meta.transform e (pre := pre)
end Lean.Meta.Sym

View File

@@ -18,6 +18,7 @@ import Lean.Meta.Sym.ProofInstInfo
import Lean.Meta.Sym.AlphaShareBuilder
import Lean.Meta.Sym.LitValues
import Lean.Meta.Sym.Offset
import Lean.Meta.Sym.Eta
namespace Lean.Meta.Sym
open Internal
@@ -323,7 +324,11 @@ def isAssignedMVar (e : Expr) : MetaM Bool :=
| _ => return false
partial def process (p : Expr) (e : Expr) : UnifyM Bool := do
match p with
let e' := etaReduce e
if !isSameExpr e e' then
-- **Note**: We eagerly eta reduce patterns
process p e'
else match p with
| .bvar bidx => assignExpr bidx e
| .mdata _ p => process p e
| .const declName us =>
@@ -723,7 +728,12 @@ def isDefEqApp (tFn : Expr) (t : Expr) (s : Expr) (_ : tFn = t.getAppFn) : DefEq
@[export lean_sym_def_eq]
def isDefEqMainImpl (t : Expr) (s : Expr) : DefEqM Bool := do
if isSameExpr t s then return true
match t, s with
-- **Note**: `etaReduce` is supposed to be fast, and does not allocate memory
let t' := etaReduce t
let s' := etaReduce s
if !isSameExpr t t' || !isSameExpr s s' then
isDefEqMain t' s'
else match t, s with
| .lit l₁, .lit l₂ => return l₁ == l₂
| .sort u, .sort v => isLevelDefEqS u v
| .lam .., .lam .. => isDefEqBindingS t s

View File

@@ -9,6 +9,7 @@ public import Lean.Meta.Sym.SymM
import Lean.Meta.Sym.IsClass
import Lean.Meta.Sym.Util
import Lean.Meta.Transform
import Lean.Meta.Sym.Eta
namespace Lean.Meta.Sym
/--
@@ -17,7 +18,8 @@ Preprocesses types that used for pattern matching and unification.
public def preprocessType (type : Expr) : MetaM Expr := do
let type Sym.unfoldReducible type
let type Core.betaReduce type
zetaReduce type
let type zetaReduce type
etaReduceAll type
/--
Analyzes whether the given free variables (aka arguments) are proofs or instances.

View File

@@ -0,0 +1,60 @@
import Std.Data.HashMap
import Lean.Meta.Sym
import Lean.Meta.DiscrTree.Basic
open Lean Meta Sym Grind
set_option sym.debug true
abbrev S := Nat
abbrev M α := StateM S α
def Exec (s : S) (k : M α) (post : α S Prop) : Prop :=
post (k s).1 (k s).2
theorem Exec.bind (k₁ : M α) (k₂ : α M β) (post : β S Prop) :
Exec s k₁ (fun a s₁ => Exec s₁ (k₂ a) post)
Exec s (k₁ >>= k₂) post := by
simp [Exec, Bind.bind, StateT.bind]
cases k₁ s; simp
def goal := a b, Exec b (set a >>= fun _ => get) fun v _ => v = a
set_option pp.explicit true
/-!
Recall that `SymM` patterns are eagerly eta-reduced.
Goals are not, but the pattern matcher/unifier performs eta whenever it is needed.
-/
/--
info: Pattern:
@Exec #5 #4 (@bind (StateT Nat Id) (@Monad.toBind (StateT Nat Id) (@StateT.instMonad Nat Id Id.instMonad)) #6 #5 #3 #2)
#1
---
info: a b : Nat
⊢ @Exec Nat b
(@bind (fun α => StateT Nat Id α) (@Monad.toBind (fun α => StateT Nat Id α) (@StateT.instMonad Nat Id Id.instMonad))
PUnit Nat (@set Nat (fun α => StateT Nat Id α) (@instMonadStateOfStateTOfMonad Nat Id Id.instMonad) a) fun x =>
@get Nat (fun α => StateT Nat Id α)
(@instMonadStateOfMonadStateOf Nat (fun α => StateT Nat Id α)
(@instMonadStateOfStateTOfMonad Nat Id Id.instMonad)))
fun v x => @Eq Nat v a
---
info: a b : Nat
⊢ @Exec PUnit b (@set Nat (fun α => StateT Nat Id α) (@instMonadStateOfStateTOfMonad Nat Id Id.instMonad) a)
fun a_1 s₁ =>
@Exec Nat s₁
(@get Nat (fun α => StateT Nat Id α)
(@instMonadStateOfMonadStateOf Nat (fun α => StateT Nat Id α)
(@instMonadStateOfStateTOfMonad Nat Id Id.instMonad)))
fun v x => @Eq Nat v a
-/
#guard_msgs in
run_meta SymM.run do
let bindRule mkBackwardRuleFromDecl ``Exec.bind
let a unfoldDefinition (mkConst ``goal)
logInfo m!"Pattern:\n{bindRule.pattern.pattern}"
forallTelescope a fun _ body => do
let mvar mkFreshExprMVar body
let mvarId preprocessMVar mvar.mvarId!
logInfo mvarId
let .goals [mvarId] bindRule.apply mvarId | failure
logInfo mvarId