Compare commits

...

1 Commits

Author SHA1 Message Date
Kim Morrison
2fe74ba5bc feat: simprocs for other Fin operations 2024-12-03 15:24:21 +11:00
3 changed files with 165 additions and 1 deletions

View File

@@ -62,7 +62,7 @@ def getStringValue? (e : Expr) : (Option String) :=
| .lit (.strVal s) => some s
| _ => none
/-- Return `some ⟨n, v⟩` if `e` is af `OfNat.ofNat` application encoding a `Fin n` with value `v` -/
/-- Return `some ⟨n, v⟩` if `e` is an `OfNat.ofNat` application encoding a `Fin n` with value `v` -/
def getFinValue? (e : Expr) : MetaM (Option ((n : Nat) × Fin n)) := OptionT.run do
let (v, type) getOfNatValue? e ``Fin
let n getNatValue? ( whnfD type.appArg!)

View File

@@ -20,6 +20,18 @@ def fromExpr? (e : Expr) : SimpM (Option Value) := do
let some n, value getFinValue? e | return none
return some { n, value }
@[inline] def reduceOp (declName : Name) (arity : Nat) (f : Nat Nat) (op : {n : Nat} Fin n Fin (f n)) (e : Expr) : SimpM DStep := do
unless e.isAppOfArity declName arity do return .continue
let some v fromExpr? e.appArg! | return .continue
let v' := op v.value
return .done <| toExpr v'
@[inline] def reduceNatOp (declName : Name) (arity : Nat) (f : Nat Nat) (op : (n : Nat) Fin (f n)) (e : Expr) : SimpM DStep := do
unless e.isAppOfArity declName arity do return .continue
let some v getNatValue? e.appArg! | return .continue
let v' := op v
return .done <| toExpr v'
@[inline] def reduceBin (declName : Name) (arity : Nat) (op : {n : Nat} Fin n Fin n Fin n) (e : Expr) : SimpM DStep := do
unless e.isAppOfArity declName arity do return .continue
let some v₁ fromExpr? e.appFn!.appArg! | return .continue
@@ -47,12 +59,23 @@ The following code assumes users did not override the `Fin n` instances for the
If they do, they must disable the following `simprocs`.
-/
builtin_dsimproc [simp, seval] reduceSucc (Fin.succ _) := reduceOp ``Fin.succ 2 (· + 1) Fin.succ
builtin_dsimproc [simp, seval] reduceRev (Fin.rev _) := reduceOp ``Fin.rev 2 (·) Fin.rev
builtin_dsimproc [simp, seval] reduceLast (Fin.last _) := reduceNatOp ``Fin.last 1 (· + 1) Fin.last
builtin_dsimproc [simp, seval] reduceAdd ((_ + _ : Fin _)) := reduceBin ``HAdd.hAdd 6 (· + ·)
builtin_dsimproc [simp, seval] reduceMul ((_ * _ : Fin _)) := reduceBin ``HMul.hMul 6 (· * ·)
builtin_dsimproc [simp, seval] reduceSub ((_ - _ : Fin _)) := reduceBin ``HSub.hSub 6 (· - ·)
builtin_dsimproc [simp, seval] reduceDiv ((_ / _ : Fin _)) := reduceBin ``HDiv.hDiv 6 (· / ·)
builtin_dsimproc [simp, seval] reduceMod ((_ % _ : Fin _)) := reduceBin ``HMod.hMod 6 (· % ·)
builtin_dsimproc [simp, seval] reduceAnd ((_ &&& _ : Fin _)) := reduceBin ``HAnd.hAnd 6 (· &&& ·)
builtin_dsimproc [simp, seval] reduceOr ((_ ||| _ : Fin _)) := reduceBin ``HOr.hOr 6 (· ||| ·)
builtin_dsimproc [simp, seval] reduceXor ((_ ^^^ _ : Fin _)) := reduceBin ``HXor.hXor 6 (· ^^^ ·)
builtin_dsimproc [simp, seval] reduceShiftLeft ((_ <<< _ : Fin _)) := reduceBin ``HShiftLeft.hShiftLeft 6 (· <<< ·)
builtin_dsimproc [simp, seval] reduceShiftRight ((_ >>> _ : Fin _)) := reduceBin ``HShiftRight.hShiftRight 6 (· >>> ·)
builtin_simproc [simp, seval] reduceLT (( _ : Fin _) < _) := reduceBinPred ``LT.lt 4 (. < .)
builtin_simproc [simp, seval] reduceLE (( _ : Fin _) _) := reduceBinPred ``LE.le 4 (. .)
builtin_simproc [simp, seval] reduceGT (( _ : Fin _) > _) := reduceBinPred ``GT.gt 4 (. > .)
@@ -83,4 +106,70 @@ builtin_dsimproc [simp, seval] reduceFinMk (Fin.mk _ _) := fun e => do
else
return .continue
builtin_dsimproc [simp, seval] reduceOfNat' (Fin.ofNat' _ _) := fun e => do
unless e.isAppOfArity ``Fin.ofNat' 3 do return .continue
let some (n + 1) getNatValue? e.appFn!.appFn!.appArg! | return .continue
let some k getNatValue? e.appArg! | return .continue
return .done <| toExpr (Fin.ofNat' (n + 1) k)
builtin_dsimproc [simp, seval] reduceCastSucc (Fin.castSucc _) := fun e => do
unless e.isAppOfArity ``Fin.castSucc 2 do return .continue
let some k fromExpr? e.appArg! | return .continue
return .done <| toExpr (castSucc k.value)
builtin_dsimproc [simp, seval] reduceCastAdd (Fin.castAdd _ _) := fun e => do
unless e.isAppOfArity ``Fin.castAdd 3 do return .continue
let some m getNatValue? e.appFn!.appArg! | return .continue
let some k fromExpr? e.appArg! | return .continue
return .done <| toExpr (castAdd m k.value)
builtin_dsimproc [simp, seval] reduceAddNat (Fin.addNat _ _) := fun e => do
unless e.isAppOfArity ``Fin.addNat 3 do return .continue
let some k fromExpr? e.appFn!.appArg! | return .continue
let some m getNatValue? e.appArg! | return .continue
return .done <| toExpr (addNat k.value m)
builtin_dsimproc [simp, seval] reduceNatAdd (Fin.natAdd _ _) := fun e => do
unless e.isAppOfArity ``Fin.natAdd 3 do return .continue
let some m getNatValue? e.appFn!.appArg! | return .continue
let some k fromExpr? e.appArg! | return .continue
return .done <| toExpr (natAdd m k.value)
builtin_dsimproc [simp, seval] reduceCastLT (Fin.castLT _ _) := fun e => do
unless e.isAppOfArity ``Fin.castLT 4 do return .continue
let some n getNatValue? e.appFn!.appFn!.appFn!.appArg! | return .continue
let some i fromExpr? e.appFn!.appArg! | return .continue
if h : i.value < n then
return .done <| toExpr (castLT i.value h)
else
return .continue
builtin_dsimproc [simp, seval] reduceCastLE (Fin.castLE _ _) := fun e => do
unless e.isAppOfArity ``Fin.castLE 4 do return .continue
let some m getNatValue? e.appFn!.appFn!.appArg! | return .continue
let some i fromExpr? e.appArg! | return .continue
if h : i.n m then
return .done <| toExpr (castLE h i.value)
else
return .continue
-- No simproc is needed for `Fin.cast`, as for explicit numbers `Fin.cast_refl` will apply.
builtin_dsimproc [simp, seval] reduceSubNat (Fin.subNat _ _ _) := fun e => do
unless e.isAppOfArity ``Fin.subNat 4 do return .continue
let some m getNatValue? e.appFn!.appFn!.appArg! | return .continue
let some i fromExpr? e.appFn!.appArg! | return .continue
if h : m i.value then
return .done <| toExpr (subNat m (i.value.cast (by omega : i.n = (i.n - m) + m)) h)
else
return .continue
builtin_dsimproc [simp, seval] reducePred (Fin.pred _ _) := fun e => do
unless e.isAppOfArity ``Fin.pred 3 do return .continue
let some (_ + 1), i fromExpr? e.appFn!.appArg! | return .continue
if h : i 0 then
return .done <| toExpr (pred i h)
else
return .continue
end Fin

View File

@@ -0,0 +1,75 @@
variable (n : Nat) [NeZero n]
/- basic operations -/
#check_simp (3 : Fin 7).succ ~> (4 : Fin 8)
#check_simp (6 : Fin 7).succ ~> (7 : Fin 8)
#check_simp Fin.last 0 ~> (0 : Fin 1)
#check_simp Fin.last 6 ~> (6 : Fin 7)
#check_simp Fin.ofNat' 6 3 ~> (3 : Fin 6)
#check_simp Fin.ofNat' 6 37 ~> (1 : Fin 6)
#check_simp Fin.rev (0 : Fin 7) ~> (6 : Fin 7)
#check_simp Fin.rev (3 : Fin 7) ~> (3 : Fin 7)
#check_simp Fin.castSucc (0 : Fin 7) ~> (0 : Fin 8)
#check_simp Fin.castSucc (3 : Fin 7) ~> (3 : Fin 8)
#check_simp Fin.castAdd 3 (0 : Fin 7) ~> (0 : Fin 10)
#check_simp Fin.castAdd 3 (3 : Fin 7) ~> (3 : Fin 10)
#check_simp Fin.castLT (3 : Fin 10) (by decide : 3 < 5) ~> (3 : Fin 5)
#check_simp Fin.castLE (by decide : 5 37) (3 : Fin 5) ~> (3 : Fin 37)
#check_simp Fin.pred (3 : Fin 7) (by decide) ~> (2 : Fin 6)
/- arithmetic operation tests -/
#check_simp (3 : Fin 7) + (1 : Fin 7) ~> 4
#check_simp (3 : Fin 7) + (5 : Fin 7) ~> 1
#check_simp (3 : Fin 7) * (1 : Fin 7) ~> 3
#check_simp (3 : Fin 7) * (3 : Fin 7) ~> 2
#check_simp (3 : Fin 7) - (1 : Fin 7) ~> 2
#check_simp (3 : Fin 7) - (5 : Fin 7) ~> 5
#check_simp (3 : Fin 7) / (1 : Fin 7) ~> 3
#check_simp (3 : Fin 7) / (5 : Fin 7) ~> 0
#check_simp (3 : Fin 7) % (0 : Fin 7) ~> 3
#check_simp (3 : Fin 7) % (1 : Fin 7) ~> 0
#check_simp (3 : Fin 7) % (5 : Fin 7) ~> 3
#check_simp (3 : Fin n) + (5 : Fin n) !~>
#check_simp (3 : Fin n) * (5 : Fin n) !~>
#check_simp (3 : Fin n) - (5 : Fin n) !~>
#check_simp (3 : Fin n) / (5 : Fin n) !~>
#check_simp (3 : Fin n) % (5 : Fin n) !~>
#check_simp Fin.addNat (3 : Fin 7) 3 ~> (6 : Fin 10)
#check_simp Fin.natAdd 3 (3 : Fin 7) ~> (6 : Fin 10)
#check_simp Fin.subNat 2 (3 : Fin 7) (by decide) ~> (1 : Fin 5)
/- bitwise operations -/
#check_simp (3 : Fin 7) &&& (1 : Fin 7) ~> 1
#check_simp (3 : Fin 7) ||| (1 : Fin 7) ~> 3
#check_simp (3 : Fin 7) ^^^ (1 : Fin 7) ~> 2
#check_simp (3 : Fin 7) <<< (1 : Fin 7) ~> 6
#check_simp (3 : Fin 7) >>> (1 : Fin 7) ~> 1
/- predicate tests -/
#check_simp (3 : Fin 7) < (1 : Fin 7) ~> False
#check_simp (3 : Fin 7) < (5 : Fin 7) ~> True
#check_simp (3 : Fin 7) (1 : Fin 7) ~> False
#check_simp (3 : Fin 7) (5 : Fin 7) ~> True
#check_simp (3 : Fin 7) > (1 : Fin 7) ~> True
#check_simp (3 : Fin 7) > (5 : Fin 7) ~> False
#check_simp (3 : Fin 7) (1 : Fin 7) ~> True
#check_simp (3 : Fin 7) (5 : Fin 7) ~> False
#check_simp (3 : Fin 7) = (1 : Fin 7) ~> False
#check_simp (3 : Fin 7) = (5 : Fin 7) ~> False
#check_simp (3 : Fin 7) = (3 : Fin 7) ~> True
#check_simp (3 : Fin 7) (1 : Fin 7) ~> True
#check_simp (3 : Fin 7) (3 : Fin 7) ~> False
#check_simp (3 : Fin 7) (5 : Fin 7) ~> True
#check_simp (3 : Fin 7) == (1 : Fin 7) ~> false
#check_simp (3 : Fin 7) == (3 : Fin 7) ~> true
#check_simp (3 : Fin 7) == (5 : Fin 7) ~> false
#check_simp (3 : Fin 7) != (1 : Fin 7) ~> true
#check_simp (3 : Fin 7) != (3 : Fin 7) ~> false
#check_simp (3 : Fin 7) != (5 : Fin 7) ~> true