Compare commits

...

2 Commits

Author SHA1 Message Date
Leonardo de Moura
563f5448dc feat: add backward chaining rule application to Sym
This PR adds `BackwardRule` for efficient goal transformation via
backward chaining in `SymM`.

`BackwardRule` stores a theorem expression, precomputed pattern for
fast unification, and argument indices that become new subgoals. The
subgoal ordering lists non-dependent goals first to match the behavior
of `MetaM.apply`.

`BackwardRule.apply` unifies the goal type with the rule's pattern,
assigns the goal metavariable to the theorem application, and returns
new subgoals for unassigned arguments.
2025-12-29 16:02:27 -08:00
Leonardo de Moura
3e411844b6 chore: rename 2025-12-29 14:50:32 -08:00
5 changed files with 149 additions and 9 deletions

View File

@@ -18,6 +18,7 @@ public import Lean.Meta.Sym.InstantiateMVarsS
public import Lean.Meta.Sym.ProofInstInfo
public import Lean.Meta.Sym.AbstractS
public import Lean.Meta.Sym.Pattern
public import Lean.Meta.Sym.Apply
/-!
# Symbolic simulation support.

View File

@@ -0,0 +1,118 @@
/-
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Sym.Pattern
import Lean.Util.CollectFVars
namespace Lean.Meta.Sym
/--
A rule for backward chaining (goal transformation).
Given a goal `... ⊢ T`, applying a `BackwardRule` derived from a theorem `∀ xs, P`
will unify `T` with `P`, assign the goal to the theorem application,
and return new goals for the unassigned arguments in `xs`.
-/
public structure BackwardRule where
/-- The theorem used to create the rule. It is often of the form `Expr.const declName`. -/
expr : Expr
/-- Precomputed pattern for efficient unification. -/
pattern : Pattern
/--
Indices of arguments that become new subgoals, ordered with
non-dependent goals first. -/
resultPos : List Nat
/--
Computes which argument positions become new subgoals after applying a backward rule.
Arguments are excluded from `resultPos` if:
- They appear in the conclusion (will be determined by unification)
- They are instance arguments (will be synthesized)
The result is ordered with non-dependent arguments first, then dependent ones.
This ordering is the same one used for the `MetaM` `apply` tactic.
It improves the user experience: non-dependent goals can be solved in
any order, while dependent goals are often resolved by solving the non-dependent ones first.
Example: `Exists.intro` produces two subgoal `?h : ?p ?w` and `?w : ?α`. The goal `?h` appears
first because solving it often solves `?w`.
-/
def mkResultPos (pattern : Pattern) : List Nat := Id.run do
let auxPrefix := `_sym_pre
-- Initialize "found" mask with arguments that can be synthesized by type class resolution.
let mut found := pattern.isInstance
let numArgs := pattern.varTypes.size
let auxVars := pattern.varTypes.mapIdx fun i _ => mkFVar .num auxPrefix i
-- Collect arguments that occur in the pattern
for fvarId in collectFVars {} (pattern.pattern.instantiateRev auxVars) |>.fvarIds do
let .num pre idx := fvarId.name | pure ()
if pre == auxPrefix then
found := found.set! idx true
let argTypes := pattern.varTypes.mapIdx fun i type => type.instantiateRevRange 0 i auxVars
-- Collect dependent and non-dependent arguments that become new goals
-- An argument is considered dependent only if there is another goal whose type depends on it.
let mut deps := #[]
let mut nonDeps := #[]
for i in *...numArgs do
unless found[i]! do
let auxVar := auxVars[i]!
let mut isDep := false
for j in (i+1)...numArgs do
unless found[j]! do
let argType := argTypes[j]!
if argType.containsFVar auxVar.fvarId! then
isDep := true
break
if isDep then
deps := deps.push i
else
nonDeps := nonDeps.push i
return (nonDeps ++ deps).toList
/--
Creates a `BackwardRule` from a declaration name.
The `num?` parameter optionally limits how many arguments are included in the pattern
(useful for partially applying theorems).
-/
public def mkBackwardRuleFromDecl (declName : Name) (num? : Option Nat := none) : MetaM BackwardRule := do
let pattern mkPatternFromDecl declName num?
let resultPos := mkResultPos pattern
return { expr := mkConst declName, pattern, resultPos }
/--
Creates a value to assign to input goal metavariable using unification result.
Handles both constant expressions (common case, avoids `instantiateLevelParams`)
and general expressions.
-/
def mkValue (expr : Expr) (pattern : Pattern) (result : MatchUnifyResult) : Expr :=
if let .const declName [] := expr then
mkAppN (mkConst declName result.us) result.args
else
mkAppN (expr.instantiateLevelParams pattern.levelParams result.us) result.args
/--
Applies a backward rule to a goal, returning new subgoals.
1. Unifies the goal type with the rule's pattern
2. Assigns the goal metavariable to the theorem application
3. Returns new goals for unassigned arguments (per `resultPos`)
Throws an error if unification fails.
-/
public def BackwardRule.apply (goal : Goal) (rule : BackwardRule) : SymM (List Goal) := goal.withContext do
let type goal.mvarId.getType
if let some result rule.pattern.unify? type then
goal.mvarId.assign (mkValue rule.expr rule.pattern result)
return rule.resultPos.map fun i =>
let mvarId := result.args[i]!.mvarId!
{ goal with mvarId }
else
throwError "rule is not applicable to goal{goal.mvarId}\nrule:{indentExpr rule.expr}"
end Lean.Meta.Sym

View File

@@ -81,7 +81,7 @@ Universe level parameters are replaced with fresh unification variables (prefixe
If `num?` is `some n`, at most `n` leading quantifiers are stripped.
If `num?` is `none`, all leading quantifiers are stripped.
-/
public def mkPatternFromTheorem (declName : Name) (num? : Option Nat := none) : MetaM Pattern := do
public def mkPatternFromDecl (declName : Name) (num? : Option Nat := none) : MetaM Pattern := do
let info getConstInfo declName
let levelParams := info.levelParams.mapIdx fun i _ => Name.num uvarPrefix i
let us := levelParams.map mkLevelParam

View File

@@ -3,13 +3,13 @@ open Lean Meta Sym Grind
set_option grind.debug true
opaque p : Nat Prop
opaque q : Nat Nat Prop
axiom pax : p x
def ex := x : Nat, p x x = .zero
def test : SymM Unit := do
let pEx mkPatternFromTheorem ``Exists.intro
let pAnd mkPatternFromTheorem ``And.intro
let pEq mkPatternFromTheorem ``Eq.refl
def test1 : SymM Unit := do
let pEx mkPatternFromDecl ``Exists.intro
let pAnd mkPatternFromDecl ``And.intro
let pEq mkPatternFromDecl ``Eq.refl
let e shareCommon ( getConstInfo ``ex).value!
let some r₁ pEx.match? e | throwError "failed"
logInfo <| mkAppN (mkConst ``Exists.intro r₁.us) r₁.args
@@ -27,4 +27,25 @@ info: @Eq.refl Nat Nat.zero
-/
#guard_msgs in
set_option pp.explicit true in
#eval SymM.run' test
#eval SymM.run' test1
def test2 : SymM Unit := do
let ruleEx mkBackwardRuleFromDecl ``Exists.intro
let ruleAnd mkBackwardRuleFromDecl ``And.intro
let ruleRefl mkBackwardRuleFromDecl ``Eq.refl
let rulePax mkBackwardRuleFromDecl ``pax
let mvar mkFreshExprMVar ( getConstInfo ``ex).value!
let goal Sym.mkGoal mvar.mvarId!
let [goal, _] ruleEx.apply goal | throwError "Failed"
let [goal₁, goal₂] ruleAnd.apply goal | throwError "Failed"
let [] rulePax.apply goal₁ | throwError "Failed"
let [] ruleRefl.apply goal₂ | throwError "Failed"
logInfo mvar
/--
info: @Exists.intro Nat (fun x => And (p x) (@Eq Nat x Nat.zero)) Nat.zero
(@And.intro (p Nat.zero) (@Eq Nat Nat.zero Nat.zero) (@pax Nat.zero) (@Eq.refl Nat Nat.zero))
-/
#guard_msgs in
set_option pp.explicit true in
#eval SymM.run' test2

View File

@@ -8,7 +8,7 @@ opaque b : Int
def ex₁ := p (a + 1) b
def test₁ : SymM Unit := do
let pEx mkPatternFromTheorem ``pax
let pEx mkPatternFromDecl ``pax
let e shareCommon ( getConstInfo ``ex₁).value!
let some r₁ pEx.match? e | throwError "failed"
let h := mkAppN (mkConst ``pax r₁.us) r₁.args
@@ -34,7 +34,7 @@ def ex₂ := ∀ x, q x 0 ∧ q (f (f x)) (f x + f (f 1))
def test₂ : SymM Unit := do
/- We use `some 5` because we want the pattern to be `(∀ x, ?P x ∧ ?Q x)`-/
let p mkPatternFromTheorem ``mk_forall_and (some 5)
let p mkPatternFromDecl ``mk_forall_and (some 5)
let e shareCommon ( getConstInfo ``ex₂).value!
logInfo p.pattern
logInfo e