Compare commits

...

1 Commits

Author SHA1 Message Date
Leonardo de Moura
49a274c0f6 feat: skip proof and instance arguments during pattern matching
This PR optimizes pattern matching by skipping proof and instance
arguments during Phase 1 (syntactic matching).
2025-12-27 21:16:07 -08:00

View File

@@ -6,7 +6,9 @@ Authors: Leonardo de Moura
module
prelude
public import Lean.Meta.Sym.SymM
import Lean.Util.FoldConsts
import Lean.Meta.Sym.InstantiateS
import Lean.Meta.Sym.IsClass
namespace Lean.Meta.Sym
open Grind
@@ -32,10 +34,75 @@ framework (`Sym`). The design prioritizes performance by using a two-phase appro
- `instantiateRevS` ensures maximal sharing of result expressions
-/
def preprocessType (type : Expr) : MetaM Expr := do
let type unfoldReducible type
let type Core.betaReduce type
zetaReduce type
/--
Information about a single argument position in a function's type signature.
This is used during pattern matching to identify arguments that can be skipped
or handled specially (e.g., instance arguments can be synthesized, proof arguments
can be inferred).
-/
public structure PatternArgInfo where
/-- `true` if this argument is a proof (its type is a `Prop`). -/
isProof : Bool
/-- `true` if this argument is a type class instance. -/
isInstance : Bool
/--
Information about a function symbol occurring in a pattern.
Stores which argument positions are proofs or instances, enabling optimizations
during pattern matching such as skipping proof arguments or deferring instance synthesis.
-/
public structure FunPatternInfo where
/-- Information about each argument position. -/
argsInfo : Array PatternArgInfo
/--
Analyzes the type signature of `declName` and returns information about which arguments
are proofs or instances. Returns `none` if no arguments are proofs or instances.
-/
def mkFunPatternInfo? (declName : Name) : MetaM (Option FunPatternInfo) := do
let info getConstInfo declName
let type preprocessType info.type
forallTelescopeReducing type fun xs _ => do
let env getEnv
let mut argsInfo := #[]
let mut found := false
for x in xs do
let type inferType x
let isInstance := isClass? env type |>.isSome
let isProof isProp type
if isInstance || isProof then
found := true
argsInfo := argsInfo.push { isInstance, isProof }
if found then
return some { argsInfo }
else
return none
/--
Collects `FunPatternInfo` for all function symbols occurring in `pattern`.
Only includes functions that have at least one proof or instance argument.
-/
def mkFunInfosFor (pattern : Expr) : MetaM (AssocList Name FunPatternInfo) := do
let cs := pattern.getUsedConstants
let mut fnInfos := {}
for declName in cs do
if let some info mkFunPatternInfo? declName then
fnInfos := fnInfos.insertNew declName info
return fnInfos
public structure Pattern where
levelParams : List Name
varTypes : Array Expr
pattern : Expr
levelParams : List Name
varTypes : Array Expr
pattern : Expr
fnInfos : AssocList Name FunPatternInfo
deriving Inhabited
def uvarPrefix : Name := `_uvar
@@ -45,11 +112,6 @@ def isUVar? (n : Name) : Option Nat := Id.run do
unless p == uvarPrefix do return none
return some idx
def preprocessType (type : Expr) : MetaM Expr := do
let type unfoldReducible type
let type Core.betaReduce type
zetaReduce type
public def mkPatternFromTheorem (declName : Name) : MetaM Pattern := do
let info getConstInfo declName
let levelParams := info.levelParams.mapIdx fun i _ => Name.num uvarPrefix i
@@ -57,11 +119,14 @@ public def mkPatternFromTheorem (declName : Name) : MetaM Pattern := do
let type instantiateTypeLevelParams info.toConstantVal us
let type preprocessType type
-- **TODO**: save position of instance arguments
let rec go (type : Expr) (varTypes : Array Expr) : Pattern :=
let rec go (type : Expr) (varTypes : Array Expr) : MetaM Pattern := do
match type with
| .forallE _ d b _ => go b (varTypes.push d)
| _ => { levelParams, varTypes, pattern := type }
return go type #[]
| _ =>
let pattern := type
let fnInfos mkFunInfosFor pattern
return { levelParams, varTypes, pattern, fnInfos }
go type #[]
structure UnifyM.Context where
pattern : Pattern
@@ -72,6 +137,7 @@ structure UnifyM.State where
uAssignment : Array (Option Level) := #[]
ePending : Array (Expr × Expr) := #[]
uPending : Array (Level × Level) := #[]
iPending : Array (Expr × Expr) := #[]
us : List Level := []
args : Array Expr := #[]
@@ -83,6 +149,15 @@ def pushPending (p : Expr) (e : Expr) : UnifyM Unit :=
def pushLevelPending (u : Level) (v : Level) : UnifyM Unit :=
modify fun s => { s with uPending := s.uPending.push (u, v) }
def pushInstPending (p : Expr) (e : Expr) : UnifyM Unit :=
modify fun s => { s with iPending := s.iPending.push (p, e) }
def assignExprIfUnassigned (bidx : Nat) (e : Expr) : UnifyM Unit := do
let s get
let i := s.eAssignment.size - bidx - 1
if s.eAssignment[i]!.isNone then
modify fun s => { s with eAssignment := s.eAssignment.set! i (some e) }
def assignExpr (bidx : Nat) (e : Expr) : UnifyM Bool := do
let s get
let i := s.eAssignment.size - bidx - 1
@@ -167,12 +242,39 @@ partial def process (p : Expr) (e : Expr) : UnifyM Bool := do
| .letE .. => unreachable!
where
processApp (p : Expr) (e : Expr) : UnifyM Bool := do
-- **TODO**: Skip instance arguments, and process later
-- **TODO**: Skip proof arguments
let f := p.getAppFn
let .const declName _ := f | processAppDefault p e
let some info := ( read).pattern.fnInfos.find? declName | process.processAppDefault p e
let numArgs := p.getAppNumArgs
processAppWithInfo p e (numArgs - 1) info
processAppWithInfo (p : Expr) (e : Expr) (i : Nat) (info : FunPatternInfo) : UnifyM Bool := do
let .app fp ap := p | process p e
let .app fe ae := e | return false
unless ( processApp fp fe) do return false
unless ( processAppWithInfo fp fe (i - 1) info) do return false
if h : i < info.argsInfo.size then
let argInfo := info.argsInfo[i]
if argInfo.isInstance then
if let .bvar bidx := ap then
assignExprIfUnassigned bidx ae
else
pushInstPending ap ae
return true
else if argInfo.isProof then
if let .bvar bidx := ap then
assignExprIfUnassigned bidx ae
return true
else
process ap ae
else
process ap ae
processAppDefault (p : Expr) (e : Expr) : UnifyM Bool := do
let .app fp ap := p | process p e
let .app fe ae := e | return false
unless ( processAppDefault fp fe) do return false
process ap ae
processConst (declName : Name) (us : List Level) (e : Expr) : UnifyM Bool := do
let .const declName' us' := e | return false
unless declName == declName' do return false
@@ -183,7 +285,7 @@ where
def noPending : UnifyM Bool := do
let s get
return s.ePending.isEmpty && s.uPending.isEmpty
return s.ePending.isEmpty && s.uPending.isEmpty && s.iPending.isEmpty
def mkPreResult : UnifyM Unit := do
let us ( get).uAssignment.toList.mapM fun
@@ -232,9 +334,37 @@ def main (p : Pattern) (e : Expr) (unify : Bool) : SymM (Option (MatchUnifyResul
unless ( processPending) do return none
return some ( mkResult)
/--
Attempts to match expression `e` against pattern `p` using purely syntactic matching.
Returns `some result` if matching succeeds, where `result` contains:
- `us`: Level assignments for the pattern's universe variables
- `args`: Expression assignments for the pattern's bound variables
Matching fails if:
- The term contains metavariables (use `unify?` instead)
- Structural mismatch after reducible unfolding
Instance arguments are deferred for later synthesis. Proof arguments are
skipped via proof irrelevance.
-/
public def Pattern.match? (p : Pattern) (e : Expr) : SymM (Option (MatchUnifyResult)) :=
main p e (unify := false)
/--
Attempts to unify expression `e` against pattern `p`, allowing metavariables in `e`.
Returns `some result` if unification succeeds, where `result` contains:
- `us`: Level assignments for the pattern's universe variables
- `args`: Expression assignments for the pattern's bound variables
Unlike `match?`, this handles terms containing metavariables by deferring
constraints to Phase 2 unification. Use this when matching against goal
expressions that may contain unsolved metavariables.
Instance arguments are deferred for later synthesis. Proof arguments are
skipped via proof irrelevance.
-/
public def Pattern.unify? (p : Pattern) (e : Expr) : SymM (Option (MatchUnifyResult)) :=
main p e (unify := true)