Compare commits

...

3 Commits

Author SHA1 Message Date
Kim Morrison
dbde3c560d merge master 2024-10-22 12:01:12 +11:00
Kim Morrison
0fa905c277 feat: related Array.forIn and List.forIn 2024-10-22 12:00:00 +11:00
Kim Morrison
71122696a1 feat: rename Array.shrink to take, and relate to List.take (#5796) 2024-10-21 23:35:32 +00:00
14 changed files with 134 additions and 39 deletions

View File

@@ -241,12 +241,15 @@ def swapAt! (a : Array α) (i : Nat) (v : α) : α × Array α :=
have : Inhabited (α × Array α) := (v, a)
panic! ("index " ++ toString i ++ " out of bounds")
def shrink (a : Array α) (n : Nat) : Array α :=
/-- `take a n` returns the first `n` elements of `a`. -/
def take (a : Array α) (n : Nat) : Array α :=
let rec loop
| 0, a => a
| n+1, a => loop n a.pop
loop (a.size - n) a
@[deprecated take (since := "2024-10-22")] abbrev shrink := @take
@[inline]
unsafe def modifyMUnsafe [Monad m] (a : Array α) (i : Nat) (f : α m α) : m (Array α) := do
if h : i < a.size then

View File

@@ -101,6 +101,25 @@ We prefer to pull `List.toArray` outwards.
@[simp] theorem back_toArray [Inhabited α] (l : List α) : l.toArray.back = l.getLast! := by
simp only [back, size_toArray, Array.get!_eq_getElem!, getElem!_toArray, getLast!_eq_getElem!]
@[simp] theorem forIn_loop_toArray [Monad m] (l : List α) (f : α β m (ForInStep β)) (i : Nat)
(h : i l.length) (b : β) :
Array.forIn.loop l.toArray f i h b = (l.drop (l.length - i)).forIn b f := by
induction i generalizing l b with
| zero => simp [Array.forIn.loop]
| succ i ih =>
simp only [Array.forIn.loop, size_toArray, getElem_toArray, ih, forIn_eq_forIn]
rw [Nat.sub_add_eq, List.drop_sub_one (by omega), List.getElem?_eq_getElem (by omega)]
simp only [Option.toList_some, singleton_append, forIn_cons]
have t : l.length - 1 - i = l.length - i - 1 := by omega
simp only [t]
congr
@[simp] theorem forIn_toArray [Monad m] (l : List α) (b : β) (f : α β m (ForInStep β)) :
forIn l.toArray b f = forIn l b f := by
change l.toArray.forIn b f = l.forIn b f
rw [Array.forIn, forIn_loop_toArray]
simp
theorem foldrM_toArray [Monad m] (f : α β m β) (init : β) (l : List α) :
l.toArray.foldrM f init = l.foldrM f init := by
rw [foldrM_eq_reverse_foldlM_toList]
@@ -671,6 +690,40 @@ theorem getElem_range {n : Nat} {x : Nat} (h : x < (Array.range n).size) : (Arra
true_and, Nat.not_lt] at h
rw [List.getElem?_eq_none_iff.2 _, List.getElem?_eq_none_iff.2 (a.toList.length_reverse _)]
/-! ### take -/
@[simp] theorem size_take_loop (a : Array α) (n : Nat) : (take.loop n a).size = a.size - n := by
induction n generalizing a with
| zero => simp [take.loop]
| succ n ih =>
simp [take.loop, ih]
omega
@[simp] theorem getElem_take_loop (a : Array α) (n : Nat) (i : Nat) (h : i < (take.loop n a).size) :
(take.loop n a)[i] = a[i]'(by simp at h; omega) := by
induction n generalizing a i with
| zero => simp [take.loop]
| succ n ih =>
simp [take.loop, ih]
@[simp] theorem size_take (a : Array α) (n : Nat) : (a.take n).size = min n a.size := by
simp [take]
omega
@[simp] theorem getElem_take (a : Array α) (n : Nat) (i : Nat) (h : i < (a.take n).size) :
(a.take n)[i] = a[i]'(by simp at h; omega) := by
simp [take]
@[simp] theorem toList_take (a : Array α) (n : Nat) : (a.take n).toList = a.toList.take n := by
apply List.ext_getElem <;> simp
/-! ### forIn -/
@[simp] theorem forIn_toList [Monad m] (as : Array α) (b : β) (f : α β m (ForInStep β)) :
forIn as.toList b f = forIn as b f := by
cases as
simp
/-! ### foldl / foldr -/
@[simp] theorem foldlM_loop_empty [Monad m] (f : β α m β) (init : β) (i j : Nat) :
@@ -1339,6 +1392,10 @@ Our goal is to have `simp` "pull `List.toArray` outwards" as much as possible.
apply ext'
simp
@[simp] theorem take_toArray (l : List α) (n : Nat) : l.toArray.take n = (l.take n).toArray := by
apply ext'
simp
@[simp] theorem mapM_toArray [Monad m] [LawfulMonad m] (f : α m β) (l : List α) :
l.toArray.mapM f = List.toArray <$> l.mapM f := by
simp only [ mapM'_eq_mapM, mapM_eq_foldlM]

View File

@@ -187,6 +187,9 @@ theorem take_add (l : List α) (m n : Nat) : l.take (m + n) = l.take m ++ (l.dro
· apply length_take_le
· apply Nat.le_add_right
theorem take_one {l : List α} : l.take 1 = l.head?.toList := by
induction l <;> simp
theorem dropLast_take {n : Nat} {l : List α} (h : n < l.length) :
(l.take n).dropLast = l.take (n - 1) := by
simp only [dropLast_eq_take, length_take, Nat.le_of_lt h, Nat.min_eq_left, take_take, sub_le]
@@ -282,14 +285,14 @@ theorem mem_drop_iff_getElem {l : List α} {a : α} :
· rintro i, hm, rfl
refine i, by simp; omega, by rw [getElem_drop]
theorem head?_drop (l : List α) (n : Nat) :
@[simp] theorem head?_drop (l : List α) (n : Nat) :
(l.drop n).head? = l[n]? := by
rw [head?_eq_getElem?, getElem?_drop, Nat.add_zero]
theorem head_drop {l : List α} {n : Nat} (h : l.drop n []) :
@[simp] theorem head_drop {l : List α} {n : Nat} (h : l.drop n []) :
(l.drop n).head h = l[n]'(by simp_all) := by
have w : n < l.length := length_lt_of_drop_ne_nil h
simpa [getElem?_eq_getElem, h, w, head_eq_iff_head?_eq_some] using head?_drop l n
simp [getElem?_eq_getElem, h, w, head_eq_iff_head?_eq_some]
theorem getLast?_drop {l : List α} : (l.drop n).getLast? = if l.length n then none else l.getLast? := by
rw [getLast?_eq_getElem?, getElem?_drop]
@@ -300,7 +303,7 @@ theorem getLast?_drop {l : List α} : (l.drop n).getLast? = if l.length ≤ n th
congr
omega
theorem getLast_drop {l : List α} (h : l.drop n []) :
@[simp] theorem getLast_drop {l : List α} (h : l.drop n []) :
(l.drop n).getLast h = l.getLast (ne_nil_of_length_pos (by simp at h; omega)) := by
simp only [ne_eq, drop_eq_nil_iff] at h
apply Option.some_inj.1
@@ -449,6 +452,26 @@ theorem reverse_drop {l : List α} {n : Nat} :
rw [w, take_zero, drop_of_length_le, reverse_nil]
omega
theorem take_add_one {l : List α} {n : Nat} :
l.take (n + 1) = l.take n ++ l[n]?.toList := by
simp [take_add, take_one]
theorem drop_eq_getElem?_toList_append {l : List α} {n : Nat} :
l.drop n = l[n]?.toList ++ l.drop (n + 1) := by
induction l generalizing n with
| nil => simp
| cons hd tl ih =>
cases n
· simp
· simp only [drop_succ_cons, getElem?_cons_succ]
rw [ih]
theorem drop_sub_one {l : List α} {n : Nat} (h : 0 < n) :
l.drop (n - 1) = l[n - 1]?.toList ++ l.drop n := by
rw [drop_eq_getElem?_toList_append]
congr
omega
/-! ### findIdx -/
theorem false_of_mem_take_findIdx {xs : List α} {p : α Bool} (h : x xs.take (xs.findIdx p)) :

View File

@@ -46,7 +46,7 @@ partial def withCheckpoint (x : PullM Code) : PullM Code := do
else
return c
let (c, keep) := go toPullSizeSaved ( read).included |>.run #[]
modify fun s => { s with toPull := s.toPull.shrink toPullSizeSaved ++ keep }
modify fun s => { s with toPull := s.toPull.take toPullSizeSaved ++ keep }
return c
def attachToPull (c : Code) : PullM Code := do

View File

@@ -182,7 +182,7 @@ partial def moduleIdent (runtimeOnly : Bool) : Parser := fun input s =>
let s := p input s
match s.error? with
| none => many p input s
| some _ => { pos, error? := none, imports := s.imports.shrink size }
| some _ => { pos, error? := none, imports := s.imports.take size }
@[inline] partial def preludeOpt (k : String) : Parser :=
keywordCore k (fun _ s => s.pushModule `Init false) (fun _ s => s)

View File

@@ -36,8 +36,8 @@ abbrev Assignment.get? (a : Assignment) (x : Var) : Option Rat :=
abbrev Assignment.push (a : Assignment) (v : Rat) : Assignment :=
{ a with val := a.val.push v }
abbrev Assignment.shrink (a : Assignment) (newSize : Nat) : Assignment :=
{ a with val := a.val.shrink newSize }
abbrev Assignment.take (a : Assignment) (newSize : Nat) : Assignment :=
{ a with val := a.val.take newSize }
structure Poly where
val : Array (Int × Var)
@@ -242,7 +242,7 @@ def resolve (s : State) (cl : Cnstr) (cu : Cnstr) : Sum Result State :=
let maxVarIdx := c.lhs.getMaxVar.id
match s with -- Hack: we avoid { s with ... } to make sure we get a destructive update
| { lowers, uppers, int, assignment, } =>
let assignment := assignment.shrink maxVarIdx
let assignment := assignment.take maxVarIdx
if c.lhs.getMaxVarCoeff < 0 then
let lowers := lowers.modify maxVarIdx (·.push c)
Sum.inr { lowers, uppers, int, assignment }

View File

@@ -84,7 +84,7 @@ private def mkNullaryCtor (type : Expr) (nparams : Nat) : MetaM (Option Expr) :=
let .const d lvls := type.getAppFn
| return none
let (some ctor) getFirstCtor d | pure none
return mkAppN (mkConst ctor lvls) (type.getAppArgs.shrink nparams)
return mkAppN (mkConst ctor lvls) (type.getAppArgs.take nparams)
private def getRecRuleFor (recVal : RecursorVal) (major : Expr) : Option RecursorRule :=
match major.getAppFn with
@@ -152,7 +152,7 @@ private def toCtorWhenStructure (inductName : Name) (major : Expr) : MetaM Expr
else
let some ctorName getFirstCtor d | pure major
let ctorInfo getConstInfoCtor ctorName
let params := majorType.getAppArgs.shrink ctorInfo.numParams
let params := majorType.getAppArgs.take ctorInfo.numParams
let mut result := mkAppN (mkConst ctorName us) params
for i in [:ctorInfo.numFields] do
result := mkApp result ( mkProjFn ctorInfo us params i major)

View File

@@ -1305,7 +1305,7 @@ namespace ParserState
def keepTop (s : SyntaxStack) (startStackSize : Nat) : SyntaxStack :=
let node := s.back
s.shrink startStackSize |>.push node
s.take startStackSize |>.push node
def keepNewError (s : ParserState) (oldStackSize : Nat) : ParserState :=
match s with
@@ -1314,13 +1314,13 @@ def keepNewError (s : ParserState) (oldStackSize : Nat) : ParserState :=
def keepPrevError (s : ParserState) (oldStackSize : Nat) (oldStopPos : String.Pos) (oldError : Option Error) (oldLhsPrec : Nat) : ParserState :=
match s with
| ⟨stack, _, _, cache, _, errs⟩ =>
⟨stack.shrink oldStackSize, oldLhsPrec, oldStopPos, cache, oldError, errs⟩
⟨stack.take oldStackSize, oldLhsPrec, oldStopPos, cache, oldError, errs⟩
def mergeErrors (s : ParserState) (oldStackSize : Nat) (oldError : Error) : ParserState :=
match s with
| ⟨stack, lhsPrec, pos, cache, some err, errs⟩ =>
let newError := if oldError == err then err else oldError.merge err
⟨stack.shrink oldStackSize, lhsPrec, pos, cache, some newError, errs⟩
⟨stack.take oldStackSize, lhsPrec, pos, cache, some newError, errs⟩
| other => other
def keepLatest (s : ParserState) (startStackSize : Nat) : ParserState :=
@@ -1363,7 +1363,7 @@ def runLongestMatchParser (left? : Option Syntax) (startLhsPrec : Nat) (p : Pars
s -- success or error with the expected number of nodes
else if s.hasError then
-- error with an unexpected number of nodes.
s.shrinkStack startSize |>.pushSyntax Syntax.missing
s.takeStack startSize |>.pushSyntax Syntax.missing
else
-- parser succeeded with incorrect number of nodes
invalidLongestMatchParser s

View File

@@ -158,8 +158,10 @@ def size (stack : SyntaxStack) : Nat :=
def isEmpty (stack : SyntaxStack) : Bool :=
stack.size == 0
def shrink (stack : SyntaxStack) (n : Nat) : SyntaxStack :=
{ stack with raw := stack.raw.shrink (stack.drop + n) }
def take (stack : SyntaxStack) (n : Nat) : SyntaxStack :=
{ stack with raw := stack.raw.take (stack.drop + n) }
@[deprecated take (since := "2024-10-22")] abbrev shrink := @take
def push (stack : SyntaxStack) (a : Syntax) : SyntaxStack :=
{ stack with raw := stack.raw.push a }
@@ -212,7 +214,7 @@ def stackSize (s : ParserState) : Nat :=
s.stxStack.size
def restore (s : ParserState) (iniStackSz : Nat) (iniPos : String.Pos) : ParserState :=
{ s with stxStack := s.stxStack.shrink iniStackSz, errorMsg := none, pos := iniPos }
{ s with stxStack := s.stxStack.take iniStackSz, errorMsg := none, pos := iniPos }
def setPos (s : ParserState) (pos : String.Pos) : ParserState :=
{ s with pos := pos }
@@ -226,8 +228,10 @@ def pushSyntax (s : ParserState) (n : Syntax) : ParserState :=
def popSyntax (s : ParserState) : ParserState :=
{ s with stxStack := s.stxStack.pop }
def shrinkStack (s : ParserState) (iniStackSz : Nat) : ParserState :=
{ s with stxStack := s.stxStack.shrink iniStackSz }
def takeStack (s : ParserState) (iniStackSz : Nat) : ParserState :=
{ s with stxStack := s.stxStack.take iniStackSz }
@[deprecated takeStack (since := "2024-10-22")] abbrev shrinkStack := @takeStack
def next (s : ParserState) (input : String) (pos : String.Pos) : ParserState :=
{ s with pos := input.next pos }
@@ -250,7 +254,7 @@ def mkNode (s : ParserState) (k : SyntaxNodeKind) (iniStackSz : Nat) : ParserSta
stack, lhsPrec, pos, cache, err, recovered
else
let newNode := Syntax.node SourceInfo.none k (stack.extract iniStackSz stack.size)
let stack := stack.shrink iniStackSz
let stack := stack.take iniStackSz
let stack := stack.push newNode
stack, lhsPrec, pos, cache, err, recovered
@@ -258,7 +262,7 @@ def mkTrailingNode (s : ParserState) (k : SyntaxNodeKind) (iniStackSz : Nat) : P
match s with
| stack, lhsPrec, pos, cache, err, errs =>
let newNode := Syntax.node SourceInfo.none k (stack.extract (iniStackSz - 1) stack.size)
let stack := stack.shrink (iniStackSz - 1)
let stack := stack.take (iniStackSz - 1)
let stack := stack.push newNode
stack, lhsPrec, pos, cache, err, errs
@@ -283,7 +287,7 @@ def mkEOIError (s : ParserState) (expected : List String := []) : ParserState :=
def mkErrorsAt (s : ParserState) (ex : List String) (pos : String.Pos) (initStackSz? : Option Nat := none) : ParserState := Id.run do
let mut s := s.setPos pos
if let some sz := initStackSz? then
s := s.shrinkStack sz
s := s.takeStack sz
s := s.setError { expected := ex }
s.pushSyntax .missing

View File

@@ -398,7 +398,7 @@ mutual
let fType replaceLPsWithVars ( inferType f)
let (mvars, bInfos, resultType) forallMetaBoundedTelescope fType args.size
let rest := args.extract mvars.size args.size
let args := args.shrink mvars.size
let args := args.take mvars.size
-- Unify with the expected type
if ( read).knowsType then tryUnify ( inferType (mkAppN f args)) resultType

View File

@@ -144,7 +144,7 @@ def fold (fn : Array Format → Format) (x : FormatterM Unit) : FormatterM Unit
x
let stack getStack
let f := fn $ stack.extract sp stack.size
setStack $ (stack.shrink sp).push f
setStack $ (stack.take sp).push f
/-- Execute `x` and concatenate generated Format objects. -/
def concat (x : FormatterM Unit) : FormatterM Unit := do

View File

@@ -292,7 +292,7 @@ instance : Append Log := ⟨Log.append⟩
/-- Removes log entries after `pos` (inclusive). -/
@[inline] def dropFrom (log : Log) (pos : Log.Pos) : Log :=
.mk <| log.entries.shrink pos.val
.mk <| log.entries.take pos.val
/-- Takes log entries before `pos` (exclusive). -/
@[inline] def takeFrom (log : Log) (pos : Log.Pos) : Log :=

View File

@@ -5,3 +5,11 @@
#check_simp #[1,2,3,4,5][2]! ~> 3
#check_simp #[1,2,3,4,5][7]! ~> (default : Nat)
#check_simp (#[] : Array Nat)[0]! ~> (default : Nat)
attribute [local simp] Id.run in
#check_simp
(Id.run do
let mut s := 0
for i in [1,2,3,4].toArray do
s := s + i
pure s) ~> 10

View File

@@ -64,7 +64,7 @@ d.errorMsg != none
d.stxStack.size
def ParserData.restore (d : ParserData) (iniStackSz : Nat) (iniPos : Nat) : ParserData :=
{ stxStack := d.stxStack.shrink iniStackSz, errorMsg := none, pos := iniPos, .. d}
{ stxStack := d.stxStack.take iniStackSz, errorMsg := none, pos := iniPos, .. d}
def ParserData.setPos (d : ParserData) (pos : Nat) : ParserData :=
{ pos := pos, .. d }
@@ -75,8 +75,8 @@ def ParserData.setCache (d : ParserData) (cache : ParserCache) : ParserData :=
def ParserData.pushSyntax (d : ParserData) (n : Syntax) : ParserData :=
{ stxStack := d.stxStack.push n, .. d }
def ParserData.shrinkStack (d : ParserData) (iniStackSz : Nat) : ParserData :=
{ stxStack := d.stxStack.shrink iniStackSz, .. d }
def ParserData.takeStack (d : ParserData) (iniStackSz : Nat) : ParserData :=
{ stxStack := d.stxStack.take iniStackSz, .. d }
def ParserData.next (d : ParserData) (s : String) (pos : Nat) : ParserData :=
{ pos := s.next pos, .. d }
@@ -114,7 +114,7 @@ match d with
d
else
let newNode := Syntax.node k (stack.extract iniStackSz stack.size) [] in
let stack := stack.shrink iniStackSz in
let stack := stack.take iniStackSz in
let stack := stack.push newNode in
stack, pos, cache, err
@@ -144,7 +144,7 @@ match d with
let iniSz := d.stackSize in
let iniPos := d.pos in
match p s d with
| stack, _, cache, some msg := stack.shrink iniSz, iniPos, cache, some msg
| stack, _, cache, some msg := stack.take iniSz, iniPos, cache, some msg
| other := other
@[noinline] def noFirstTokenInfo (info : ParserInfo) : ParserInfo :=
@@ -516,15 +516,15 @@ partial def identFnAux (startPos : Nat) (tk : Option TokenConfig) : Name → Par
def ParserData.keepNewError (d : ParserData) (oldStackSize : Nat) : ParserData :=
match d with
| ⟨stack, pos, cache, err⟩ := ⟨stack.shrink oldStackSize, pos, cache, err⟩
| ⟨stack, pos, cache, err⟩ := ⟨stack.take oldStackSize, pos, cache, err⟩
def ParserData.keepPrevError (d : ParserData) (oldStackSize : Nat) (oldStopPos : String.Pos) (oldError : Option String) : ParserData :=
match d with
| ⟨stack, _, cache, _⟩ := ⟨stack.shrink oldStackSize, oldStopPos, cache, oldError⟩
| ⟨stack, _, cache, _⟩ := ⟨stack.take oldStackSize, oldStopPos, cache, oldError⟩
def ParserData.mergeErrors (d : ParserData) (oldStackSize : Nat) (oldError : String) : ParserData :=
match d with
| ⟨stack, pos, cache, some err⟩ := ⟨stack.shrink oldStackSize, pos, cache, some (err ++ "; " ++ oldError)⟩
| ⟨stack, pos, cache, some err⟩ := ⟨stack.take oldStackSize, pos, cache, some (err ++ "; " ++ oldError)⟩
| other := other
def ParserData.mkLongestNodeAlt (d : ParserData) (startSize : Nat) : ParserData :=
@@ -535,14 +535,14 @@ match d with
else
-- parser created more than one node, combine them into a single node
let node := Syntax.node nullKind (stack.extract startSize stack.size) [] in
let stack := stack.shrink startSize in
let stack := stack.take startSize in
⟨stack.push node, pos, cache, none⟩
def ParserData.keepLatest (d : ParserData) (startStackSize : Nat) : ParserData :=
match d with
| ⟨stack, pos, cache, _⟩ :=
let node := stack.back in
let stack := stack.shrink startStackSize in
let stack := stack.take startStackSize in
let stack := stack.push node in
⟨stack, pos, cache, none⟩
@@ -591,7 +591,7 @@ def longestMatchFn₂ (p q : ParserFn) : ParserFn :=
let startSize := d.stackSize in
let startPos := d.pos in
let d := p s d in
let d := if d.hasError then d.shrinkStack startSize else d.mkLongestNodeAlt startSize in
let d := if d.hasError then d.takeStack startSize else d.mkLongestNodeAlt startSize in
let d := longestMatchStep startSize startPos q s d in
longestMatchMkResult startSize d
@@ -603,7 +603,7 @@ def longestMatchFn : List ParserFn → ParserFn
let startPos := d.pos in
let d := p s d in
if d.hasError then
let d := d.shrinkStack startSize in
let d := d.takeStack startSize in
longestMatchFnAux startSize startPos ps s d
else
let d := d.mkLongestNodeAlt startSize in