Compare commits

...

1 Commits

Author SHA1 Message Date
Leonardo de Moura
d82844af61 fix: fold raw Nat literals at dsimp
closes #2916

Remark: this PR also renames `Expr.natLit?` ==> `Expr.rawNatLit?`.
Motivation: consistent naming convention: `Expr.isRawNatLit`.
2024-03-06 10:09:06 -08:00
6 changed files with 115 additions and 20 deletions

View File

@@ -924,7 +924,7 @@ def isRawNatLit : Expr → Bool
| lit (Literal.natVal _) => true
| _ => false
def natLit? : Expr Option Nat
def rawNatLit? : Expr Option Nat
| lit (Literal.natVal v) => v
| _ => none

View File

@@ -33,7 +33,7 @@ partial def reduce (e : Expr) (explicitOnly skipTypes skipProofs := true) : Meta
else
args args.modifyM i visit
if f.isConstOf ``Nat.succ && args.size == 1 && args[0]!.isRawNatLit then
return mkRawNatLit (args[0]!.natLit?.get! + 1)
return mkRawNatLit (args[0]!.rawNatLit?.get! + 1)
else
return mkAppN f args
| Expr.lam .. => lambdaTelescope e fun xs b => do mkLambdaFVars xs ( visit b)

View File

@@ -35,6 +35,21 @@ def Config.updateArith (c : Config) : CoreM Config := do
def isOfNatNatLit (e : Expr) : Bool :=
e.isAppOfArity ``OfNat.ofNat 3 && e.appFn!.appArg!.isRawNatLit
/--
If `e` is a raw Nat literal and `OfNat.ofNat` is not in the list of declarations to unfold,
return an `OfNat.ofNat`-application.
-/
def foldRawNatLit (e : Expr) : SimpM Expr := do
match e.rawNatLit? with
| some n =>
/- If `OfNat.ofNat` is marked to be unfolded, we do not pack orphan nat literals as `OfNat.ofNat` applications
to avoid non-termination. See issue #788. -/
if ( readThe Simp.Context).isDeclToUnfold ``OfNat.ofNat then
return e
else
return toExpr n
| none => return e
private def reduceProjFn? (e : Expr) : SimpM (Option Expr) := do
matchConst e.getAppFn (fun _ => pure none) fun cinfo _ => do
match ( getProjectionFnInfo? cinfo.name) with
@@ -179,7 +194,7 @@ private def reduceStep (e : Expr) : SimpM Expr := do
trace[Meta.Tactic.simp.rewrite] "unfold {mkConst e.getAppFn.constName!}, {e} ==> {e'}"
recordSimpTheorem (.decl e.getAppFn.constName!)
return e'
| none => return e
| none => foldRawNatLit e
private partial def reduce (e : Expr) : SimpM Expr := withIncRecDepth do
let e' reduceStep e
@@ -233,17 +248,6 @@ def withNewLemmas {α} (xs : Array Expr) (f : SimpM α) : SimpM α := do
else
f
def simpLit (e : Expr) : SimpM Result := do
match e.natLit? with
| some n =>
/- If `OfNat.ofNat` is marked to be unfolded, we do not pack orphan nat literals as `OfNat.ofNat` applications
to avoid non-termination. See issue #788. -/
if ( readThe Simp.Context).isDeclToUnfold ``OfNat.ofNat then
return { expr := e }
else
return { expr := ( mkNumeral (mkConst ``Nat) n) }
| none => return { expr := e }
def simpProj (e : Expr) : SimpM Result := do
match ( reduceProj? e) with
| some e => return { expr := e }
@@ -406,13 +410,27 @@ private def dsimpReduce : DSimproc := fun e => do
eNew reduceFVar ( getConfig) ( getSimpTheorems) eNew
if eNew != e then return .visit eNew else return .done e
/--
Auliliary `dsimproc` for not visiting `OfNat.ofNat` application subterms.
This is the `dsimp` equivalent of the approach used at `visitApp`.
Recall that we fold orphan raw Nat literals.
-/
private def doNotVisitOfNat : DSimproc := fun e => do
if isOfNatNatLit e then
if ( readThe Simp.Context).isDeclToUnfold ``OfNat.ofNat then
return .continue e
else
return .done e
else
return .continue e
@[export lean_dsimp]
private partial def dsimpImpl (e : Expr) : SimpM Expr := do
let cfg getConfig
unless cfg.dsimp do
return e
let m getMethods
let pre := m.dpre
let pre := m.dpre >> doNotVisitOfNat
let post := m.dpost >> dsimpReduce
transform (usedLetOnly := cfg.zeta) e (pre := pre) (post := post)
@@ -533,7 +551,7 @@ def congr (e : Expr) : SimpM Result := do
def simpApp (e : Expr) : SimpM Result := do
if isOfNatNatLit e then
-- Recall that we expand "orphan" kernel nat literals `n` into `ofNat n`
-- Recall that we expand "orphan" kernel Nat literals `n` into `OfNat.ofNat n`
return { expr := e }
else
congr e
@@ -549,7 +567,7 @@ def simpStep (e : Expr) : SimpM Result := do
| .const .. => simpConst e
| .bvar .. => unreachable!
| .sort .. => return { expr := e }
| .lit .. => simpLit e
| .lit .. => return { expr := e }
| .mvar .. => return { expr := ( instantiateMVars e) }
| .fvar .. => return { expr := ( reduceFVar ( getConfig) ( getSimpTheorems) e) }

