Compare commits

...

3 Commits

Author SHA1 Message Date
Kim Morrison
b90fc5b7b5 semireducible 2024-09-23 12:21:02 +10:00
Kim Morrison
eeae29c126 fix test 2024-09-23 11:53:15 +10:00
Kim Morrison
2106a013c3 first try 2024-09-23 11:51:06 +10:00
5 changed files with 114 additions and 48 deletions

View File

@@ -162,19 +162,16 @@ instance : Inhabited (Array α) where
@[simp] def isEmpty (a : Array α) : Bool :=
a.size = 0
-- TODO(Leo): cleanup
@[specialize]
def isEqvAux (a b : Array α) (hsz : a.size = b.size) (p : α α Bool) (i : Nat) : Bool :=
if h : i < a.size then
have : i < b.size := hsz h
p a[i] b[i] && isEqvAux a b hsz p (i+1)
else
true
decreasing_by simp_wf; decreasing_trivial_pre_omega
def isEqvAux (a b : Array α) (hsz : a.size = b.size) (p : α α Bool) :
(i : Nat) (_ : i a.size), Bool
| 0, _ => true
| i+1, h =>
p a[i] (b[i]'(hsz h)) && isEqvAux a b hsz p i (Nat.le_trans (Nat.le_add_right i 1) h)
@[inline] def isEqv (a b : Array α) (p : α α Bool) : Bool :=
if h : a.size = b.size then
isEqvAux a b h p 0
isEqvAux a b h p a.size (Nat.le_refl a.size)
else
false
@@ -188,9 +185,10 @@ ofFn f = #[f 0, f 1, ... , f(n - 1)]
``` -/
def ofFn {n} (f : Fin n α) : Array α := go 0 (mkEmpty n) where
/-- Auxiliary for `ofFn`. `ofFn.go f i acc = acc ++ #[f i, ..., f(n - 1)]` -/
@[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
go (i : Nat) (acc : Array α) : Array α :=
if h : i < n then go (i+1) (acc.push (f i, h)) else acc
decreasing_by simp_wf; decreasing_trivial_pre_omega
decreasing_by simp_wf; decreasing_trivial_pre_omega
/-- The array `#[0, 1, ..., n - 1]`. -/
def range (n : Nat) : Array Nat :=
@@ -389,11 +387,12 @@ unsafe def mapMUnsafe {α : Type u} {β : Type v} {m : Type v → Type w} [Monad
def mapM {α : Type u} {β : Type v} {m : Type v Type w} [Monad m] (f : α m β) (as : Array α) : m (Array β) :=
-- Note: we cannot use `foldlM` here for the reference implementation because this calls
-- `bind` and `pure` too many times. (We are not assuming `m` is a `LawfulMonad`)
let rec map (i : Nat) (r : Array β) : m (Array β) := do
if hlt : i < as.size then
map (i+1) (r.push ( f as[i]))
else
pure r
let rec @[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
map (i : Nat) (r : Array β) : m (Array β) := do
if hlt : i < as.size then
map (i+1) (r.push ( f as[i]))
else
pure r
decreasing_by simp_wf; decreasing_trivial_pre_omega
map 0 (mkEmpty as.size)
@@ -457,7 +456,8 @@ unsafe def anyMUnsafe {α : Type u} {m : Type → Type w} [Monad m] (p : α
@[implemented_by anyMUnsafe]
def anyM {α : Type u} {m : Type Type w} [Monad m] (p : α m Bool) (as : Array α) (start := 0) (stop := as.size) : m Bool :=
let any (stop : Nat) (h : stop as.size) :=
let rec loop (j : Nat) : m Bool := do
let rec @[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
loop (j : Nat) : m Bool := do
if hlt : j < stop then
have : j < as.size := Nat.lt_of_lt_of_le hlt h
if ( p as[j]) then
@@ -547,7 +547,8 @@ def findRev? {α : Type} (as : Array α) (p : α → Bool) : Option α :=
@[inline]
def findIdx? {α : Type u} (as : Array α) (p : α Bool) : Option Nat :=
let rec loop (j : Nat) :=
let rec @[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
loop (j : Nat) :=
if h : j < as.size then
if p as[j] then some j else loop (j + 1)
else none
@@ -557,6 +558,7 @@ def findIdx? {α : Type u} (as : Array α) (p : α → Bool) : Option Nat :=
def getIdx? [BEq α] (a : Array α) (v : α) : Option Nat :=
a.findIdx? fun a => a == v
@[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
def indexOfAux [BEq α] (a : Array α) (v : α) (i : Nat) : Option (Fin a.size) :=
if h : i < a.size then
let idx : Fin a.size := i, h;
@@ -678,6 +680,7 @@ where
else
as
@[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
def popWhile (p : α Bool) (as : Array α) : Array α :=
if h : as.size > 0 then
if p (as.get as.size - 1, Nat.sub_lt h (by decide)) then
@@ -689,7 +692,8 @@ def popWhile (p : α → Bool) (as : Array α) : Array α :=
decreasing_by simp_wf; decreasing_trivial_pre_omega
def takeWhile (p : α Bool) (as : Array α) : Array α :=
let rec go (i : Nat) (r : Array α) : Array α :=
let rec @[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
go (i : Nat) (r : Array α) : Array α :=
if h : i < as.size then
let a := as.get i, h
if p a then
@@ -705,6 +709,7 @@ def takeWhile (p : α → Bool) (as : Array α) : Array α :=
This function takes worst case O(n) time because
it has to backshift all elements at positions greater than `i`.-/
@[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
def feraseIdx (a : Array α) (i : Fin a.size) : Array α :=
if h : i.val + 1 < a.size then
let a' := a.swap i.val + 1, h i
@@ -739,7 +744,8 @@ def erase [BEq α] (as : Array α) (a : α) : Array α :=
/-- Insert element `a` at position `i`. -/
@[inline] def insertAt (as : Array α) (i : Fin (as.size + 1)) (a : α) : Array α :=
let rec loop (as : Array α) (j : Fin as.size) :=
let rec @[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
loop (as : Array α) (j : Fin as.size) :=
if i.1 < j then
let j' := j-1, Nat.lt_of_le_of_lt (Nat.pred_le _) j.2
let as := as.swap j' j
@@ -757,6 +763,7 @@ def insertAt! (as : Array α) (i : Nat) (a : α) : Array α :=
insertAt as i, Nat.lt_succ_of_le h a
else panic! "invalid index"
@[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
def isPrefixOfAux [BEq α] (as bs : Array α) (hle : as.size bs.size) (i : Nat) : Bool :=
if h : i < as.size then
let a := as[i]
@@ -778,7 +785,8 @@ def isPrefixOf [BEq α] (as bs : Array α) : Bool :=
else
false
@[specialize] def zipWithAux (f : α β γ) (as : Array α) (bs : Array β) (i : Nat) (cs : Array γ) : Array γ :=
@[semireducible, specialize] -- This is otherwise irreducible because it uses well-founded recursion.
def zipWithAux (f : α β γ) (as : Array α) (bs : Array β) (i : Nat) (cs : Array γ) : Array γ :=
if h : i < as.size then
let a := as[i]
if h : i < bs.size then
@@ -814,6 +822,7 @@ private def allDiffAuxAux [BEq α] (as : Array α) (a : α) : forall (i : Nat),
have : i < as.size := Nat.lt_trans (Nat.lt_succ_self _) h;
a != as[i] && allDiffAuxAux as a i this
@[semireducible] -- This is otherwise irreducible because it uses well-founded recursion.
private def allDiffAux [BEq α] (as : Array α) (i : Nat) : Bool :=
if h : i < as.size then
allDiffAuxAux as as[i] i h && allDiffAux as (i+1)

View File

@@ -9,37 +9,37 @@ import Init.ByCases
namespace Array
theorem eq_of_isEqvAux [DecidableEq α] (a b : Array α) (hsz : a.size = b.size) (i : Nat) (hi : i a.size) (heqv : Array.isEqvAux a b hsz (fun x y => x = y) i) (j : Nat) (low : i j) (high : j < a.size) : a[j] = b[j]'(hsz high) := by
by_cases h : i < a.size
· unfold Array.isEqvAux at heqv
simp [h] at heqv
have hind := eq_of_isEqvAux a b hsz (i+1) (Nat.succ_le_of_lt h) heqv.2
by_cases heq : i = j
· subst heq; exact heqv.1
· exact hind j (Nat.succ_le_of_lt (Nat.lt_of_le_of_ne low heq)) high
· have heq : i = a.size := Nat.le_antisymm hi (Nat.ge_of_not_lt h)
subst heq
exact absurd (Nat.lt_of_lt_of_le high low) (Nat.lt_irrefl j)
termination_by a.size - i
decreasing_by decreasing_trivial_pre_omega
theorem eq_of_isEqvAux
[DecidableEq α] (a b : Array α) (hsz : a.size = b.size) (i : Nat) (hi : i a.size)
(heqv : Array.isEqvAux a b hsz (fun x y => x = y) i hi)
(j : Nat) (hj : j < i) : a[j]'(Nat.lt_of_lt_of_le hj hi) = b[j]'(Nat.lt_of_lt_of_le hj (hsz hi)) := by
induction i with
| zero => contradiction
| succ i ih =>
simp only [Array.isEqvAux, Bool.and_eq_true, decide_eq_true_eq] at heqv
by_cases hj' : j < i
next =>
exact ih _ heqv.right hj'
next =>
replace hj' : j = i := Nat.eq_of_le_of_lt_succ (Nat.not_lt.mp hj') hj
subst hj'
exact heqv.left
theorem eq_of_isEqv [DecidableEq α] (a b : Array α) : Array.isEqv a b (fun x y => x = y) a = b := by
simp [Array.isEqv]
split
next hsz =>
intro h
have aux := eq_of_isEqvAux a b hsz 0 (Nat.zero_le ..) h
exact ext a b hsz fun i h _ => aux i (Nat.zero_le ..) _
have aux := eq_of_isEqvAux a b hsz a.size (Nat.le_refl ..) h
exact ext a b hsz fun i h _ => aux i h
next => intro; contradiction
theorem isEqvAux_self [DecidableEq α] (a : Array α) (i : Nat) : Array.isEqvAux a a rfl (fun x y => x = y) i = true := by
unfold Array.isEqvAux
split
next h => simp [h, isEqvAux_self a (i+1)]
next h => simp [h]
termination_by a.size - i
decreasing_by decreasing_trivial_pre_omega
theorem isEqvAux_self [DecidableEq α] (a : Array α) (i : Nat) (h : i a.size) :
Array.isEqvAux a a rfl (fun x y => x = y) i h = true := by
induction i with
| zero => simp [Array.isEqvAux]
| succ i ih =>
simp_all only [isEqvAux, decide_True, Bool.and_self]
theorem isEqv_self [DecidableEq α] (a : Array α) : Array.isEqv a a (fun x y => x = y) = true := by
simp [isEqv, isEqvAux_self]

View File

@@ -156,9 +156,11 @@ private def processVar (idStx : Syntax) : M Syntax := do
modify fun s => { s with vars := s.vars.push idStx, found := s.found.insert id }
return idStx
private def samePatternsVariables (startingAt : Nat) (s₁ s₂ : State) : Bool :=
if h : s₁.vars.size = s₂.vars.size then
Array.isEqvAux s₁.vars s.vars h (.==.) startingAt
private def samePatternsVariables (startingAt : Nat) (s₁ s₂ : State) : Bool := Id.run do
if h : s₁.vars.size = s₂.vars.size then
for h₂ : i in [startingAt:s.vars.size] do
if s₁.vars[i] != s₂.vars[i]'(by obtain _, y := h₂; simp_all) then return false
true
else
false

View File

@@ -0,0 +1,45 @@
/-!
Because `Array.isEqvAux` was defined by well-founded recursion, this used to fail with
```
tactic 'decide' failed for proposition
#[0, 1] = #[0, 1]
since its 'Decidable' instance
#[0, 1].instDecidableEq #[0, 1]
did not reduce to 'isTrue' or 'isFalse'.
After unfolding the instances 'instDecidableEqNat', 'Array.instDecidableEq' and 'Nat.decEq', reduction got stuck at the 'Decidable' instance
match h : #[0, 1].isEqv #[0, 1] fun a b => decide (a = b) with
| true => isTrue ⋯
| false => isFalse ⋯
```
-/
example : #[0, 1] = #[0, 1] := by decide
/-!
There are other `Array` functions that use well-founded recursion,
which we've marked as `@[semireducible]`. We test that `decide` can unfold them here.
-/
example : Array.ofFn (id : Fin 2 Fin 2) = #[0, 1] := by decide
example : #[0, 1].map (· + 1) = #[1, 2] := by decide
example : #[0, 1].any (· % 2 = 0) := by decide
example : #[0, 1].findIdx? (· % 2 = 0) = some 0 := by decide
example : #[0, 1, 2].popWhile (· % 2 = 0) = #[0, 1] := by decide
example : #[0, 1, 2].takeWhile (· % 2 = 0) = #[0] := by decide
example : #[0, 1, 2].feraseIdx 1, by decide = #[0, 2] := by decide
example : #[0, 1, 2].insertAt 1, by decide 3 = #[0, 3, 1, 2] := by decide
example : #[0, 1, 2].isPrefixOf #[0, 1, 2, 3] = true := by decide
example : #[0, 1, 2].zipWith #[3, 4, 5] (· + ·) = #[3, 5, 7] := by decide
example : #[0, 1, 2].allDiff = true := by decide

View File

@@ -1,7 +1,17 @@
theorem eq_of_isEqvAux [DecidableEq α] (a b : Array α) (hsz : a.size = b.size) (i : Nat) (hi : i a.size) (heqv : Array.isEqvAux a b hsz (fun x y => x = y) i) : (j : Nat) (hl : i j) (hj : j < a.size), a.get j, hj = b.get j, hsz hj := by
@[specialize]
def isEqvAux (a b : Array α) (hsz : a.size = b.size) (p : α α Bool) (i : Nat) : Bool :=
if h : i < a.size then
have : i < b.size := hsz h
p a[i] b[i] && isEqvAux a b hsz p (i+1)
else
true
termination_by a.size - i
decreasing_by simp_wf; decreasing_trivial_pre_omega
theorem eq_of_isEqvAux [DecidableEq α] (a b : Array α) (hsz : a.size = b.size) (i : Nat) (hi : i a.size) (heqv : isEqvAux a b hsz (fun x y => x = y) i) : (j : Nat) (hl : i j) (hj : j < a.size), a.get j, hj = b.get j, hsz hj := by
intro j low high
by_cases h : i < a.size
· unfold Array.isEqvAux at heqv
· unfold isEqvAux at heqv
simp [h] at heqv
have hind := eq_of_isEqvAux a b hsz (i+1) (Nat.succ_le_of_lt h) heqv.2
by_cases heq : i = j