fix: eta-reduce expressions in sym discrimination tree lookup (#12920)

This PR adds eta reduction to the sym discrimination tree lookup
functions (`getMatch`, `getMatchWithExtra`, `getMatchLoop`). Without
this, expressions like `StateM Nat` that unfold to eta-expanded forms
`(fun α => StateT Nat Id α)` fail to match discrimination tree entries
for the eta-reduced form `(StateT Nat Id)`.

Also optimizes `etaReduce` with an early exit for non-lambda expressions
and removes a redundant `n == 0` check.
Includes a test verifying that `P (StateM Nat)` matches a disc tree
entry for `P (StateT Nat Id)`.

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Leonardo de Moura
2026-03-14 09:57:10 -07:00
committed by GitHub
parent c2d4079193
commit 7120d9aef5
3 changed files with 26 additions and 3 deletions

View File

@@ -30,12 +30,12 @@ then returns `f`. Otherwise, returns `e`.
Returns the original expression when not reducible to enable pointer equality checks.
-/
public def etaReduce (e : Expr) : Expr :=
go e 0
if e.isLambda then go e 0 else e
where
go (body : Expr) (n : Nat) : Expr :=
match body with
| .lam _ _ b _ => go b (n+1)
| _ => if n == 0 then e else etaReduceAux body n 0 e
| _ => etaReduceAux body n 0 e
/-- Returns `true` if `e` can be eta-reduced. Uses pointer equality for efficiency. -/
public def isEtaReducible (e : Expr) : Bool :=

View File

@@ -8,6 +8,7 @@ prelude
public import Lean.Meta.Sym.Pattern
public import Lean.Meta.DiscrTree.Basic
import Lean.Meta.Sym.Offset
import Lean.Meta.Sym.Eta
import Init.Omega
namespace Lean.Meta.Sym
open DiscrTree
@@ -154,7 +155,7 @@ partial def getMatchLoop (todo : Array Expr) (c : Trie α) (result : Array α) :
else if h : csize = 0 then
result
else
let e := todo.back!
let e := etaReduce todo.back!
let todo := todo.pop
let first := cs[0] /- Recall that `Key.star` is the minimal key -/
if csize = 1 then
@@ -184,6 +185,7 @@ public def getMatch (d : DiscrTree α) (e : Expr) : Array α :=
let result := match d.root.find? .star with
| none => .mkEmpty initCapacity
| some (.node vs _) => vs
let e := etaReduce e
match d.root.find? (getKey e) with
| none => result
| some c => getMatchLoop (pushArgsTodo #[] e) c result
@@ -196,6 +198,7 @@ This is useful for rewriting: if a pattern matches `f x` but `e` is `f x y z`, w
still apply the rewrite and return `(value, 2)` indicating 2 extra arguments.
-/
public partial def getMatchWithExtra (d : DiscrTree α) (e : Expr) : Array (α × Nat) :=
let e := etaReduce e
let result := getMatch d e
let result := result.map (·, 0)
if !e.isApp then

View File

@@ -0,0 +1,20 @@
import Lean.Meta.Sym
open Lean Meta Sym
opaque P : (Type Type) Prop
axiom stateT_P : P (StateT Nat Id)
/--
info: disc tree lookup match count: 1
-/
#guard_msgs in
#eval show MetaM Unit from do
-- `StateM Nat` unfolds to `fun α => StateT Nat Id α` (an eta redex)
let e Sym.unfoldReducible (mkApp (mkConst ``P) (mkApp (mkConst ``StateM [0]) (mkConst ``Nat)))
-- Verify the eta redex is present
assert! e.appArg!.isLambda
let pat mkPatternFromDecl ``stateT_P
let dt := Sym.insertPattern (Lean.Meta.DiscrTree.empty (α := Unit)) pat ()
let nMatches := (Sym.getMatch dt e).size
logInfo m!"disc tree lookup match count: {nMatches}"