77
tests/lean/run/2916.lean Normal file
View File

@@ -0,0 +1,77 @@
set_option pp.coercions false -- Show `OfNat.ofNat` when present for clarity
/--
warning: declaration uses 'sorry'
---
info: x : Nat
⊢ OfNat.ofNat 2 = x
-/
#guard_msgs in
example : nat_lit 2 = x := by
simp only
trace_state
sorry
/--
warning: declaration uses 'sorry'
---
info: x : Nat
⊢ OfNat.ofNat 2 = x
-/
#guard_msgs in
example : nat_lit 2 = x := by
dsimp only -- dsimp made no progress
trace_state
sorry
/--
warning: declaration uses 'sorry'
---
info: α : Nat → Type
f : (n : Nat) → α n
x : α (OfNat.ofNat 2)
⊢ f (OfNat.ofNat 2) = x
-/
#guard_msgs in
example (α : Nat Type) (f : (n : Nat) α n) (x : α 2) : f (nat_lit 2) = x := by
simp only
trace_state
sorry
/--
info: x : Nat
f : Nat → Nat
h : f (OfNat.ofNat 2) = x
⊢ f (OfNat.ofNat 2) = x
---
info: x : Nat
f : Nat → Nat
h : f (OfNat.ofNat 2) = x
⊢ f 2 = x
-/
#guard_msgs in
example (f : Nat Nat) (h : f 2 = x) : f 2 = x := by
trace_state
simp [OfNat.ofNat]
trace_state
assumption
/--
warning: declaration uses 'sorry'
---
info: α : Nat → Type
f : (n : Nat) → α n
x : α (OfNat.ofNat 2)
⊢ f (OfNat.ofNat 2) = x
---
info: α : Nat → Type
f : (n : Nat) → α n
x : α (OfNat.ofNat 2)
⊢ f 2 = x
-/
#guard_msgs in
example (α : Nat Type) (f : (n : Nat) α n) (x : α 2) : f 2 = x := by
trace_state
simp [OfNat.ofNat]
trace_state
sorry

View File

@@ -100,8 +100,8 @@ def extractXY : Lean.Expr → Lean.MetaM Coords
let sizeArgs := Lean.Expr.getAppArgs e'
let x Lean.Meta.whnf sizeArgs[0]!
let y Lean.Meta.whnf sizeArgs[1]!
let numCols := (Lean.Expr.natLit? x).get!
let numRows := (Lean.Expr.natLit? y).get!
let numCols := (Lean.Expr.rawNatLit? x).get!
let numRows := (Lean.Expr.rawNatLit? y).get!
return Coords.mk numCols numRows
partial def extractWallList : Lean.Expr Lean.MetaM (List Coords)

View File

@@ -676,7 +676,7 @@ check t;
(match t.arrayLit? with
| some (_, xs) => do
checkM $ pure $ xs.length == 2;
(match (xs.get! 0).natLit?, (xs.get! 1).natLit? with
(match (xs.get! 0).rawNatLit?, (xs.get! 1).rawNatLit? with
| some 1, some 2 => pure ()
| _, _ => throwError "nat lits expected")
| none => throwError "array lit expected")