mirror of
https://github.com/leanprover/lean4.git
synced 2026-04-06 12:14:07 +00:00
Compare commits
12 Commits
sofia/open
...
sym-arith-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f12d008bb1 | ||
|
|
30315a59d4 | ||
|
|
9d078f64bc | ||
|
|
083b393294 | ||
|
|
9ea2b7b533 | ||
|
|
e17b0347c8 | ||
|
|
0a6c7eef66 | ||
|
|
46046b47a8 | ||
|
|
069e676532 | ||
|
|
94bf1d34d1 | ||
|
|
681769fb6d | ||
|
|
dad6fe832d |
@@ -25,6 +25,7 @@ public import Lean.Meta.Sym.Simp
|
||||
public import Lean.Meta.Sym.Util
|
||||
public import Lean.Meta.Sym.Eta
|
||||
public import Lean.Meta.Sym.Canon
|
||||
public import Lean.Meta.Sym.Arith
|
||||
public import Lean.Meta.Sym.Grind
|
||||
public import Lean.Meta.Sym.SynthInstance
|
||||
|
||||
|
||||
20
src/Lean/Meta/Sym/Arith.lean
Normal file
20
src/Lean/Meta/Sym/Arith.lean
Normal file
@@ -0,0 +1,20 @@
|
||||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.Arith.Types
|
||||
public import Lean.Meta.Sym.Arith.EvalNum
|
||||
public import Lean.Meta.Sym.Arith.Classify
|
||||
public import Lean.Meta.Sym.Arith.MonadCanon
|
||||
public import Lean.Meta.Sym.Arith.MonadRing
|
||||
public import Lean.Meta.Sym.Arith.MonadSemiring
|
||||
public import Lean.Meta.Sym.Arith.MonadVar
|
||||
public import Lean.Meta.Sym.Arith.Functions
|
||||
public import Lean.Meta.Sym.Arith.Reify
|
||||
public import Lean.Meta.Sym.Arith.DenoteExpr
|
||||
public import Lean.Meta.Sym.Arith.ToExpr
|
||||
public import Lean.Meta.Sym.Arith.VarRename
|
||||
public import Lean.Meta.Sym.Arith.Poly
|
||||
143
src/Lean/Meta/Sym/Arith/Classify.lean
Normal file
143
src/Lean/Meta/Sym/Arith/Classify.lean
Normal file
@@ -0,0 +1,143 @@
|
||||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.Arith.EvalNum
|
||||
import Lean.Meta.Sym.SynthInstance
|
||||
import Lean.Meta.Sym.Canon
|
||||
import Lean.Meta.DecLevel
|
||||
import Init.Grind.Ring
|
||||
public section
|
||||
|
||||
namespace Lean.Meta.Sym.Arith
|
||||
|
||||
/-!
|
||||
# Algebraic structure classification
|
||||
|
||||
Detects the strongest algebraic structure available for a type and caches
|
||||
the classification in `Arith.State.typeClassify`. The detection order is:
|
||||
|
||||
1. `Grind.CommRing` (includes `Field` check)
|
||||
2. `Grind.Ring` (non-commutative)
|
||||
3. `Grind.CommSemiring` (via `OfSemiring.Q` envelope)
|
||||
4. `Grind.Semiring` (non-commutative)
|
||||
|
||||
Results (including failures) are cached in a single `PHashMap ExprPtr ClassifyResult`
|
||||
to avoid repeated synthesis attempts.
|
||||
-/
|
||||
|
||||
private def getIsCharInst? (u : Level) (type : Expr) (semiringInst : Expr) : SymM (Option (Expr × Nat)) := do
|
||||
withNewMCtxDepth do
|
||||
let n ← mkFreshExprMVar (mkConst ``Nat)
|
||||
let charType := mkApp3 (mkConst ``Grind.IsCharP [u]) type semiringInst n
|
||||
let some charInst ← Sym.synthInstance? charType | return none
|
||||
let n ← instantiateMVars n
|
||||
let some n ← evalNat? n | return none
|
||||
return some (charInst, n)
|
||||
|
||||
private def getNoZeroDivInst? (u : Level) (type : Expr) : SymM (Option Expr) := do
|
||||
let natModuleType := mkApp (mkConst ``Grind.NatModule [u]) type
|
||||
let some natModuleInst ← Sym.synthInstance? natModuleType | return none
|
||||
let noZeroDivType := mkApp2 (mkConst ``Grind.NoNatZeroDivisors [u]) type natModuleInst
|
||||
Sym.synthInstance? noZeroDivType
|
||||
|
||||
/-- Try to classify `type` as a `CommRing`. Returns the ring id on success. -/
|
||||
private def tryCommRing? (type : Expr) : SymM (Option Nat) := do
|
||||
let u ← getDecLevel type
|
||||
let commRing := mkApp (mkConst ``Grind.CommRing [u]) type
|
||||
let some commRingInst ← Sym.synthInstance? commRing | return none
|
||||
let ringInst := mkApp2 (mkConst ``Grind.CommRing.toRing [u]) type commRingInst
|
||||
let semiringInst := mkApp2 (mkConst ``Grind.Ring.toSemiring [u]) type ringInst
|
||||
let commSemiringInst := mkApp2 (mkConst ``Grind.CommRing.toCommSemiring [u]) type semiringInst
|
||||
let charInst? ← getIsCharInst? u type semiringInst
|
||||
let noZeroDivInst? ← getNoZeroDivInst? u type
|
||||
let fieldInst? ← Sym.synthInstance? <| mkApp (mkConst ``Grind.Field [u]) type
|
||||
let semiringId? := none
|
||||
let id := (← getArithState).rings.size
|
||||
let ring : CommRing := {
|
||||
id, semiringId?, type, u, semiringInst, ringInst, commSemiringInst,
|
||||
commRingInst, charInst?, noZeroDivInst?, fieldInst?,
|
||||
}
|
||||
modifyArithState fun s => { s with rings := s.rings.push ring }
|
||||
return some id
|
||||
|
||||
/-- Try to classify `type` as a non-commutative `Ring`. -/
|
||||
private def tryNonCommRing? (type : Expr) : SymM (Option Nat) := do
|
||||
let u ← getDecLevel type
|
||||
let ring := mkApp (mkConst ``Grind.Ring [u]) type
|
||||
let some ringInst ← Sym.synthInstance? ring | return none
|
||||
let semiringInst := mkApp2 (mkConst ``Grind.Ring.toSemiring [u]) type ringInst
|
||||
let charInst? ← getIsCharInst? u type semiringInst
|
||||
let id := (← getArithState).ncRings.size
|
||||
let ring : Ring := {
|
||||
id, type, u, semiringInst, ringInst, charInst?
|
||||
}
|
||||
modifyArithState fun s => { s with ncRings := s.ncRings.push ring }
|
||||
return some id
|
||||
|
||||
/-- Helper function for `tryCommSemiring? -/
|
||||
private def tryCacheAndCommRing? (type : Expr) : SymM (Option Nat) := do
|
||||
if let some result := (← getArithState).typeClassify.find? { expr := type } then
|
||||
let .commRing id := result | return none
|
||||
return id
|
||||
let id? ← tryCommRing? type
|
||||
let result := match id? with
|
||||
| none => .none
|
||||
| some id => .commRing id
|
||||
modifyArithState fun s => { s with typeClassify := s.typeClassify.insert { expr := type } result }
|
||||
return id?
|
||||
|
||||
/-- Try to classify `type` as a `CommSemiring`. Creates the `OfSemiring.Q` envelope ring. -/
|
||||
private def tryCommSemiring? (type : Expr) : SymM (Option Nat) := do
|
||||
let u ← getDecLevel type
|
||||
let commSemiring := mkApp (mkConst ``Grind.CommSemiring [u]) type
|
||||
let some commSemiringInst ← Sym.synthInstance? commSemiring | return none
|
||||
let semiringInst := mkApp2 (mkConst ``Grind.CommSemiring.toSemiring [u]) type commSemiringInst
|
||||
let q ← shareCommon (← Sym.canon (mkApp2 (mkConst ``Grind.Ring.OfSemiring.Q [u]) type semiringInst))
|
||||
-- The envelope `Q` type must be classifiable as a CommRing.
|
||||
let some ringId ← tryCacheAndCommRing? q
|
||||
| reportIssue! "unexpected failure initializing ring{indentExpr q}"; return none
|
||||
let id := (← getArithState).semirings.size
|
||||
let semiring : CommSemiring := {
|
||||
id, type, ringId, u, semiringInst, commSemiringInst
|
||||
}
|
||||
modifyArithState fun s => { s with semirings := s.semirings.push semiring }
|
||||
-- Link the envelope ring back to this semiring
|
||||
modifyArithState fun s =>
|
||||
let rings := s.rings.modify ringId fun r => { r with semiringId? := some id }
|
||||
{ s with rings }
|
||||
return some id
|
||||
|
||||
/-- Try to classify `type` as a non-commutative `Semiring`. -/
|
||||
private def tryNonCommSemiring? (type : Expr) : SymM (Option Nat) := do
|
||||
let u ← getDecLevel type
|
||||
let semiring := mkApp (mkConst ``Grind.Semiring [u]) type
|
||||
let some semiringInst ← Sym.synthInstance? semiring | return none
|
||||
let id := (← getArithState).ncSemirings.size
|
||||
let semiring : Semiring := { id, type, u, semiringInst }
|
||||
modifyArithState fun s => { s with ncSemirings := s.ncSemirings.push semiring }
|
||||
return some id
|
||||
|
||||
/--
|
||||
Classify the algebraic structure of `type`, trying the strongest first:
|
||||
CommRing > Ring > CommSemiring > Semiring.
|
||||
Results are cached in `Arith.State.typeClassify`.
|
||||
-/
|
||||
def classify? (type : Expr) : SymM ClassifyResult := do
|
||||
if let some result := (← getArithState).typeClassify.find? { expr := type } then
|
||||
return result
|
||||
let result ← go
|
||||
modifyArithState fun s => { s with typeClassify := s.typeClassify.insert { expr := type } result }
|
||||
return result
|
||||
where
|
||||
go : SymM ClassifyResult := do
|
||||
if let some id ← tryCommRing? type then return .commRing id
|
||||
if let some id ← tryNonCommRing? type then return .nonCommRing id
|
||||
if let some id ← tryCommSemiring? type then return .commSemiring id
|
||||
if let some id ← tryNonCommSemiring? type then return .nonCommSemiring id
|
||||
return .none
|
||||
|
||||
end Lean.Meta.Sym.Arith
|
||||
93
src/Lean/Meta/Sym/Arith/DenoteExpr.lean
Normal file
93
src/Lean/Meta/Sym/Arith/DenoteExpr.lean
Normal file
@@ -0,0 +1,93 @@
|
||||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.Arith.Functions
|
||||
public import Lean.Meta.Sym.Arith.MonadVar
|
||||
public section
|
||||
namespace Lean.Meta.Sym.Arith
|
||||
|
||||
/-!
|
||||
# Denotation of reified expressions
|
||||
|
||||
Converts reified `RingExpr`, `Poly`, `Mon`, `Power` back into Lean `Expr`s using
|
||||
the ring's cached operator functions and variable array.
|
||||
-/
|
||||
|
||||
variable [Monad m] [MonadError m] [MonadLiftT MetaM m] [MonadCanon m] [MonadRing m]
|
||||
|
||||
/-- Convert an integer to a numeral expression in the ring. Negative values use `getNegFn`. -/
|
||||
def denoteNum (k : Int) : m Expr := do
|
||||
let ring ← getRing
|
||||
let n := mkRawNatLit k.natAbs
|
||||
let ofNatInst ← if let some inst ← MonadCanon.synthInstance? (mkApp2 (mkConst ``OfNat [ring.u]) ring.type n) then
|
||||
pure inst
|
||||
else
|
||||
pure <| mkApp3 (mkConst ``Grind.Semiring.ofNat [ring.u]) ring.type ring.semiringInst n
|
||||
let e := mkApp3 (mkConst ``OfNat.ofNat [ring.u]) ring.type n ofNatInst
|
||||
if k < 0 then
|
||||
return mkApp (← getNegFn) e
|
||||
else
|
||||
return e
|
||||
|
||||
/-- Denote a `Power` (variable raised to a power). -/
|
||||
def denotePower [MonadGetVar m] (pw : Power) : m Expr := do
|
||||
let x ← getVar pw.x
|
||||
if pw.k == 1 then
|
||||
return x
|
||||
else
|
||||
return mkApp2 (← getPowFn) x (toExpr pw.k)
|
||||
|
||||
/-- Denote a `Mon` (product of powers). -/
|
||||
def denoteMon [MonadGetVar m] (mn : Mon) : m Expr := do
|
||||
match mn with
|
||||
| .unit => denoteNum 1
|
||||
| .mult pw mn => go mn (← denotePower pw)
|
||||
where
|
||||
go (mn : Mon) (acc : Expr) : m Expr := do
|
||||
match mn with
|
||||
| .unit => return acc
|
||||
| .mult pw mn => go mn (mkApp2 (← getMulFn) acc (← denotePower pw))
|
||||
|
||||
/-- Denote a `Poly` (sum of coefficient × monomial terms). -/
|
||||
def denotePoly [MonadGetVar m] (p : Poly) : m Expr := do
|
||||
match p with
|
||||
| .num k => denoteNum k
|
||||
| .add k mn p => go p (← denoteTerm k mn)
|
||||
where
|
||||
denoteTerm (k : Int) (mn : Mon) : m Expr := do
|
||||
if k == 1 then
|
||||
denoteMon mn
|
||||
else
|
||||
return mkApp2 (← getMulFn) (← denoteNum k) (← denoteMon mn)
|
||||
|
||||
go (p : Poly) (acc : Expr) : m Expr := do
|
||||
match p with
|
||||
| .num 0 => return acc
|
||||
| .num k => return mkApp2 (← getAddFn) acc (← denoteNum k)
|
||||
| .add k mn p => go p (mkApp2 (← getAddFn) acc (← denoteTerm k mn))
|
||||
|
||||
/-- Denote a `RingExpr` using a variable lookup function. -/
|
||||
@[specialize]
|
||||
private def denoteRingExprCore (getVarExpr : Nat → Expr) (e : RingExpr) : m Expr := do
|
||||
go e
|
||||
where
|
||||
go : RingExpr → m Expr
|
||||
| .num k => denoteNum k
|
||||
| .natCast k => return mkApp (← getNatCastFn) (mkNatLit k)
|
||||
| .intCast k => return mkApp (← getIntCastFn) (mkIntLit k)
|
||||
| .var x => return getVarExpr x
|
||||
| .add a b => return mkApp2 (← getAddFn) (← go a) (← go b)
|
||||
| .sub a b => return mkApp2 (← getSubFn) (← go a) (← go b)
|
||||
| .mul a b => return mkApp2 (← getMulFn) (← go a) (← go b)
|
||||
| .pow a k => return mkApp2 (← getPowFn) (← go a) (toExpr k)
|
||||
| .neg a => return mkApp (← getNegFn) (← go a)
|
||||
|
||||
/-- Denote a `RingExpr` using an explicit variable array. -/
|
||||
def denoteRingExpr (vars : Array Expr) (e : RingExpr) : m Expr := do
|
||||
denoteRingExprCore (fun x => vars[x]!) e
|
||||
|
||||
end Lean.Meta.Sym.Arith
|
||||
90
src/Lean/Meta/Sym/Arith/EvalNum.lean
Normal file
90
src/Lean/Meta/Sym/Arith/EvalNum.lean
Normal file
@@ -0,0 +1,90 @@
|
||||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.Arith.Types
|
||||
import Lean.Meta.Sym.LitValues
|
||||
import Lean.Meta.IntInstTesters
|
||||
import Lean.Meta.NatInstTesters
|
||||
public section
|
||||
|
||||
namespace Lean.Meta.Sym.Arith
|
||||
|
||||
/-!
|
||||
Functions for evaluating simple `Nat` and `Int` expressions that appear in type classes
|
||||
(e.g., `ToInt` and `IsCharP`). Using `whnf` for this purpose is too expensive and can
|
||||
exhaust the stack. We considered `evalExpr` as an alternative, but it introduces
|
||||
considerable overhead. We may use `evalExpr` as a fallback in the future.
|
||||
-/
|
||||
|
||||
def checkExp (k : Nat) : OptionT SymM Unit := do
|
||||
let exp ← getExpThreshold
|
||||
if k > exp then
|
||||
reportIssue! "exponent {k} exceeds threshold for exponentiation `(exp := {exp})`"
|
||||
failure
|
||||
|
||||
/-
|
||||
**Note**: It is safe to use (the more efficient) structural instance tests here because
|
||||
`Sym.Canon` has already run.
|
||||
-/
|
||||
open Structural in
|
||||
mutual
|
||||
private partial def evalNatCore (e : Expr) : OptionT SymM Nat := do
|
||||
match_expr e with
|
||||
| Nat.zero => return 0
|
||||
| Nat.succ a => return (← evalNatCore a) + 1
|
||||
| Int.toNat a => return (← evalIntCore a).toNat
|
||||
| Int.natAbs a => return (← evalIntCore a).natAbs
|
||||
| HAdd.hAdd _ _ _ inst a b => guard (← isInstHAddNat inst); return (← evalNatCore a) + (← evalNatCore b)
|
||||
| HMul.hMul _ _ _ inst a b => guard (← isInstHMulNat inst); return (← evalNatCore a) * (← evalNatCore b)
|
||||
| HSub.hSub _ _ _ inst a b => guard (← isInstHSubNat inst); return (← evalNatCore a) - (← evalNatCore b)
|
||||
| HDiv.hDiv _ _ _ inst a b => guard (← isInstHDivNat inst); return (← evalNatCore a) / (← evalNatCore b)
|
||||
| HMod.hMod _ _ _ inst a b => guard (← isInstHModNat inst); return (← evalNatCore a) % (← evalNatCore b)
|
||||
| OfNat.ofNat _ _ _ =>
|
||||
let some n := Sym.getNatValue? e |>.run | failure
|
||||
return n
|
||||
| HPow.hPow _ _ _ inst a k =>
|
||||
guard (← isInstHPowNat inst)
|
||||
let k ← evalNatCore k
|
||||
checkExp k
|
||||
let a ← evalNatCore a
|
||||
return a ^ k
|
||||
| _ => failure
|
||||
|
||||
private partial def evalIntCore (e : Expr) : OptionT SymM Int := do
|
||||
match_expr e with
|
||||
| Neg.neg _ i a => guard (← isInstNegInt i); return - (← evalIntCore a)
|
||||
| HAdd.hAdd _ _ _ i a b => guard (← isInstHAddInt i); return (← evalIntCore a) + (← evalIntCore b)
|
||||
| HSub.hSub _ _ _ i a b => guard (← isInstHSubInt i); return (← evalIntCore a) - (← evalIntCore b)
|
||||
| HMul.hMul _ _ _ i a b => guard (← isInstHMulInt i); return (← evalIntCore a) * (← evalIntCore b)
|
||||
| HDiv.hDiv _ _ _ i a b => guard (← isInstHDivInt i); return (← evalIntCore a) / (← evalIntCore b)
|
||||
| HMod.hMod _ _ _ i a b => guard (← isInstHModInt i); return (← evalIntCore a) % (← evalIntCore b)
|
||||
| HPow.hPow _ _ _ i a k =>
|
||||
guard (← isInstHPowInt i)
|
||||
let a ← evalIntCore a
|
||||
let k ← evalNatCore k
|
||||
checkExp k
|
||||
return a ^ k
|
||||
| OfNat.ofNat _ _ _ =>
|
||||
let some n := Sym.getIntValue? e |>.run | failure
|
||||
return n
|
||||
| NatCast.natCast _ i a =>
|
||||
let_expr instNatCastInt ← i | failure
|
||||
return (← evalNatCore a)
|
||||
| Nat.cast _ i a =>
|
||||
let_expr instNatCastInt ← i | failure
|
||||
return (← evalNatCore a)
|
||||
| _ => failure
|
||||
|
||||
end
|
||||
|
||||
def evalNat? (e : Expr) : SymM (Option Nat) :=
|
||||
evalNatCore e |>.run
|
||||
|
||||
def evalInt? (e : Expr) : SymM (Option Int) :=
|
||||
evalIntCore e |>.run
|
||||
|
||||
end Lean.Meta.Sym.Arith
|
||||
171
src/Lean/Meta/Sym/Arith/Functions.lean
Normal file
171
src/Lean/Meta/Sym/Arith/Functions.lean
Normal file
@@ -0,0 +1,171 @@
|
||||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.Arith.MonadRing
|
||||
public import Lean.Meta.Sym.Arith.MonadSemiring
|
||||
public section
|
||||
namespace Lean.Meta.Sym.Arith
|
||||
|
||||
/-!
|
||||
# Cached function expressions for arithmetic operators
|
||||
|
||||
Synthesizes and caches the canonical Lean expressions for arithmetic operators
|
||||
(`+`, `*`, `-`, `^`, `intCast`, `natCast`, etc.). These cached expressions are used
|
||||
during reification to validate instances via pointer equality (`isSameExpr`).
|
||||
|
||||
Each getter checks the cache field first. On a miss, it synthesizes the instance,
|
||||
verifies it against the expected instance from the ring structure using `isDefEqI`,
|
||||
canonicalizes the result via `canonExpr`, and stores it.
|
||||
-/
|
||||
|
||||
variable [MonadLiftT MetaM m] [MonadError m] [Monad m] [MonadCanon m]
|
||||
|
||||
private def checkInst (declName : Name) (inst inst' : Expr) : MetaM Unit := do
|
||||
unless (← withReducibleAndInstances <| isDefEq inst inst') do
|
||||
throwError "error while initializing arithmetic operators:\ninstance for `{declName}` {indentExpr inst}\nis not definitionally equal to the expected one {indentExpr inst'}\nwhen only reducible definitions and instances are reduced"
|
||||
|
||||
private def mkUnaryFn (type : Expr) (u : Level) (instDeclName : Name) (declName : Name) (expectedInst : Expr) : m Expr := do
|
||||
let inst ← MonadCanon.synthInstance <| mkApp (mkConst instDeclName [u]) type
|
||||
checkInst declName inst expectedInst
|
||||
canonExpr <| mkApp2 (mkConst declName [u]) type inst
|
||||
|
||||
private def mkBinHomoFn (type : Expr) (u : Level) (instDeclName : Name) (declName : Name) (expectedInst : Expr) : m Expr := do
|
||||
let inst ← MonadCanon.synthInstance <| mkApp3 (mkConst instDeclName [u, u, u]) type type type
|
||||
checkInst declName inst expectedInst
|
||||
canonExpr <| mkApp4 (mkConst declName [u, u, u]) type type type inst
|
||||
|
||||
private def mkPowFn (u : Level) (type : Expr) (semiringInst : Expr) : m Expr := do
|
||||
let inst ← MonadCanon.synthInstance <| mkApp3 (mkConst ``HPow [u, 0, u]) type Nat.mkType type
|
||||
let inst' := mkApp2 (mkConst ``Grind.Semiring.npow [u]) type semiringInst
|
||||
checkInst ``HPow.hPow inst inst'
|
||||
canonExpr <| mkApp4 (mkConst ``HPow.hPow [u, 0, u]) type Nat.mkType type inst
|
||||
|
||||
private def mkNatCastFn (u : Level) (type : Expr) (semiringInst : Expr) : m Expr := do
|
||||
let inst' := mkApp2 (mkConst ``Grind.Semiring.natCast [u]) type semiringInst
|
||||
let instType := mkApp (mkConst ``NatCast [u]) type
|
||||
-- Note: `Semiring.natCast` is not a global instance, so `NatCast α` may not be available.
|
||||
-- When present, verify defeq; otherwise fall back to the semiring field.
|
||||
let inst ← match (← MonadCanon.synthInstance? instType) with
|
||||
| none => pure inst'
|
||||
| some inst => checkInst ``NatCast.natCast inst inst'; pure inst
|
||||
canonExpr <| mkApp2 (mkConst ``NatCast.natCast [u]) type inst
|
||||
|
||||
section RingFns
|
||||
variable [MonadRing m]
|
||||
|
||||
def getAddFn : m Expr := do
|
||||
let ring ← getRing
|
||||
if let some addFn := ring.addFn? then return addFn
|
||||
let expectedInst := mkApp2 (mkConst ``instHAdd [ring.u]) ring.type <| mkApp2 (mkConst ``Grind.Semiring.toAdd [ring.u]) ring.type ring.semiringInst
|
||||
let addFn ← mkBinHomoFn ring.type ring.u ``HAdd ``HAdd.hAdd expectedInst
|
||||
modifyRing fun s => { s with addFn? := some addFn }
|
||||
return addFn
|
||||
|
||||
def getMulFn : m Expr := do
|
||||
let ring ← getRing
|
||||
if let some mulFn := ring.mulFn? then return mulFn
|
||||
let expectedInst := mkApp2 (mkConst ``instHMul [ring.u]) ring.type <| mkApp2 (mkConst ``Grind.Semiring.toMul [ring.u]) ring.type ring.semiringInst
|
||||
let mulFn ← mkBinHomoFn ring.type ring.u ``HMul ``HMul.hMul expectedInst
|
||||
modifyRing fun s => { s with mulFn? := some mulFn }
|
||||
return mulFn
|
||||
|
||||
def getSubFn : m Expr := do
|
||||
let ring ← getRing
|
||||
if let some subFn := ring.subFn? then return subFn
|
||||
let expectedInst := mkApp2 (mkConst ``instHSub [ring.u]) ring.type <| mkApp2 (mkConst ``Grind.Ring.toSub [ring.u]) ring.type ring.ringInst
|
||||
let subFn ← mkBinHomoFn ring.type ring.u ``HSub ``HSub.hSub expectedInst
|
||||
modifyRing fun s => { s with subFn? := some subFn }
|
||||
return subFn
|
||||
|
||||
def getNegFn : m Expr := do
|
||||
let ring ← getRing
|
||||
if let some negFn := ring.negFn? then return negFn
|
||||
let expectedInst := mkApp2 (mkConst ``Grind.Ring.toNeg [ring.u]) ring.type ring.ringInst
|
||||
let negFn ← mkUnaryFn ring.type ring.u ``Neg ``Neg.neg expectedInst
|
||||
modifyRing fun s => { s with negFn? := some negFn }
|
||||
return negFn
|
||||
|
||||
def getPowFn : m Expr := do
|
||||
let ring ← getRing
|
||||
if let some powFn := ring.powFn? then return powFn
|
||||
let powFn ← mkPowFn ring.u ring.type ring.semiringInst
|
||||
modifyRing fun s => { s with powFn? := some powFn }
|
||||
return powFn
|
||||
|
||||
def getIntCastFn : m Expr := do
|
||||
let ring ← getRing
|
||||
if let some intCastFn := ring.intCastFn? then return intCastFn
|
||||
let inst' := mkApp2 (mkConst ``Grind.Ring.intCast [ring.u]) ring.type ring.ringInst
|
||||
let instType := mkApp (mkConst ``IntCast [ring.u]) ring.type
|
||||
-- Note: `Ring.intCast` is not a global instance. Same pattern as `mkNatCastFn`.
|
||||
let inst ← match (← MonadCanon.synthInstance? instType) with
|
||||
| none => pure inst'
|
||||
| some inst => checkInst ``Int.cast inst inst'; pure inst
|
||||
let intCastFn ← canonExpr <| mkApp2 (mkConst ``IntCast.intCast [ring.u]) ring.type inst
|
||||
modifyRing fun s => { s with intCastFn? := some intCastFn }
|
||||
return intCastFn
|
||||
|
||||
def getNatCastFn : m Expr := do
|
||||
let ring ← getRing
|
||||
if let some natCastFn := ring.natCastFn? then return natCastFn
|
||||
let natCastFn ← mkNatCastFn ring.u ring.type ring.semiringInst
|
||||
modifyRing fun s => { s with natCastFn? := some natCastFn }
|
||||
return natCastFn
|
||||
|
||||
end RingFns
|
||||
|
||||
section CommRingFns
|
||||
variable [MonadCommRing m]
|
||||
|
||||
def getInvFn : m Expr := do
|
||||
let ring ← getCommRing
|
||||
let some fieldInst := ring.fieldInst?
|
||||
| throwError "internal error: type is not a field{indentExpr ring.type}"
|
||||
if let some invFn := ring.invFn? then return invFn
|
||||
let expectedInst := mkApp2 (mkConst ``Grind.Field.toInv [ring.u]) ring.type fieldInst
|
||||
let invFn ← mkUnaryFn ring.type ring.u ``Inv ``Inv.inv expectedInst
|
||||
modifyCommRing fun s => { s with invFn? := some invFn }
|
||||
return invFn
|
||||
|
||||
end CommRingFns
|
||||
|
||||
section SemiringFns
|
||||
variable [MonadSemiring m]
|
||||
|
||||
def getAddFn' : m Expr := do
|
||||
let sr ← getSemiring
|
||||
if let some addFn := sr.addFn? then return addFn
|
||||
let expectedInst := mkApp2 (mkConst ``instHAdd [sr.u]) sr.type <| mkApp2 (mkConst ``Grind.Semiring.toAdd [sr.u]) sr.type sr.semiringInst
|
||||
let addFn ← mkBinHomoFn sr.type sr.u ``HAdd ``HAdd.hAdd expectedInst
|
||||
modifySemiring fun s => { s with addFn? := some addFn }
|
||||
return addFn
|
||||
|
||||
def getMulFn' : m Expr := do
|
||||
let sr ← getSemiring
|
||||
if let some mulFn := sr.mulFn? then return mulFn
|
||||
let expectedInst := mkApp2 (mkConst ``instHMul [sr.u]) sr.type <| mkApp2 (mkConst ``Grind.Semiring.toMul [sr.u]) sr.type sr.semiringInst
|
||||
let mulFn ← mkBinHomoFn sr.type sr.u ``HMul ``HMul.hMul expectedInst
|
||||
modifySemiring fun s => { s with mulFn? := some mulFn }
|
||||
return mulFn
|
||||
|
||||
def getPowFn' : m Expr := do
|
||||
let sr ← getSemiring
|
||||
if let some powFn := sr.powFn? then return powFn
|
||||
let powFn ← mkPowFn sr.u sr.type sr.semiringInst
|
||||
modifySemiring fun s => { s with powFn? := some powFn }
|
||||
return powFn
|
||||
|
||||
def getNatCastFn' : m Expr := do
|
||||
let sr ← getSemiring
|
||||
if let some natCastFn := sr.natCastFn? then return natCastFn
|
||||
let natCastFn ← mkNatCastFn sr.u sr.type sr.semiringInst
|
||||
modifySemiring fun s => { s with natCastFn? := some natCastFn }
|
||||
return natCastFn
|
||||
|
||||
end SemiringFns
|
||||
|
||||
end Lean.Meta.Sym.Arith
|
||||
@@ -1,24 +1,23 @@
|
||||
/-
|
||||
Copyright (c) 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Types
|
||||
public import Lean.Meta.Sym.Arith.Types
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
namespace Lean.Meta.Sym.Arith
|
||||
|
||||
class MonadCanon (m : Type → Type) where
|
||||
/--
|
||||
Helper function for removing dependency on `GoalM`.
|
||||
In `RingM` and `SemiringM`, this is just `sharedCommon (← canon e)`
|
||||
When printing counterexamples, we are at `MetaM`, and this is just the identity function.
|
||||
Canonicalize an expression (types, instances, support arguments).
|
||||
In `SymM`, this is `Sym.canon`. In `PP.M` (diagnostics), this is the identity.
|
||||
-/
|
||||
canonExpr : Expr → m Expr
|
||||
/--
|
||||
Helper function for removing dependency on `GoalM`. During search we
|
||||
want to track the instances synthesized by `grind`, and this is `Grind.synthInstance`.
|
||||
Synthesize an instance, returning `none` on failure.
|
||||
In `SymM`, this is `Sym.synthInstance?`. In `PP.M`, this is `Meta.synthInstance?`.
|
||||
-/
|
||||
synthInstance? : Expr → m (Option Expr)
|
||||
|
||||
@@ -31,7 +30,7 @@ instance (m n) [MonadLift m n] [MonadCanon m] : MonadCanon n where
|
||||
|
||||
def MonadCanon.synthInstance [Monad m] [MonadError m] [MonadCanon m] (type : Expr) : m Expr := do
|
||||
let some inst ← synthInstance? type
|
||||
| throwError "`grind` failed to find instance{indentExpr type}"
|
||||
| throwError "failed to find instance{indentExpr type}"
|
||||
return inst
|
||||
|
||||
end Lean.Meta.Grind.Arith.CommRing
|
||||
end Lean.Meta.Sym.Arith
|
||||
39
src/Lean/Meta/Sym/Arith/MonadRing.lean
Normal file
39
src/Lean/Meta/Sym/Arith/MonadRing.lean
Normal file
@@ -0,0 +1,39 @@
|
||||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.Arith.MonadCanon
|
||||
public section
|
||||
namespace Lean.Meta.Sym.Arith
|
||||
|
||||
class MonadRing (m : Type → Type) where
|
||||
getRing : m Ring
|
||||
modifyRing : (Ring → Ring) → m Unit
|
||||
|
||||
export MonadRing (getRing modifyRing)
|
||||
|
||||
@[always_inline]
|
||||
instance (m n) [MonadLift m n] [MonadRing m] : MonadRing n where
|
||||
getRing := liftM (getRing : m Ring)
|
||||
modifyRing f := liftM (modifyRing f : m Unit)
|
||||
|
||||
class MonadCommRing (m : Type → Type) where
|
||||
getCommRing : m CommRing
|
||||
modifyCommRing : (CommRing → CommRing) → m Unit
|
||||
|
||||
export MonadCommRing (getCommRing modifyCommRing)
|
||||
|
||||
@[always_inline]
|
||||
instance (m n) [MonadLift m n] [MonadCommRing m] : MonadCommRing n where
|
||||
getCommRing := liftM (getCommRing : m CommRing)
|
||||
modifyCommRing f := liftM (modifyCommRing f : m Unit)
|
||||
|
||||
@[always_inline]
|
||||
instance (m) [Monad m] [MonadCommRing m] : MonadRing m where
|
||||
getRing := return (← getCommRing).toRing
|
||||
modifyRing f := modifyCommRing fun s => { s with toRing := f s.toRing }
|
||||
|
||||
end Lean.Meta.Sym.Arith
|
||||
39
src/Lean/Meta/Sym/Arith/MonadSemiring.lean
Normal file
39
src/Lean/Meta/Sym/Arith/MonadSemiring.lean
Normal file
@@ -0,0 +1,39 @@
|
||||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.Arith.MonadCanon
|
||||
public section
|
||||
namespace Lean.Meta.Sym.Arith
|
||||
|
||||
class MonadSemiring (m : Type → Type) where
|
||||
getSemiring : m Semiring
|
||||
modifySemiring : (Semiring → Semiring) → m Unit
|
||||
|
||||
export MonadSemiring (getSemiring modifySemiring)
|
||||
|
||||
@[always_inline]
|
||||
instance (m n) [MonadLift m n] [MonadSemiring m] : MonadSemiring n where
|
||||
getSemiring := liftM (getSemiring : m Semiring)
|
||||
modifySemiring f := liftM (modifySemiring f : m Unit)
|
||||
|
||||
class MonadCommSemiring (m : Type → Type) where
|
||||
getCommSemiring : m CommSemiring
|
||||
modifyCommSemiring : (CommSemiring → CommSemiring) → m Unit
|
||||
|
||||
export MonadCommSemiring (getCommSemiring modifyCommSemiring)
|
||||
|
||||
@[always_inline]
|
||||
instance (m n) [MonadLift m n] [MonadCommSemiring m] : MonadCommSemiring n where
|
||||
getCommSemiring := liftM (getCommSemiring : m CommSemiring)
|
||||
modifyCommSemiring f := liftM (modifyCommSemiring f : m Unit)
|
||||
|
||||
@[always_inline]
|
||||
instance (m) [Monad m] [MonadCommSemiring m] : MonadSemiring m where
|
||||
getSemiring := return (← getCommSemiring).toSemiring
|
||||
modifySemiring f := modifyCommSemiring fun s => { s with toSemiring := f s.toSemiring }
|
||||
|
||||
end Lean.Meta.Sym.Arith
|
||||
32
src/Lean/Meta/Sym/Arith/MonadVar.lean
Normal file
32
src/Lean/Meta/Sym/Arith/MonadVar.lean
Normal file
@@ -0,0 +1,32 @@
|
||||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.Arith.Types
|
||||
public section
|
||||
namespace Lean.Meta.Sym.Arith
|
||||
|
||||
/-- Read a variable's Lean expression by index. Used by `DenoteExpr` and diagnostics (PP). -/
|
||||
class MonadGetVar (m : Type → Type) where
|
||||
getVar : Var → m Expr
|
||||
|
||||
export MonadGetVar (getVar)
|
||||
|
||||
@[always_inline]
|
||||
instance (m n) [MonadLift m n] [MonadGetVar m] : MonadGetVar n where
|
||||
getVar x := liftM (getVar x : m Expr)
|
||||
|
||||
/-- Create or lookup a variable for a Lean expression. Used by reification. -/
|
||||
class MonadMkVar (m : Type → Type) where
|
||||
mkVar : Expr → m Var
|
||||
|
||||
export MonadMkVar (mkVar)
|
||||
|
||||
@[always_inline]
|
||||
instance (m n) [MonadLift m n] [MonadMkVar m] : MonadMkVar n where
|
||||
mkVar e := liftM (mkVar e : m Var)
|
||||
|
||||
end Lean.Meta.Sym.Arith
|
||||
205
src/Lean/Meta/Sym/Arith/Reify.lean
Normal file
205
src/Lean/Meta/Sym/Arith/Reify.lean
Normal file
@@ -0,0 +1,205 @@
|
||||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.Arith.Functions
|
||||
public import Lean.Meta.Sym.Arith.MonadVar
|
||||
public import Lean.Meta.Sym.LitValues
|
||||
public section
|
||||
namespace Lean.Meta.Sym.Arith
|
||||
open Sym.Arith (MonadCanon)
|
||||
|
||||
/-!
|
||||
# Reification of arithmetic expressions
|
||||
|
||||
Converts Lean expressions into `CommRing.Expr` (ring) or `CommSemiring.Expr`
|
||||
(semiring) for reflection-based normalization.
|
||||
|
||||
Instance validation uses pointer equality (`isSameExpr`) against cached function
|
||||
expressions from `Functions.lean`.
|
||||
|
||||
## Differences from grind's `Reify.lean`
|
||||
|
||||
- Uses `MonadMkVar` for variable creation instead of grind's `internalize` + `mkVarCore`
|
||||
- Uses `Sym.getNatValue?`/`Sym.getIntValue?` (pure) instead of `MetaM` versions
|
||||
- No `MonadSetTermId` — term-to-ring-id tracking is grind-specific
|
||||
-/
|
||||
|
||||
section RingReify
|
||||
|
||||
variable [MonadLiftT SymM m] [MonadLiftT MetaM m] [MonadError m] [Monad m] [MonadCanon m] [MonadRing m] [MonadMkVar m]
|
||||
|
||||
def isAddInst (inst : Expr) : m Bool :=
|
||||
return isSameExpr (← getAddFn).appArg! inst
|
||||
def isMulInst (inst : Expr) : m Bool :=
|
||||
return isSameExpr (← getMulFn).appArg! inst
|
||||
def isSubInst (inst : Expr) : m Bool :=
|
||||
return isSameExpr (← getSubFn).appArg! inst
|
||||
def isNegInst (inst : Expr) : m Bool :=
|
||||
return isSameExpr (← getNegFn).appArg! inst
|
||||
def isPowInst (inst : Expr) : m Bool :=
|
||||
return isSameExpr (← getPowFn).appArg! inst
|
||||
def isIntCastInst (inst : Expr) : m Bool :=
|
||||
return isSameExpr (← getIntCastFn).appArg! inst
|
||||
def isNatCastInst (inst : Expr) : m Bool :=
|
||||
return isSameExpr (← getNatCastFn).appArg! inst
|
||||
|
||||
private def reportRingAppIssue [MonadLiftT SymM m] (e : Expr) : m Unit := do
|
||||
reportIssue! "ring term with unexpected instance{indentExpr e}"
|
||||
|
||||
/--
|
||||
Converts a Lean expression `e` into a `RingExpr`.
|
||||
|
||||
If `skipVar` is `true`, returns `none` if `e` is not an interpreted ring term
|
||||
(used for equalities/disequalities). If `false`, treats non-interpreted terms
|
||||
as variables (used for inequalities).
|
||||
-/
|
||||
partial def reifyRing? (e : Expr) (skipVar : Bool := true) : m (Option RingExpr) := do
|
||||
let toVar (e : Expr) : m RingExpr := do
|
||||
return .var (← mkVar e)
|
||||
let asVar (e : Expr) : m RingExpr := do
|
||||
reportRingAppIssue e
|
||||
return .var (← mkVar e)
|
||||
let rec go (e : Expr) : m RingExpr := do
|
||||
match_expr e with
|
||||
| HAdd.hAdd _ _ _ i a b =>
|
||||
if (← isAddInst i) then return .add (← go a) (← go b) else asVar e
|
||||
| HMul.hMul _ _ _ i a b =>
|
||||
if (← isMulInst i) then return .mul (← go a) (← go b) else asVar e
|
||||
| HSub.hSub _ _ _ i a b =>
|
||||
if (← isSubInst i) then return .sub (← go a) (← go b) else asVar e
|
||||
| HPow.hPow _ _ _ i a b =>
|
||||
let some k := Sym.getNatValue? b |>.run | toVar e
|
||||
if (← isPowInst i) then return .pow (← go a) k else asVar e
|
||||
| Neg.neg _ i a =>
|
||||
if (← isNegInst i) then return .neg (← go a) else asVar e
|
||||
| IntCast.intCast _ i a =>
|
||||
if (← isIntCastInst i) then
|
||||
let some k := Sym.getIntValue? a |>.run | toVar e
|
||||
return .intCast k
|
||||
else
|
||||
asVar e
|
||||
| NatCast.natCast _ i a =>
|
||||
if (← isNatCastInst i) then
|
||||
let some k := Sym.getNatValue? a |>.run | toVar e
|
||||
return .natCast k
|
||||
else
|
||||
asVar e
|
||||
| OfNat.ofNat _ n _ =>
|
||||
/-
|
||||
**Note**: We extract `n` directly as a raw nat literal. The grind version uses `MetaM`'s
|
||||
`getNatValue?` which handles multiple encodings (raw literals, nested `OfNat`, etc.).
|
||||
In `SymM`, we assume terms have been canonicalized by `Sym.canon` before reification,
|
||||
so `OfNat.ofNat _ n _` always has a raw nat literal at position 1.
|
||||
-/
|
||||
let .lit (.natVal k) := n | toVar e
|
||||
return .num k
|
||||
| BitVec.ofNat _ n =>
|
||||
let .lit (.natVal k) := n | toVar e
|
||||
return .num k
|
||||
| _ => toVar e
|
||||
let toTopVar (e : Expr) : m (Option RingExpr) := do
|
||||
if skipVar then
|
||||
return none
|
||||
else
|
||||
return some (← toVar e)
|
||||
let asTopVar (e : Expr) : m (Option RingExpr) := do
|
||||
reportRingAppIssue e
|
||||
toTopVar e
|
||||
match_expr e with
|
||||
| HAdd.hAdd _ _ _ i a b =>
|
||||
if (← isAddInst i) then return some (.add (← go a) (← go b)) else asTopVar e
|
||||
| HMul.hMul _ _ _ i a b =>
|
||||
if (← isMulInst i) then return some (.mul (← go a) (← go b)) else asTopVar e
|
||||
| HSub.hSub _ _ _ i a b =>
|
||||
if (← isSubInst i) then return some (.sub (← go a) (← go b)) else asTopVar e
|
||||
| HPow.hPow _ _ _ i a b =>
|
||||
let some k := Sym.getNatValue? b |>.run | asTopVar e
|
||||
if (← isPowInst i) then return some (.pow (← go a) k) else asTopVar e
|
||||
| Neg.neg _ i a =>
|
||||
if (← isNegInst i) then return some (.neg (← go a)) else asTopVar e
|
||||
| IntCast.intCast _ i a =>
|
||||
if (← isIntCastInst i) then
|
||||
let some k := Sym.getIntValue? a |>.run | toTopVar e
|
||||
return some (.intCast k)
|
||||
else
|
||||
asTopVar e
|
||||
| NatCast.natCast _ i a =>
|
||||
if (← isNatCastInst i) then
|
||||
let some k := Sym.getNatValue? a |>.run | toTopVar e
|
||||
return some (.natCast k)
|
||||
else
|
||||
asTopVar e
|
||||
| OfNat.ofNat _ n _ =>
|
||||
let .lit (.natVal k) := n | asTopVar e
|
||||
return some (.num k)
|
||||
| _ => toTopVar e
|
||||
|
||||
end RingReify
|
||||
|
||||
section SemiringReify
|
||||
|
||||
variable [MonadLiftT SymM m] [MonadLiftT MetaM m] [MonadError m] [Monad m] [MonadCanon m] [MonadSemiring m] [MonadMkVar m]
|
||||
|
||||
private def reportSemiringAppIssue [MonadLiftT SymM m] (e : Expr) : m Unit := do
|
||||
reportIssue! "semiring term with unexpected instance{indentExpr e}"
|
||||
|
||||
/--
|
||||
Converts a Lean expression `e` into a `SemiringExpr`.
|
||||
Only recognizes `add`, `mul`, `pow`, `natCast`, and numerals (no `sub`, `neg`, `intCast`).
|
||||
-/
|
||||
partial def reifySemiring? (e : Expr) : m (Option SemiringExpr) := do
|
||||
let toVar (e : Expr) : m SemiringExpr := do
|
||||
return .var (← mkVar e)
|
||||
let asVar (e : Expr) : m SemiringExpr := do
|
||||
reportSemiringAppIssue e
|
||||
return .var (← mkVar e)
|
||||
let rec go (e : Expr) : m SemiringExpr := do
|
||||
match_expr e with
|
||||
| HAdd.hAdd _ _ _ i a b =>
|
||||
if isSameExpr (← getAddFn').appArg! i then return .add (← go a) (← go b) else asVar e
|
||||
| HMul.hMul _ _ _ i a b =>
|
||||
if isSameExpr (← getMulFn').appArg! i then return .mul (← go a) (← go b) else asVar e
|
||||
| HPow.hPow _ _ _ i a b =>
|
||||
let some k := Sym.getNatValue? b |>.run | toVar e
|
||||
if isSameExpr (← getPowFn').appArg! i then return .pow (← go a) k else asVar e
|
||||
| NatCast.natCast _ i a =>
|
||||
if isSameExpr (← getNatCastFn').appArg! i then
|
||||
let some k := Sym.getNatValue? a |>.run | toVar e
|
||||
return .num k
|
||||
else
|
||||
asVar e
|
||||
| OfNat.ofNat _ n _ =>
|
||||
let .lit (.natVal k) := n | toVar e
|
||||
return .num k
|
||||
| _ => toVar e
|
||||
let toTopVar (e : Expr) : m (Option SemiringExpr) := do
|
||||
return some (← toVar e)
|
||||
let asTopVar (e : Expr) : m (Option SemiringExpr) := do
|
||||
reportSemiringAppIssue e
|
||||
toTopVar e
|
||||
match_expr e with
|
||||
| HAdd.hAdd _ _ _ i a b =>
|
||||
if isSameExpr (← getAddFn').appArg! i then return some (.add (← go a) (← go b)) else asTopVar e
|
||||
| HMul.hMul _ _ _ i a b =>
|
||||
if isSameExpr (← getMulFn').appArg! i then return some (.mul (← go a) (← go b)) else asTopVar e
|
||||
| HPow.hPow _ _ _ i a b =>
|
||||
let some k := Sym.getNatValue? b |>.run | return none
|
||||
if isSameExpr (← getPowFn').appArg! i then return some (.pow (← go a) k) else asTopVar e
|
||||
| NatCast.natCast _ i a =>
|
||||
if isSameExpr (← getNatCastFn').appArg! i then
|
||||
let some k := Sym.getNatValue? a |>.run | toTopVar e
|
||||
return some (.num k)
|
||||
else
|
||||
asTopVar e
|
||||
| OfNat.ofNat _ n _ =>
|
||||
let .lit (.natVal k) := n | asTopVar e
|
||||
return some (.num k)
|
||||
| _ => toTopVar e
|
||||
|
||||
end SemiringReify
|
||||
|
||||
end Lean.Meta.Sym.Arith
|
||||
@@ -8,7 +8,7 @@ prelude
|
||||
public import Init.Grind.Ring.CommSemiringAdapter
|
||||
public import Lean.ToExpr
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
namespace Lean.Meta.Sym.Arith
|
||||
open Grind.CommRing
|
||||
/-!
|
||||
`ToExpr` instances for `CommRing.Poly` types.
|
||||
@@ -57,4 +57,4 @@ instance : ToExpr CommRing.Expr where
|
||||
toExpr := ofRingExpr
|
||||
toTypeExpr := mkConst ``CommRing.Expr
|
||||
|
||||
end Lean.Meta.Grind.Arith.CommRing
|
||||
end Lean.Meta.Sym.Arith
|
||||
137
src/Lean/Meta/Sym/Arith/Types.lean
Normal file
137
src/Lean/Meta/Sym/Arith/Types.lean
Normal file
@@ -0,0 +1,137 @@
|
||||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Init.Grind.Ring.CommSemiringAdapter
|
||||
public import Lean.Meta.Sym.SymM
|
||||
public section
|
||||
|
||||
namespace Lean.Meta.Sym.Arith
|
||||
export Lean.Grind.CommRing (Var Power Mon Poly)
|
||||
abbrev RingExpr := Grind.CommRing.Expr
|
||||
/-
|
||||
**Note**: recall that we use ring expressions to represent semiring expressions,
|
||||
and ignore non-applicable constructors.
|
||||
-/
|
||||
abbrev SemiringExpr := Grind.CommRing.Expr
|
||||
|
||||
/-- Classification state for a type with a `Semiring` instance. -/
|
||||
structure Semiring where
|
||||
id : Nat
|
||||
type : Expr
|
||||
/-- Cached `getDecLevel type` -/
|
||||
u : Level
|
||||
/-- `Semiring` instance for `type` -/
|
||||
semiringInst : Expr
|
||||
addFn? : Option Expr := none
|
||||
mulFn? : Option Expr := none
|
||||
powFn? : Option Expr := none
|
||||
natCastFn? : Option Expr := none
|
||||
deriving Inhabited
|
||||
|
||||
/-- Classification state for a type with a `Ring` instance. -/
|
||||
structure Ring where
|
||||
id : Nat
|
||||
type : Expr
|
||||
/-- Cached `getDecLevel type` -/
|
||||
u : Level
|
||||
/-- `Ring` instance for `type` -/
|
||||
ringInst : Expr
|
||||
/-- `Semiring` instance for `type` -/
|
||||
semiringInst : Expr
|
||||
/-- `IsCharP` instance for `type` if available. -/
|
||||
charInst? : Option (Expr × Nat)
|
||||
addFn? : Option Expr := none
|
||||
mulFn? : Option Expr := none
|
||||
subFn? : Option Expr := none
|
||||
negFn? : Option Expr := none
|
||||
powFn? : Option Expr := none
|
||||
intCastFn? : Option Expr := none
|
||||
natCastFn? : Option Expr := none
|
||||
one? : Option Expr := none
|
||||
deriving Inhabited
|
||||
|
||||
/-- Classification state for a type with a `CommRing` instance. -/
|
||||
structure CommRing extends Ring where
|
||||
/-- Inverse function if `fieldInst?` is `some inst` -/
|
||||
invFn? : Option Expr := none
|
||||
/--
|
||||
If this is a `OfSemiring.Q α` ring, this field contains the
|
||||
`semiringId` for `α`.
|
||||
-/
|
||||
semiringId? : Option Nat
|
||||
/-- `CommSemiring` instance for `type` -/
|
||||
commSemiringInst : Expr
|
||||
/-- `CommRing` instance for `type` -/
|
||||
commRingInst : Expr
|
||||
/-- `NoNatZeroDivisors` instance for `type` if available. -/
|
||||
noZeroDivInst? : Option Expr
|
||||
/-- `Field` instance for `type` if available. -/
|
||||
fieldInst? : Option Expr
|
||||
deriving Inhabited
|
||||
|
||||
/--
|
||||
Classification state for a type with a `CommSemiring` instance.
|
||||
Recall that `CommSemiring` types are normalized using the `OfSemiring.Q` envelope.
|
||||
-/
|
||||
structure CommSemiring extends Semiring where
|
||||
/-- Id of the envelope ring `OfSemiring.Q type` -/
|
||||
ringId : Nat
|
||||
/-- `CommSemiring` instance for `type` -/
|
||||
commSemiringInst : Expr
|
||||
/-- `AddRightCancel` instance for `type` if available. -/
|
||||
addRightCancelInst? : Option (Option Expr) := none
|
||||
toQFn? : Option Expr := none
|
||||
deriving Inhabited
|
||||
|
||||
/-- Result of classifying a type's algebraic structure. -/
|
||||
inductive ClassifyResult where
|
||||
| commRing (id : Nat)
|
||||
| nonCommRing (id : Nat)
|
||||
| commSemiring (id : Nat)
|
||||
| nonCommSemiring (id : Nat)
|
||||
| /-- No algebraic structure found. -/ none
|
||||
deriving Inhabited
|
||||
|
||||
/-- Arith type classification state, stored as a `SymExtension`. -/
|
||||
structure State where
|
||||
/-- Exponent threshold for `HPow` evaluation. -/
|
||||
exp : Nat := 8
|
||||
/-- Commutative rings. -/
|
||||
rings : Array CommRing := {}
|
||||
/-- Commutative semirings. -/
|
||||
semirings : Array CommSemiring := {}
|
||||
/-- Non-commutative rings. -/
|
||||
ncRings : Array Ring := {}
|
||||
/-- Non-commutative semirings. -/
|
||||
ncSemirings : Array Semiring := {}
|
||||
/-- Mapping from types to their classification result. Caches failures as `.none`. -/
|
||||
typeClassify : PHashMap ExprPtr ClassifyResult := {}
|
||||
deriving Inhabited
|
||||
|
||||
builtin_initialize arithExt : SymExtension State ← registerSymExtension (return {})
|
||||
|
||||
def getArithState : SymM State :=
|
||||
arithExt.getState
|
||||
|
||||
@[inline] def modifyArithState (f : State → State) : SymM Unit :=
|
||||
arithExt.modifyState f
|
||||
|
||||
/-- Get the exponent threshold. -/
|
||||
def getExpThreshold : SymM Nat :=
|
||||
return (← getArithState).exp
|
||||
|
||||
/-- Set the exponent threshold. -/
|
||||
def setExpThreshold (exp : Nat) : SymM Unit :=
|
||||
modifyArithState fun s => { s with exp }
|
||||
|
||||
/-- Run `k` with a temporary exponent threshold. -/
|
||||
def withExpThreshold (exp : Nat) (k : SymM α) : SymM α := do
|
||||
let oldExp := (← getArithState).exp
|
||||
setExpThreshold exp
|
||||
try k finally setExpThreshold oldExp
|
||||
|
||||
end Lean.Meta.Sym.Arith
|
||||
@@ -5,11 +5,9 @@ Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Poly
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Types
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.RingId
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Internalize
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.ToExpr
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.RingM
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.SemiringM
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.NonCommRingM
|
||||
@@ -21,8 +19,6 @@ public import Lean.Meta.Tactic.Grind.Arith.CommRing.Proof
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.DenoteExpr
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Inv
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.PP
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.VarRename
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.MonadCanon
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.MonadRing
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.MonadSemiring
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Action
|
||||
|
||||
@@ -8,6 +8,7 @@ prelude
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Functions
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
open Sym.Arith
|
||||
/-!
|
||||
Helper functions for converting reified terms back into their denotations.
|
||||
-/
|
||||
|
||||
@@ -8,6 +8,7 @@ prelude
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.MonadRing
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
open Sym.Arith
|
||||
variable [MonadLiftT MetaM m] [MonadError m] [Monad m] [MonadCanon m]
|
||||
|
||||
section
|
||||
|
||||
@@ -6,7 +6,7 @@ Authors: Leonardo de Moura
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.RingM
|
||||
import Lean.Meta.Tactic.Grind.Arith.CommRing.Poly
|
||||
import Lean.Meta.Sym.Arith.Poly
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
|
||||
|
||||
@@ -5,7 +5,8 @@ Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.MonadCanon
|
||||
public import Lean.Meta.Sym.Arith.MonadCanon
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Types
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
|
||||
|
||||
@@ -5,7 +5,8 @@ Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.MonadCanon
|
||||
public import Lean.Meta.Sym.Arith.MonadCanon
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Types
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ prelude
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.RingM
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
|
||||
open Sym.Arith
|
||||
structure NonCommRingM.Context where
|
||||
ringId : Nat
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ prelude
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.SemiringM
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
open Sym.Arith (MonadCanon)
|
||||
|
||||
structure NonCommSemiringM.Context where
|
||||
semiringId : Nat
|
||||
|
||||
@@ -10,6 +10,7 @@ import Lean.Meta.Tactic.Grind.Arith.CommRing.DenoteExpr
|
||||
import Init.Omega
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
open Sym.Arith
|
||||
|
||||
private abbrev M := StateT CommRing MetaM
|
||||
|
||||
|
||||
@@ -12,12 +12,14 @@ import Lean.Data.RArray
|
||||
import Lean.Meta.Tactic.Grind.Diseq
|
||||
import Lean.Meta.Tactic.Grind.ProofUtil
|
||||
import Lean.Meta.Tactic.Grind.Arith.CommRing.DenoteExpr
|
||||
import Lean.Meta.Tactic.Grind.Arith.CommRing.ToExpr
|
||||
import Lean.Meta.Tactic.Grind.Arith.CommRing.VarRename
|
||||
import Lean.Meta.Sym.Arith.ToExpr
|
||||
import Lean.Meta.Sym.Arith.VarRename
|
||||
import Init.Data.Nat.Order
|
||||
import Init.Data.Order.Lemmas
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
open Sym.Arith (MonadCanon)
|
||||
|
||||
/--
|
||||
Returns a context of type `RArray α` containing the variables `vars` where
|
||||
`α` is the type of the ring.
|
||||
|
||||
@@ -9,6 +9,7 @@ public import Lean.Meta.Tactic.Grind.Arith.CommRing.NonCommRingM
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.NonCommSemiringM
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
open Sym.Arith (MonadCanon)
|
||||
|
||||
variable [MonadLiftT MetaM m] [MonadError m] [Monad m] [MonadCanon m] [MonadRing m]
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ public import Lean.Meta.Tactic.Grind.SynthInstance
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.MonadRing
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
open Sym.Arith
|
||||
|
||||
def checkMaxSteps : GoalM Bool := do
|
||||
return (← get').steps >= (← getConfig).ringSteps
|
||||
|
||||
@@ -6,7 +6,7 @@ Authors: Leonardo de Moura
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.RingM
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Poly
|
||||
public import Lean.Meta.Sym.Arith.Poly
|
||||
import Lean.Meta.Tactic.Grind.Arith.EvalNum
|
||||
import Init.Data.Nat.Linear
|
||||
public section
|
||||
|
||||
@@ -11,6 +11,7 @@ import Lean.Meta.Tactic.Grind.Arith.CommRing.DenoteExpr
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Functions
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
open Sym.Arith
|
||||
|
||||
structure SemiringM.Context where
|
||||
semiringId : Nat
|
||||
|
||||
@@ -7,7 +7,7 @@ module
|
||||
prelude
|
||||
public import Init.Grind.Ring.CommSemiringAdapter
|
||||
public import Lean.Meta.Tactic.Grind.Types
|
||||
import Lean.Meta.Tactic.Grind.Arith.CommRing.Poly
|
||||
import Lean.Meta.Sym.Arith.Poly
|
||||
public section
|
||||
|
||||
namespace Lean.Meta.Grind.Arith.CommRing
|
||||
|
||||
@@ -14,8 +14,8 @@ import Lean.Meta.Tactic.Grind.Arith.Cutsat.CommRing
|
||||
import Lean.Meta.Tactic.Grind.Arith.Cutsat.Util
|
||||
import Lean.Meta.Tactic.Grind.Arith.Cutsat.Nat
|
||||
import Lean.Meta.Tactic.Grind.Arith.Cutsat.VarRename
|
||||
import Lean.Meta.Tactic.Grind.Arith.CommRing.VarRename
|
||||
import Lean.Meta.Tactic.Grind.Arith.CommRing.ToExpr
|
||||
import Lean.Meta.Sym.Arith.VarRename
|
||||
import Lean.Meta.Sym.Arith.ToExpr
|
||||
import Init.Data.Nat.Order
|
||||
import Init.Data.Order.Lemmas
|
||||
public section
|
||||
|
||||
@@ -9,6 +9,7 @@ public import Lean.Meta.Tactic.Grind.Arith.Linear.Types
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.RingM
|
||||
public section
|
||||
namespace Lean.Meta.Grind.Arith.Linear
|
||||
open Sym.Arith (MonadCanon)
|
||||
|
||||
def get' : GoalM State := do
|
||||
linearExt.getState
|
||||
|
||||
@@ -11,8 +11,8 @@ import Lean.Data.RArray
|
||||
import Lean.Meta.Tactic.Grind.Arith.Linear.ToExpr
|
||||
import Lean.Meta.Tactic.Grind.Diseq
|
||||
import Lean.Meta.Tactic.Grind.ProofUtil
|
||||
import Lean.Meta.Tactic.Grind.Arith.CommRing.VarRename
|
||||
import Lean.Meta.Tactic.Grind.Arith.CommRing.ToExpr
|
||||
import Lean.Meta.Sym.Arith.VarRename
|
||||
import Lean.Meta.Sym.Arith.ToExpr
|
||||
import Lean.Meta.Tactic.Grind.Arith.Linear.VarRename
|
||||
public import Lean.Meta.Tactic.Grind.Arith.Linear.DenoteExpr
|
||||
public import Lean.Meta.Tactic.Grind.Arith.Linear.OfNatModule
|
||||
|
||||
@@ -97,6 +97,8 @@ def mkCnstrNorm0 (s : Struct) (ringInst : Expr) (kind : CnstrKind) (lhs rhs : Ex
|
||||
| .le => mkLeNorm0 s ringInst lhs rhs
|
||||
| .lt => mkLtNorm0 s ringInst lhs rhs
|
||||
|
||||
open Sym.Arith (MonadCanon)
|
||||
|
||||
/--
|
||||
Returns `rel lhs (rhs + 0)`
|
||||
-/
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
module
|
||||
public import Init.Grind.Ring.CommSolver
|
||||
public import Lean.Meta.Tactic.Grind.Arith.CommRing.Poly
|
||||
public import Lean.Meta.Sym.Arith.Poly
|
||||
open Lean.Grind.CommRing
|
||||
|
||||
def w : Var := 0
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
module
|
||||
import Lean.Meta.Tactic.Grind.Arith.CommRing.Poly
|
||||
import Lean.Meta.Sym.Arith.Poly
|
||||
open Lean.Grind.CommRing
|
||||
|
||||
def w : Expr := .var 0
|
||||
|
||||
138
tests/elab/sym_arith_classify.lean
Normal file
138
tests/elab/sym_arith_classify.lean
Normal file
@@ -0,0 +1,138 @@
|
||||
import Lean
|
||||
|
||||
/-!
|
||||
# Tests for `Sym.Arith.Classify`, `Sym.Arith.EvalNum`, and `Sym.Arith.Functions`
|
||||
-/
|
||||
|
||||
open Lean Meta Sym Arith
|
||||
|
||||
/-- Extract the value of a definition by name. -/
|
||||
def getDefValue (n : Name) : MetaM Expr := do
|
||||
let some (.defnInfo info) := (← getEnv).find? n
|
||||
| throwError "expected definition: {n}"
|
||||
return info.value
|
||||
|
||||
/-! ## Classification tests -/
|
||||
|
||||
deriving instance Repr for ClassifyResult
|
||||
|
||||
/-- info: Lean.Meta.Sym.Arith.ClassifyResult.commRing 0 -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
logInfo m!"{repr (← classify? (mkConst ``Int))}"
|
||||
|
||||
/-- info: Lean.Meta.Sym.Arith.ClassifyResult.commSemiring 0 -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
logInfo m!"{repr (← classify? (mkConst ``Nat))}"
|
||||
|
||||
/-- info: Lean.Meta.Sym.Arith.ClassifyResult.none -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
logInfo m!"{repr (← classify? (mkConst ``Bool))}"
|
||||
|
||||
-- Classifying the same type twice should return cached result with same id
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
let .commRing id1 ← classify? (mkConst ``Int) | unreachable!
|
||||
let .commRing id2 ← classify? (mkConst ``Int) | unreachable!
|
||||
logInfo m!"{id1 == id2}"
|
||||
|
||||
/--
|
||||
info: Lean.Meta.Sym.Arith.ClassifyResult.commRing 0
|
||||
---
|
||||
info: Lean.Meta.Sym.Arith.ClassifyResult.commSemiring 0
|
||||
---
|
||||
info: Lean.Meta.Sym.Arith.ClassifyResult.commRing 2
|
||||
---
|
||||
info: Lean.Meta.Sym.Arith.ClassifyResult.commRing 1
|
||||
-/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
let int ← shareCommon (mkConst ``Int)
|
||||
let nat ← shareCommon (mkConst ``Nat)
|
||||
let rat ← shareCommon (mkConst ``Rat)
|
||||
logInfo m!"{repr (← classify? int)}"
|
||||
logInfo m!"{repr (← classify? nat)}"
|
||||
logInfo m!"{repr (← classify? rat)}"
|
||||
let inst ← Sym.synthInstance (mkApp (mkConst ``Grind.Semiring [0]) nat)
|
||||
let ofSemiring ← shareCommon (← Sym.canon <| mkApp2 (mkConst ``Grind.Ring.OfSemiring.Q [0]) nat inst)
|
||||
logInfo m!"{repr (← classify? ofSemiring)}"
|
||||
|
||||
/-! ## EvalNum tests -/
|
||||
|
||||
def natZero : Nat := 0
|
||||
def natSucc3 : Nat := Nat.succ (Nat.succ (Nat.succ 0))
|
||||
def natSeven : Nat := 7
|
||||
def natAdd : Nat := 2 + 3
|
||||
def natMul : Nat := 2 * 3
|
||||
def natPow : Nat := 2 ^ 3
|
||||
def natBigPow : Nat := 2 ^ 100
|
||||
def natPow10 : Nat := 2 ^ 10
|
||||
|
||||
/-- info: some (0) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
logInfo m!"{← evalNat? (← getDefValue ``natZero)}"
|
||||
|
||||
/-- info: some (3) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
logInfo m!"{← evalNat? (← getDefValue ``natSucc3)}"
|
||||
|
||||
/-- info: some (7) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
logInfo m!"{← evalNat? (← getDefValue ``natSeven)}"
|
||||
|
||||
/-- info: some (5) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
logInfo m!"{← evalNat? (← getDefValue ``natAdd)}"
|
||||
|
||||
/-- info: some (6) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
logInfo m!"{← evalNat? (← getDefValue ``natMul)}"
|
||||
|
||||
/-- info: some (8) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
logInfo m!"{← evalNat? (← getDefValue ``natPow)}"
|
||||
|
||||
/-! ## Exp threshold tests -/
|
||||
|
||||
-- 2 ^ 100 should fail with default exp threshold (8)
|
||||
/-- info: none -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
logInfo m!"{← evalNat? (← getDefValue ``natBigPow)}"
|
||||
|
||||
-- 2 ^ 10 succeeds with exp threshold raised to 20
|
||||
/-- info: some (1024) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
withExpThreshold 20 do
|
||||
logInfo m!"{← evalNat? (← getDefValue ``natPow10)}"
|
||||
|
||||
/-! ## Int EvalNum tests -/
|
||||
|
||||
def intNeg : Int := -5
|
||||
def intAdd : Int := 3 + (-2)
|
||||
def intMul : Int := (-3) * 4
|
||||
|
||||
/-- info: some (-5) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
logInfo m!"{← evalInt? (← getDefValue ``intNeg)}"
|
||||
|
||||
/-- info: some (1) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
logInfo m!"{← evalInt? (← getDefValue ``intAdd)}"
|
||||
|
||||
/-- info: some (-12) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do
|
||||
logInfo m!"{← evalInt? (← getDefValue ``intMul)}"
|
||||
172
tests/elab/sym_arith_reify.lean
Normal file
172
tests/elab/sym_arith_reify.lean
Normal file
@@ -0,0 +1,172 @@
|
||||
import Lean
|
||||
|
||||
/-!
|
||||
# Tests for `Sym.Arith.Reify`
|
||||
-/
|
||||
|
||||
open Lean Meta Sym Arith
|
||||
|
||||
/-- Extract the value of a definition by name. -/
|
||||
def getDefValue (n : Name) : MetaM Expr := do
|
||||
let some (.defnInfo info) := (← getEnv).find? n
|
||||
| throwError "expected definition: {n}"
|
||||
return info.value
|
||||
|
||||
/-!
|
||||
## Setup: a simple monad for testing reification
|
||||
-/
|
||||
|
||||
structure TestState where
|
||||
ring : CommRing
|
||||
vars : Array Expr := {}
|
||||
varMap : PHashMap ExprPtr Var := {}
|
||||
|
||||
abbrev TestM := StateRefT TestState SymM
|
||||
|
||||
instance : MonadCanon TestM where
|
||||
canonExpr e := Sym.canon e
|
||||
synthInstance? e := Sym.synthInstance? e
|
||||
|
||||
instance : MonadCommRing TestM where
|
||||
getCommRing := return (← get).ring
|
||||
modifyCommRing f := modify fun s => { s with ring := f s.ring }
|
||||
|
||||
instance : MonadMkVar TestM where
|
||||
mkVar e := do
|
||||
if let some v := (← get).varMap.find? { expr := e } then
|
||||
return v
|
||||
let v := (← get).vars.size
|
||||
modify fun s => { s with
|
||||
vars := s.vars.push e
|
||||
varMap := s.varMap.insert { expr := e } v
|
||||
}
|
||||
return v
|
||||
|
||||
instance : MonadGetVar TestM where
|
||||
getVar x := return (← get).vars[x]!
|
||||
|
||||
/-- Run a `TestM` on `Int`'s `CommRing`, canonicalizing `e` first. -/
|
||||
def reifyIntExpr (n : Name) (skipVar := true) : TestM (Option RingExpr) := do
|
||||
let e ← canonExpr (← getDefValue n)
|
||||
reifyRing? e (skipVar := skipVar)
|
||||
|
||||
def runTestOnInt (x : TestM α) : SymM α := do
|
||||
let .commRing id ← classify? (mkConst ``Int) | throwError "Int is not a CommRing"
|
||||
let ring := (← getArithState).rings[id]!
|
||||
x |>.run' { ring }
|
||||
|
||||
/-! ## Reify ring tests on Int -/
|
||||
|
||||
deriving instance Repr for Lean.Grind.CommRing.Expr
|
||||
|
||||
def intAdd : Int := 2 + 3
|
||||
def intMulAdd : Int := 2 * 3 + 1
|
||||
def intNeg : Int := -5
|
||||
def intPow : Int := 2 ^ 3
|
||||
def intSub : Int := 7 - 2
|
||||
|
||||
/-- info: some (Lean.Grind.CommRing.Expr.add (Lean.Grind.CommRing.Expr.num 2) (Lean.Grind.CommRing.Expr.num 3)) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do runTestOnInt do
|
||||
logInfo m!"{repr (← reifyIntExpr ``intAdd)}"
|
||||
|
||||
/--
|
||||
info: some (Lean.Grind.CommRing.Expr.add
|
||||
(Lean.Grind.CommRing.Expr.mul (Lean.Grind.CommRing.Expr.num 2) (Lean.Grind.CommRing.Expr.num 3))
|
||||
(Lean.Grind.CommRing.Expr.num 1))
|
||||
-/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do runTestOnInt do
|
||||
logInfo m!"{repr (← reifyIntExpr ``intMulAdd)}"
|
||||
|
||||
/-- info: some (Lean.Grind.CommRing.Expr.neg (Lean.Grind.CommRing.Expr.num 5)) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do runTestOnInt do
|
||||
logInfo m!"{repr (← reifyIntExpr ``intNeg)}"
|
||||
|
||||
/-- info: some (Lean.Grind.CommRing.Expr.pow (Lean.Grind.CommRing.Expr.num 2) 3) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do runTestOnInt do
|
||||
logInfo m!"{repr (← reifyIntExpr ``intPow)}"
|
||||
|
||||
/--
|
||||
info: some (Lean.Grind.CommRing.Expr.sub (Lean.Grind.CommRing.Expr.num 7) (Lean.Grind.CommRing.Expr.num 2))
|
||||
-/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do runTestOnInt do
|
||||
logInfo m!"{repr (← reifyIntExpr ``intSub)}"
|
||||
|
||||
-- skipVar test: a non-arithmetic term returns none with skipVar=true
|
||||
/-- info: none -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do runTestOnInt do
|
||||
let a ← mkFreshExprMVar (mkConst ``Int)
|
||||
logInfo m!"{repr (← reifyRing? a)}"
|
||||
|
||||
-- skipVar=false: a non-arithmetic term becomes a variable
|
||||
/-- info: some (Lean.Grind.CommRing.Expr.var 0) -/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do runTestOnInt do
|
||||
let a ← mkFreshExprMVar (mkConst ``Int)
|
||||
logInfo m!"{repr (← reifyRing? a (skipVar := false))}"
|
||||
|
||||
opaque a : Int
|
||||
opaque b : Int
|
||||
opaque c : Int
|
||||
def e := (a + b*2) - (c*a + a*(3*b + c))
|
||||
|
||||
/--
|
||||
info: some (Lean.Grind.CommRing.Expr.sub
|
||||
(Lean.Grind.CommRing.Expr.add
|
||||
(Lean.Grind.CommRing.Expr.var 0)
|
||||
(Lean.Grind.CommRing.Expr.mul (Lean.Grind.CommRing.Expr.var 1) (Lean.Grind.CommRing.Expr.num 2)))
|
||||
(Lean.Grind.CommRing.Expr.add
|
||||
(Lean.Grind.CommRing.Expr.mul (Lean.Grind.CommRing.Expr.var 2) (Lean.Grind.CommRing.Expr.var 0))
|
||||
(Lean.Grind.CommRing.Expr.mul
|
||||
(Lean.Grind.CommRing.Expr.var 0)
|
||||
(Lean.Grind.CommRing.Expr.add
|
||||
(Lean.Grind.CommRing.Expr.mul (Lean.Grind.CommRing.Expr.num 3) (Lean.Grind.CommRing.Expr.var 1))
|
||||
(Lean.Grind.CommRing.Expr.var 2)))))
|
||||
---
|
||||
info: #[a, b, c]
|
||||
-/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do runTestOnInt do
|
||||
logInfo m!"{repr (← reifyIntExpr ``e)}"
|
||||
logInfo (← get).vars
|
||||
|
||||
/-! ## Roundtrip tests: reify then denote -/
|
||||
|
||||
/-- Reify an expression, denote it back, and check they're definitionally equal. -/
|
||||
def roundtrip (n : Name) : TestM Unit := do
|
||||
let orig ← canonExpr (← getDefValue n)
|
||||
let some re ← reifyRing? orig (skipVar := false) | throwError "reify failed"
|
||||
let vars := (← get).vars
|
||||
let denoted ← denoteRingExpr vars re
|
||||
let denoted ← canonExpr denoted
|
||||
unless (← isDefEq orig denoted) do
|
||||
logInfo m!"MISMATCH for {n}:\n orig: {orig}\n denoted: {denoted}"
|
||||
return
|
||||
logInfo m!"roundtrip OK: {n}: {denoted}"
|
||||
|
||||
/--
|
||||
info: roundtrip OK: intAdd: 2 + 3
|
||||
---
|
||||
info: roundtrip OK: intMulAdd: 2 * 3 + 1
|
||||
---
|
||||
info: roundtrip OK: intNeg: -5
|
||||
---
|
||||
info: roundtrip OK: intPow: 2 ^ 3
|
||||
---
|
||||
info: roundtrip OK: intSub: 7 - 2
|
||||
---
|
||||
info: roundtrip OK: e: a + b * 2 - (c * a + a * (3 * b + c))
|
||||
-/
|
||||
#guard_msgs in
|
||||
run_meta SymM.run do runTestOnInt do
|
||||
roundtrip ``intAdd
|
||||
roundtrip ``intMulAdd
|
||||
roundtrip ``intNeg
|
||||
roundtrip ``intPow
|
||||
roundtrip ``intSub
|
||||
roundtrip ``e
|
||||
Reference in New Issue
Block a user