Compare commits

...

6 Commits

Author SHA1 Message Date
Leonardo de Moura
7596132c7f chore: update example 2026-01-21 08:49:46 -08:00
Leonardo de Moura
f4e6147941 chore: 2026-01-21 08:44:14 -08:00
Leonardo de Moura
d1d77001ec feat: add ApplyResult 2026-01-21 08:28:10 -08:00
Leonardo de Moura
6d05bcb5ea perf: use shareCommon to ensure we don't get lost in DAGs 2026-01-21 08:16:28 -08:00
Leonardo de Moura
593eae6887 feat: add simpGoal 2026-01-21 08:07:50 -08:00
Leonardo de Moura
5756d74f45 checkpoint 2026-01-21 08:07:50 -08:00
12 changed files with 181 additions and 53 deletions

View File

@@ -96,6 +96,10 @@ def mkValue (expr : Expr) (pattern : Pattern) (result : MatchUnifyResult) : Expr
else
mkAppN (expr.instantiateLevelParams pattern.levelParams result.us) result.args
public inductive ApplyResult where
| notApplicable
| goals (mvarId : List MVarId)
/--
Applies a backward rule to a goal, returning new subgoals.
@@ -103,27 +107,23 @@ Applies a backward rule to a goal, returning new subgoals.
2. Assigns the goal metavariable to the theorem application
3. Returns new goals for unassigned arguments (per `resultPos`)
Returns `none` if unification fails.
Returns `.notApplicable` if unification fails.
-/
public def BackwardRule.apply? (mvarId : MVarId) (rule : BackwardRule) : SymM (Option (List MVarId)) := mvarId.withContext do
public def BackwardRule.apply (mvarId : MVarId) (rule : BackwardRule) : SymM ApplyResult := mvarId.withContext do
let decl mvarId.getDecl
if let some result rule.pattern.unify? decl.type then
mvarId.assign (mkValue rule.expr rule.pattern result)
return some <| rule.resultPos.map fun i =>
return .goals <| rule.resultPos.map fun i =>
result.args[i]!.mvarId!
else
return none
return .notApplicable
/--
Similar to `BackwardRule.apply?`, but throws an error if unification fails.
Similar to `BackwardRule.apply', but throws an error if unification fails.
-/
public def BackwardRule.apply (mvarId : MVarId) (rule : BackwardRule) : SymM (List MVarId) := mvarId.withContext do
let decl mvarId.getDecl
if let some result rule.pattern.unify? decl.type then
mvarId.assign (mkValue rule.expr rule.pattern result)
return rule.resultPos.map fun i =>
result.args[i]!.mvarId!
else
throwError "rule is not applicable to goal{mvarId}rule:{indentExpr rule.expr}"
public def BackwardRule.apply' (mvarId : MVarId) (rule : BackwardRule) : SymM (List MVarId) := do
let .goals mvarIds rule.apply mvarId
| throwError "rule is not applicable to goal{mvarId}rule:{indentExpr rule.expr}"
return mvarIds
end Lean.Meta.Sym

View File

@@ -21,3 +21,4 @@ public import Lean.Meta.Sym.Simp.Debug
public import Lean.Meta.Sym.Simp.EvalGround
public import Lean.Meta.Sym.Simp.Discharger
public import Lean.Meta.Sym.Simp.ControlFlow
public import Lean.Meta.Sym.Simp.Goal

View File

@@ -9,6 +9,7 @@ public import Lean.Meta.Sym.Simp.SimpM
public import Lean.Meta.Sym.Simp.Discharger
import Lean.Meta.Sym.Simp.Theorems
import Lean.Meta.Sym.Simp.Rewrite
import Lean.Meta.Sym.Simp.Goal
import Lean.Meta.Sym.Util
import Lean.Meta.Tactic.Util
import Lean.Meta.AppBuilder
@@ -27,24 +28,9 @@ public def mkSimprocFor (declNames : Array Name) (d : Discharger := dischargeNon
public def mkMethods (declNames : Array Name) : MetaM Methods := do
return { post := ( mkSimprocFor declNames) }
public def simpWith (k : Expr SymM Result) (mvarId : MVarId) : MetaM (Option MVarId) := SymM.run do
let mvarId preprocessMVar mvarId
let decl mvarId.getDecl
let target := decl.type
match ( k target) with
| .rfl _ => throwError "`Sym.simp` made no progress "
| .step target' h _ =>
let mvarNew mkFreshExprSyntheticOpaqueMVar target' decl.userName
let h mkAppM ``Eq.mpr #[h, mvarNew]
mvarId.assign h
if target'.isTrue then
mvarNew.mvarId!.assign (mkConst ``True.intro)
return none
else
return some mvarNew.mvarId!
public def simpGoal (declNames : Array Name) (mvarId : MVarId) : MetaM (Option MVarId) := SymM.run do mvarId.withContext do
public def simpGoalUsing (declNames : Array Name) (mvarId : MVarId) : MetaM (Option MVarId) := SymM.run do
let methods mkMethods declNames
simpWith (simp · methods) mvarId
let mvarId preprocessMVar mvarId
( simpGoal mvarId methods).toOption
end Lean.Meta.Sym

View File

@@ -81,7 +81,7 @@ public def simpForall (e : Expr) : SimpM Result := do
else if ( isProp e) then
let n := getForallTelescopeSize e.bindingBody! 1
forallBoundedTelescope e n fun xs b => withoutModifyingCacheIfNotWellBehaved do
main xs ( share b)
main xs ( shareCommon b)
else
return .rfl
where
@@ -90,7 +90,7 @@ where
| .rfl _ => return .rfl
| .step b' h _ =>
let h mkLambdaFVars xs h
let e' share ( mkForallFVars xs b')
let e' shareCommon ( mkForallFVars xs b')
-- **Note**: consider caching the forall-congr theorems
let hcongr mkForallCongrFor xs
return .step e' (mkApp3 hcongr ( mkLambdaFVars xs b) ( mkLambdaFVars xs b') h)

View File

@@ -0,0 +1,69 @@
/-
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Meta.Sym.Simp.SimpM
import Lean.Meta.Tactic.Util
import Lean.Meta.AppBuilder
namespace Lean.Meta.Sym
/-!
# Goal simplification
Applies `Sym.simp` to a goal's target type, producing a simplified goal or closing it if
the result is `True`.
-/
/-- Result of simplifying a goal with `Sym.simp`. -/
public inductive SimpGoalResult where
/-- No simplification was possible. -/
| noProgress
/-- The goal was closed (simplified to `True`). -/
| closed
/-- The goal was simplified to a new goal. -/
| goal (mvarId : MVarId)
/--
Converts a `SimpGoalResult` to an optional goal.
Returns `none` if closed, `some mvarId` if simplified, or throws an error if no progress.
-/
public def SimpGoalResult.toOption : SimpGoalResult CoreM (Option MVarId)
| .noProgress => throwError "`Sym.simp` made no progress "
| .closed => return none
| .goal mvarId => return some mvarId
/--
Simplifies the target of `mvarId` using `Sym.simp`.
Returns `.closed` if the target simplifies to `True`, `.simp mvarId'` if simplified
to a new goal, or `.noProgress` if no simplification occurred.
This function assumed the input goal is a valid `Sym` goal (e.g., expressions are maximally shared).
-/
public def simpGoal (mvarId : MVarId) (methods : Simp.Methods := {}) (config : Simp.Config := {})
: SymM SimpGoalResult := mvarId.withContext do
let decl mvarId.getDecl
let target := decl.type
match ( simp target methods config) with
| .rfl _ => return .noProgress
| .step target' h _ =>
let mvarNew mkFreshExprSyntheticOpaqueMVar target' decl.userName
let h mkAppM ``Eq.mpr #[h, mvarNew]
mvarId.assign h
if target'.isTrue then
mvarNew.mvarId!.assign (mkConst ``True.intro)
return .closed
else
return .goal mvarNew.mvarId!
/--
Similar to `simpGoal`, but returns `.goal mvarId` if no progress was made.
-/
public def trySimpGoal (mvarId : MVarId) (methods : Simp.Methods := {}) (config : Simp.Config := {})
: SymM SimpGoalResult := do
match ( simpGoal mvarId methods config) with
| .noProgress => return .goal mvarId
| r => return r
end Lean.Meta.Sym

