feat: infer termination arguments like xs.size - i (#3666)

a common pattern for recursive functions is
```
def countUp (n i acc : Nat) : Nat :=
  if i < n then
    countUp n (i+1) (acc + i)
  else
    acc
```
where we increase a value `i` until it hits an upper bound. This is
particularly common with array processing functions:
```
$ git grep 'termination_by.*size.*-' src/|wc -l
26
```

GuessLex now recognizes this pattern. The general approach is:

For every recursive call, check if the context contains hypotheses of
the form `e₁ < e₂` (or similar comparisions), and then consider `e₂ -
e₁` as a termination argument.

Currently, this only fires when `e₁` and `e₂` only depend on the
functions parameters, but not local let-bindings or variables bound in
local pattern matches.

Duplicates are removed.

In the table showing the termination argument failures, long termination
arguments are now given a number and abbreviated as e.g. `#4` in the
table headers.

More examples in the test file, here as some highlights:
```
def distinct (xs : Array Nat) : Bool :=
  let rec loop (i j : Nat) : Bool :=
    if _ : i < xs.size then
      if _ : j < i then
        if xs[j] = xs[i] then
          false
        else
          loop i (j+1)
      else
        loop (i+1) 0
    else
      true
  loop 0 0
```
infers
```
termination_by (Array.size xs - i, i - j)
```
and the weird functions where `i` goes up or down
```
def weird (xs : Array Nat) (i : Nat) : Bool :=
  if _ : i < xs.size then
    if _ : 0 < i then
      if xs[i] = 42 then
        weird xs.pop (i - 1)
      else
        weird xs (i+1)
    else
      weird xs (i+1)
  else
    true
decreasing_by all_goals simp_wf; omega
```
infers
```
termination_by (Array.size xs - i, i)
```
but unfortunately needs `decreasing_by` pending the “big
decreasing_tactic refactor” that
I expect we’ll want to do at some point.
This commit is contained in:
Joachim Breitner
2024-03-16 13:27:35 +01:00
committed by GitHub
parent f0ff01ae28
commit 4c57da4b0f
6 changed files with 343 additions and 14 deletions

View File

@@ -41,6 +41,19 @@ v4.8.0 (development in progress)
(x x : Nat) : motive x x
```
* The termination checker now recognizes more recursion patterns without an
explicit `terminatin_by`. In particular the idiom of counting up to an upper
bound, as in
```
def Array.sum (arr : Array Nat) (i acc : Nat) : Nat :=
if _ : i < arr.size then
Array.sum arr (i+1) (acc + arr[i])
else
acc
```
is recognized without having to say `termination_by arr.size - i`.
Breaking changes:
* Automatically generated equational theorems are now named using suffix `.eq_<idx>` instead of `._eq_<idx>`, and `.def` instead of `._unfold`. Example:

View File

@@ -21,7 +21,8 @@ import Lean.Data.Array
/-!
This module finds lexicographic termination arguments for well-founded recursion.
Starting with basic measures (`sizeOf xᵢ` for all parameters `xᵢ`) it tries all combinations
Starting with basic measures (`sizeOf xᵢ` for all parameters `xᵢ`), and complex measures
(e.g. `e₂ - e₁` if `e₁ < e₂` is found in the context of a recursive call) it tries all combinations
until it finds one where all proof obligations go through with the given tactic (`decerasing_by`),
if given, or the default `decreasing_tactic`.
@@ -59,6 +60,10 @@ The following optimizations are applied to make this feasible:
The logic here is based on “Finding Lexicographic Orders for Termination Proofs in Isabelle/HOL”
by Lukas Bulwahn, Alexander Krauss, and Tobias Nipkow, 10.1007/978-3-540-74591-4_5
<https://www21.in.tum.de/~nipkow/pubs/tphols07.pdf>.
We got the idea of considering the measure `e₂ - e₁` if we see `e₁ < e₂` from
“Termination Analysis with Calling Context Graphs” by Panagiotis Manolios &
Daron Vroon, https://doi.org/10.1007/11817963_36.
-/
set_option autoImplicit false
@@ -353,6 +358,51 @@ def collectRecCalls (unaryPreDef : PreDefinition) (fixedPrefixSize : Nat)
let (callee, args) argsPacker.unpack arg
RecCallWithContext.create ( getRef) caller (ys ++ params) callee (ys ++ args)
/-- Is the expression a `<`-like comparison of `Nat` expressions -/
def isNatCmp (e : Expr) : Option (Expr × Expr) :=
match_expr e with
| LT.lt α _ e₁ e₂ => if α.isConstOf ``Nat then some (e₁, e₂) else none
| LE.le α _ e₁ e₂ => if α.isConstOf ``Nat then some (e₁, e₂) else none
| GT.gt α _ e₁ e₂ => if α.isConstOf ``Nat then some (e₂, e₁) else none
| GE.ge α _ e₁ e₂ => if α.isConstOf ``Nat then some (e₂, e₁) else none
| _ => none
def complexMeasures (preDefs : Array PreDefinition) (fixedPrefixSize : Nat)
(userVarNamess : Array (Array Name)) (recCalls : Array RecCallWithContext) :
MetaM (Array (Array Measure)) := do
preDefs.mapIdxM fun funIdx preDef => do
let arity lambdaTelescope preDef.value fun xs _ => pure xs.size
let mut measures := #[]
for rc in recCalls do
-- Only look at calls from the current function
unless rc.caller = funIdx do continue
-- Only look at calls where the parameters have not been refined
unless rc.params.all (·.isFVar) do continue
let xs := rc.params.map (·.fvarId!)
let varyingParams : Array FVarId := xs[fixedPrefixSize:]
measures rc.ctxt.run do
withUserNames rc.params[fixedPrefixSize:] userVarNamess[funIdx]! do
trace[Elab.definition.wf] "rc: {rc.caller} ({rc.params}) → {rc.callee} ({rc.args})"
let mut measures := measures
for ldecl in getLCtx do
if let some (e₁, e₂) := isNatCmp ldecl.type then
-- We only want to consider these expressions if they depend only on the function's
-- immediate arguments, so check that
if e₁.hasAnyFVar (! xs.contains ·) then continue
if e₂.hasAnyFVar (! xs.contains ·) then continue
-- If e₁ does not depend on any varying parameters, simply ignore it
let e₁_is_const := ! e₁.hasAnyFVar (varyingParams.contains ·)
let body := if e₁_is_const then e₂ else mkNatSub e₂ e₁
-- Avoid adding simple measures
unless body.isFVar do
let fn mkLambdaFVars rc.params body
-- Avoid duplicates
unless measures.anyM (isDefEq ·.fn fn) do
let extraParams := preDef.termination.extraParams
measures := measures.push { ref := .missing, fn, natFn := fn, arity, extraParams }
return measures
return measures
/-- A `GuessLexRel` described how a recursive call affects a measure; whether it
decreases strictly, non-strictly, is equal, or else. -/
inductive GuessLexRel | lt | eq | le | no_idea
@@ -603,21 +653,43 @@ def RecCallWithContext.posString (rcc : RecCallWithContext) : MetaM String := do
return s!"{position.line}:{position.column}{endPosStr}"
/-- How to present the measure in the table header, possibly abbreviated. -/
def measureHeader (measure : Measure) : StateT (Nat × String) MetaM String := do
let s measure.toString
if s.length > 5 then
let (i, footer) get
let i := i + 1
let footer := footer ++ s!"#{i}: {s}\n"
set (i, footer)
pure s!"#{i}"
else
pure s
def collectHeaders {α} (a : StateT (Nat × String) MetaM α) : MetaM (α × String) := do
let (x, (_, footer)) a.run (0, "")
pure (x,footer)
/-- Explain what we found out about the recursive calls (non-mutual case) -/
def explainNonMutualFailure (measures : Array Measure) (rcs : Array RecCallCache) : MetaM Format := do
let header measures.mapM Measure.toString
let (header, footer) collectHeaders (measures.mapM measureHeader)
let mut table : Array (Array String) := #[#[""] ++ header]
for i in [:rcs.size], rc in rcs do
let mut row := #[s!"{i+1}) {← rc.rcc.posString}"]
for argIdx in [:measures.size] do
row := row.push ( rc.prettyEntry argIdx argIdx)
table := table.push row
return formatTable table
let out := formatTable table
if footer.isEmpty then
return out
else
return out ++ "\n\n" ++ footer
/-- Explain what we found out about the recursive calls (mutual case) -/
def explainMutualFailure (declNames : Array Name) (measuress : Array (Array Measure))
(rcs : Array RecCallCache) : MetaM Format := do
let (headerss, footer) collectHeaders (measuress.mapM (·.mapM measureHeader))
let mut r := Format.nil
for rc in rcs do
@@ -626,8 +698,7 @@ def explainMutualFailure (declNames : Array Name) (measuress : Array (Array Meas
r := r ++ f!"Call from {declNames[caller]!} to {declNames[callee]!} " ++
f!"at {← rc.rcc.posString}:\n"
let header measuress[caller]!.mapM Measure.toString
let mut table : Array (Array String) := #[#[""] ++ header]
let mut table : Array (Array String) := #[#[""] ++ headerss[caller]!]
if caller = callee then
-- For self-calls, only the diagonal is interesting, so put it into one row
let mut row := #[""]
@@ -637,12 +708,15 @@ def explainMutualFailure (declNames : Array Name) (measuress : Array (Array Meas
else
for argIdx in [:measuress[callee]!.size] do
let mut row := #[]
row := row.push ( measuress[callee]![argIdx]!.toString)
row := row.push headerss[callee]![argIdx]!
for paramIdx in [:measuress[caller]!.size] do
row := row.push ( rc.prettyEntry paramIdx argIdx)
table := table.push row
r := r ++ formatTable table ++ "\n"
unless footer.isEmpty do
r := r ++ "\n\n" ++ footer
return r
def explainFailure (declNames : Array Name) (measuress : Array (Array Measure))
@@ -705,9 +779,15 @@ def guessLex (preDefs : Array PreDefinition) (unaryPreDef : PreDefinition)
let userVarNamess argsPacker.varNamess.mapM (naryVarNames ·)
trace[Elab.definition.wf] "varNames is: {userVarNamess}"
-- For every function, the meaures we want to use
-- Collect all recursive calls and extract their context
let recCalls collectRecCalls unaryPreDef fixedPrefixSize argsPacker
let recCalls := filterSubsumed recCalls
-- For every function, the measures we want to use
-- (One for each non-forbiddend arg)
let measuress simpleMeasures preDefs fixedPrefixSize userVarNamess
let meassures simpleMeasures preDefs fixedPrefixSize userVarNamess
let meassures₂ complexMeasures preDefs fixedPrefixSize userVarNamess recCalls
let measuress := Array.zipWith meassures₁ meassures₂ (· ++ ·)
-- The list of measures, including the measures that order functions.
-- The function ordering measures come last
@@ -719,9 +799,6 @@ def guessLex (preDefs : Array PreDefinition) (unaryPreDef : PreDefinition)
reportTermArgs preDefs termArgs
return termArgs
-- Collect all recursive calls and extract their context
let recCalls collectRecCalls unaryPreDef fixedPrefixSize argsPacker
let recCalls := filterSubsumed recCalls
let rcs recCalls.mapM (RecCallCache.mk (preDefs.map (·.termination.decreasingBy?)) measuress ·)
let callMatrix := rcs.map (inspectCall ·)

View File

@@ -2007,6 +2007,10 @@ private def natAddFn : Expr :=
let nat := mkConst ``Nat
mkApp4 (mkConst ``HAdd.hAdd [0, 0, 0]) nat nat nat (mkApp2 (mkConst ``instHAdd [0]) nat (mkConst ``instAddNat))
private def natSubFn : Expr :=
let nat := mkConst ``Nat
mkApp4 (mkConst ``HSub.hSub [0, 0, 0]) nat nat nat (mkApp2 (mkConst ``instHSub [0]) nat (mkConst ``instSubNat))
private def natMulFn : Expr :=
let nat := mkConst ``Nat
mkApp4 (mkConst ``HMul.hMul [0, 0, 0]) nat nat nat (mkApp2 (mkConst ``instHMul [0]) nat (mkConst ``instMulNat))
@@ -2019,6 +2023,10 @@ def mkNatSucc (a : Expr) : Expr :=
def mkNatAdd (a b : Expr) : Expr :=
mkApp2 natAddFn a b
/-- Given `a b : Nat`, returns `a - b` -/
def mkNatSub (a b : Expr) : Expr :=
mkApp2 natSubFn a b
/-- Given `a b : Nat`, returns `a * b` -/
def mkNatMul (a b : Expr) : Expr :=
mkApp2 natMulFn a b

View File

@@ -0,0 +1,136 @@
set_option showInferredTerminationBy true
def countUp (n i acc : Nat) : Nat :=
if i < n then
countUp n (i+1) (acc + i)
else
acc
def all42 (xs : Array Nat) (i : Nat) : Bool :=
if h : i < xs.size then
if xs[i] = 42 then
all42 xs (i+1)
else
false
else
true
def henrik1 (xs : Array Nat) (i : Nat) : Bool :=
if h : i < xs.size then
if xs[i] = 42 then
henrik1 (xs.push 42) (i+2)
else
false
else
true
def merge (xs ys : Array Nat) : Array Nat :=
let rec loop (i j : Nat) (acc : Array Nat) : Array Nat :=
if _ : i < xs.size then
if _ : j < ys.size then
if xs[i] < ys[j] then
loop (i+1) j (acc.push xs[i])
else
loop i (j+1) (acc.push ys[j])
else
loop (i+1) j (acc.push xs[i])
else
if _ : j < ys.size then
loop i (j+1) (acc.push ys[j])
else
acc
loop 0 0 #[]
def distinct (xs : Array Nat) : Bool :=
let rec loop (i j : Nat) : Bool :=
if _ : i < xs.size then
if _ : j < i then
if xs[j] = xs[i] then
false
else
loop i (j+1)
else
loop (i+1) 0
else
true
loop 0 0
-- This examples shows a limitation of our current `decreasing_tactic`.
-- Guesslex infers
-- termination_by (Array.size xs - i, i)
-- but because `decreasing_with` is using
-- repeat (first | apply Prod.Lex.right | apply Prod.Lex.left)
-- it cannot solve this goal. But if we leave the Prod.Lex-handling to omega, it works
def weird (xs : Array Nat) (i : Nat) : Bool :=
if _ : i < xs.size then
if _ : 0 < i then
if xs[i] = 42 then
weird xs.pop (i - 1)
else
weird xs (i+1)
else
weird xs (i+1)
else
true
decreasing_by all_goals simp_wf; omega
/--
This checks
* the presentation of complex measures in the table
* that multiple recursive calls do not lead to the same argument tried multiple times.
* it uses `e` instead of `e - 0`
* that we do not get measures from refined arguments
-/
def failure (xs : Array Nat) (i : Nat) : Bool :=
if h : i < xs.size then failure xs i && failure xs i && failure xs (i + 1) else
if h : i + 1 < xs.size then failure xs i else
let j := i
if h : j < xs.size then failure xs (j+1) else
if h : 0 < i then failure xs (j+1) else
if h : 42 < i then failure xs (j+1) else
if h : xs.size < i then failure xs (j+1) else
if h : 42 < i + i then failure xs (j+1) else
match i with
| 0 => true
| i+1 =>
if h : i < xs.size + 5 then
failure xs i
else
false
mutual
def mutual_failure (xs : Array Nat) (i : Nat) : Bool :=
if h : i < xs.size then
mutual_failure2 xs i && mutual_failure2 xs i && mutual_failure2 xs (i + 1)
else
if h : i + 1 < xs.size then
mutual_failure2 xs i
else
let j := i
if h : j < xs.size then
mutual_failure2 xs (j+1)
else
match i with
| 0 => true
| i+1 =>
if h : i < xs.size then
mutual_failure2 xs i
else
false
def mutual_failure2 (xs : Array Nat) (i : Nat) : Bool :=
if h : i < xs.size then
mutual_failure xs i && mutual_failure xs i && mutual_failure xs (i + 1)
else
let j := i
if h : j < xs.size then
mutual_failure xs (j+1)
else
match i with
| 0 => true
| i+1 =>
if h : i < xs.size then
mutual_failure xs i
else
false
end

View File

@@ -0,0 +1,95 @@
Inferred termination argument:
termination_by n - i
Inferred termination argument:
termination_by Array.size xs - i
Inferred termination argument:
termination_by Array.size xs - i
Inferred termination argument:
termination_by (Array.size xs - i, Array.size ys - j)
Inferred termination argument:
termination_by (Array.size xs - i, i - j)
Inferred termination argument:
termination_by (Array.size xs - i, i)
guessLexDiff.lean:85:26-85:38: error: fail to show termination for
failure
with errors
argument #2 was not used for structural recursion
failed to eliminate recursive application
_root_.failure xs i
structural recursion cannot be used
Could not find a decreasing measure.
The arguments relate at each recursive call as follows:
(<, ≤, =: relation proved, ? all proofs failed, _: no proof attempted)
i #1 #2 i + i
1) 85:26-38 = = = =
2) 85:58-76 ? < _ _
3) 85:26-38 = = = =
4) 88:26-42 _ < _ _
5) 88:26-42 ? ≤ ≤ ?
6) 88:26-42 _ < _ _
7) 88:26-42 _ < _ _
8) 88:26-42 _ < _ _
9) 97:8-20 _ < _ _
#1: Array.size xs - i
#2: Array.size xs - (i + 1)
Please use `termination_by` to specify a decreasing measure.
guessLexDiff.lean:102:4-102:18: error: fail to show termination for
mutual_failure
mutual_failure2
with errors
structural recursion does not handle mutually recursive functions
Could not find a decreasing measure.
The arguments relate at each recursive call as follows:
(<, ≤, =: relation proved, ? all proofs failed, _: no proof attempted)
Call from mutual_failure to mutual_failure2 at 104:4-24:
i #1 #2
i = ? ?
#3 ? = ?
Call from mutual_failure to mutual_failure2 at 104:52-78:
i #1 #2
i ? _ _
#3 _ < _
Call from mutual_failure to mutual_failure2 at 104:4-24:
i #1 #2
i _ _ _
#3 _ = _
Call from mutual_failure to mutual_failure2 at 111:4-28:
i #1 #2
i _ _ _
#3 _ < _
Call from mutual_failure to mutual_failure2 at 117:8-28:
i #1 #2
i _ _ _
#3 _ ? _
Call from mutual_failure2 to mutual_failure at 123:4-23:
i #3
i _ _
#1 _ _
#2 _ _
Call from mutual_failure2 to mutual_failure at 123:50-75:
i #3
i _ _
#1 _ _
#2 _ _
Call from mutual_failure2 to mutual_failure at 127:4-27:
i #3
i _ _
#1 _ _
#2 _ _
Call from mutual_failure2 to mutual_failure at 133:8-27:
i #3
i _ _
#1 _ _
#2 _ _
#1: Array.size xs - i
#2: Array.size xs - (i + 1)
#3: Array.size xs - i
Please use `termination_by` to specify a decreasing measure.

View File

@@ -14,6 +14,6 @@ structural recursion cannot be used
Could not find a decreasing measure.
The arguments relate at each recursive call as follows:
(<, ≤, =: relation proved, ? all proofs failed, _: no proof attempted)
x y
1) 3:12-19 ≤ ?
x y y - x
1) 3:12-19 ≤ ? ?
Please use `termination_by` to specify a decreasing measure.