View File

@@ -48,14 +48,14 @@ def mkFunextFor (xs : Array Expr) (β : Expr) : MetaM Expr := do
public def simpLambda (e : Expr) : SimpM Result := do
lambdaTelescope e fun xs b => withoutModifyingCacheIfNotWellBehaved do
main xs ( share b)
main xs ( shareCommon b)
where
main (xs : Array Expr) (b : Expr) : SimpM Result := do
match ( simp b) with
| .rfl _ => return .rfl
| .step b' h _ =>
let h mkLambdaFVars xs h
let e' share ( mkLambdaFVars xs b')
let e' shareCommon ( mkLambdaFVars xs b')
let funext getFunext xs b
return .step e' (mkApp3 funext e e' h)

View File

@@ -235,3 +235,74 @@ def runBenchUsingMeta : MetaM Unit := do
solveUsingMeta n
#eval runBenchUsingMeta
-- goal_80: 1467.414291 ms, kernel: 120.162250 ms
/-!
`SymM` Solution
-/
theorem exists_eq_True (a : α) : ( x, x = a) = True := by
simp
open Sym
def mkMethods (declNames : Array Name) : MetaM Sym.Simp.Methods := do
let rewrite Sym.mkSimprocFor declNames
return {
post := Sym.Simp.evalGround.andThen rewrite
}
elab "sym_simp" "[" declNames:ident,* "]" : tactic => do
let rewrite Sym.mkSimprocFor ( declNames.getElems.mapM fun s => realizeGlobalConstNoOverload s.raw) Sym.Simp.dischargeSimpSelf
let methods : Sym.Simp.Methods := {
pre := Sym.Simp.simpControl
post := Sym.Simp.evalGround.andThen rewrite
}
liftMetaTactic1 fun mvarId => Sym.SymM.run do
let mvarId Sym.preprocessMVar mvarId
( Sym.simpGoal mvarId methods).toOption
example (l : PartialMap String Word) : ((l.put "b" x).put "a" y).get "b" = x := by
sym_simp [PartialMap.get_put_diff, PartialMap.get_put]
partial def solve (mvarId : MVarId) : SymM Unit := do
let exec_cpsRule mkBackwardRuleFromDecl ``Exec.seq_cps
let inputRule mkBackwardRuleFromDecl ``Exec.input
let skipRule mkBackwardRuleFromDecl ``Exec.skip
let setRule mkBackwardRuleFromDecl ``Exec.set
let rflRule mkBackwardRuleFromDecl ``Eq.refl
let unfoldMethods mkMethods #[``generated_cmd.eq_1, ``repeated_cmds.eq_1, ``repeated_cmds.eq_2]
let evalMethods mkMethods #[``Expr.eval.eq_1, ``Expr.eval.eq_2, ``Expr.eval.eq_3]
let simpMethods mkMethods #[``PartialMap.get_put_diff, ``PartialMap.get_put, ``PartialMap.put_put, ``Binop.interp_add,
``Binop.interp_sub, ``Word.add_sub_cancel, ``Option.some.injEq, ``not_false_eq_true, ``ne_eq]
let finalSimpMethods mkMethods #[``List.cons.injEq, ``IOEvent.IN.injEq, ``and_true, ``true_and, ``PartialMap.put_put, ``PartialMap.get_put,
``Option.some.injEq, ``and_self, ``exists_eq_True]
-- Initialize
let mvarId preprocessMVar mvarId
let (_, mvarId) Sym.introN mvarId 2
let .goal mvarId Sym.simpGoal mvarId unfoldMethods | failure
let .goals [mvarId] exec_cpsRule.apply mvarId | failure
let .goals [mvarId] inputRule.apply mvarId | failure
let (_, mvarId) Sym.introN mvarId 1
-- Loop
let rec loop (mvarId : MVarId) : SymM MVarId := do
-- mvarId.withContext do logInfo m!"{← mvarId.getType}"
let .goals [mvarId] exec_cpsRule.apply mvarId | return mvarId
let .goals [mvarId', mvarId, _] setRule.apply mvarId | failure
let .goal mvarId' Sym.simpGoal mvarId' evalMethods | failure
let .goal mvarId' Sym.simpGoal mvarId' simpMethods | failure
let .goals [] rflRule.apply mvarId' | failure
loop mvarId
let mvarId loop mvarId
let .goals [mvarId] skipRule.apply mvarId | failure
let .goal mvarId Sym.simpGoal mvarId finalSimpMethods { maxSteps := 100000 } | failure
logInfo mvarId -- **TODO**: get_put theorem is not behaving correctly
mvarId.admit
return
def solveUsingSym (n : Nat) (check := true) : MetaM Unit := do
driver n check fun mvarId => SymM.run do solve mvarId
set_option maxRecDepth 100000
#eval solveUsingSym 4

View File

@@ -13,9 +13,9 @@ def test1 : SymM Unit := do
let e shareCommon ( getConstInfo ``ex).value!
let some r₁ pEx.match? e | throwError "failed"
logInfo <| mkAppN (mkConst ``Exists.intro r₁.us) r₁.args
let some r₂ pAnd.match? ( Sym.inferType r₁.args[3]!) | throwError "failed"
let some r₂ pAnd.match? ( Sym.inferType r₁.args[3]!) | failure
logInfo <| mkAppN (mkConst ``And.intro r₂.us) r₂.args
let some r₃ pEq.unify? ( Sym.inferType r₂.args[3]!) | throwError "failed"
let some r₃ pEq.unify? ( Sym.inferType r₂.args[3]!) | failure
logInfo <| mkAppN (mkConst ``Eq.refl r₃.us) r₃.args
/--
@@ -36,10 +36,10 @@ def test2 : SymM Unit := do
let rulePax mkBackwardRuleFromDecl ``pax
let mvar mkFreshExprMVar ( getConstInfo ``ex).value!
let mvarId preprocessMVar mvar.mvarId!
let [mvarId, _] ruleEx.apply mvarId | throwError "Failed"
let [mvarId₁, mvarId₂] ruleAnd.apply mvarId | throwError "Failed"
let [] rulePax.apply mvarId₁ | throwError "Failed"
let [] ruleRefl.apply mvarId₂ | throwError "Failed"
let .goals [mvarId, _] ruleEx.apply mvarId | failure
let .goals [mvarId₁, mvarId₂] ruleAnd.apply mvarId | failure
let .goals [] rulePax.apply mvarId₁ | failure
let .goals [] ruleRefl.apply mvarId₂ | failure
logInfo mvar
/--
@@ -62,7 +62,7 @@ def test3 : SymM Unit := do
let mvar mkFreshExprMVar target
let mvarId preprocessMVar mvar.mvarId!
let rule mkBackwardRuleFromDecl ``pFoo
let [] rule.apply mvarId | throwError "failed"
let .goals [] rule.apply mvarId | failure
logInfo mvar
/-- info: pFoo (3 + y) -/
@@ -78,7 +78,7 @@ def test4 : SymM Unit := do
let target := mkApp (mkConst ``p) (mkApp2 (mkConst ``foo) x m1)
let target shareCommon target
let p mkPatternFromDecl ``pFoo
let some r p.match? target | throwError "failed"
let some r p.match? target | failure
logInfo <| mkAppN (mkConst ``pFoo r.us) r.args
/-- info: pFoo (3 + y) -/

View File

@@ -7,7 +7,7 @@ set_option warn.sorry false
elab "sym_simp" "[" declNames:ident,* "]" : tactic => do
let declNames declNames.getElems.mapM fun s => realizeGlobalConstNoOverload s.raw
liftMetaTactic1 <| Sym.simpGoal declNames
liftMetaTactic1 <| Sym.simpGoalUsing declNames
theorem heq_self : (x x) = True := by simp
theorem forall_true {α : Sort u} : ( _ : α, True) = True := by simp
@@ -115,7 +115,7 @@ example (as : Array (Nat → Nat)) (i : Nat) (_ : i < as.size) (h : as[i] a = b)
/--
trace: c a : Nat
g : Nat → Nat
h : ite (c > 0) a = g
h : ite (0 < c) a = g
⊢ ite (0 < c) a = g
-/
#guard_msgs in

View File

@@ -3,7 +3,9 @@ open Lean Meta Elab Tactic
elab "sym_simp" : tactic => do
let methods : Sym.Simp.Methods := { post := Sym.Simp.evalGround }
liftMetaTactic1 <| Sym.simpWith (Sym.simp · methods)
liftMetaTactic1 fun mvarId => Sym.SymM.run do
let mvarId Sym.preprocessMVar mvarId
( Sym.simpGoal mvarId methods).toOption
-- Basic arithmetic: Nat
example : 2 + 3 = 5 := by sym_simp

View File

@@ -7,7 +7,9 @@ elab "sym_simp" "[" declNames:ident,* "]" : tactic => do
pre := Sym.Simp.simpControl
post := Sym.Simp.evalGround.andThen rewrite
}
liftMetaTactic1 <| Sym.simpWith (Sym.simp · methods)
liftMetaTactic1 fun mvarId => Sym.SymM.run do
let mvarId Sym.preprocessMVar mvarId
( Sym.simpGoal mvarId methods).toOption
example : (1-1) + x*1 + (2-1)*0 = x := by
sym_simp [Nat.add_zero, Nat.zero_add, Nat.mul_one]

View File

@@ -3,7 +3,7 @@ import Lean.Meta.Sym
open Lean Meta Sym
def profileM {α : Type} (k : MetaM α) (msg : String := "experiment") : MetaM α :=
profileitM Exception msg ({ : Options }.set `profiler true |>.setNat `profiler.threshold 0) k
profileitM Exception msg (Options.empty.set `profiler true |>.set `profiler.threshold 0) k
def genTerm (n : Nat) : Expr := Id.run do
let mut e := mkConst ``True
@@ -33,11 +33,8 @@ def tryIntros? (goals : List MVarId) : SymM (Option (List MVarId)) := do
def tryApply? (rule : BackwardRule) (goals : List MVarId) : SymM (Option (List MVarId)) := do
let goal :: goals := goals | return none
try
let goals' rule.apply goal
return some (goals' ++ goals)
catch _ =>
return none
let .goals goals' rule.apply goal | return none
return some (goals' ++ goals)
def tryApplyAny? (rules : List BackwardRule) (goals : List MVarId) : SymM (Option (List MVarId)) := do
match rules with