Compare commits

...

1 Commits

Author SHA1 Message Date
Leonardo de Moura
09ce61cfd1 feat: safe exponentiation
- Adds configuration option `exponentiation.threshold`
- An expression `b^n` where `b` and `n` are literals is not reduced by
  `whnf`, `simp`, and `isDefEq` if `n > exponentiation.threshold`.

Motivation: prevents system from becoming irresponsive.

TODO: improve support in the kernel. It is using a hard-coded limit
for now.
2024-07-02 20:39:33 -07:00
12 changed files with 145 additions and 8 deletions

View File

@@ -519,4 +519,16 @@ instance : MonadRuntimeException CoreM where
@[inline] def mapCoreM [MonadControlT CoreM m] [Monad m] (f : forall {α}, CoreM α CoreM α) {α} (x : m α) : m α :=
controlAt CoreM fun runInBase => f <| runInBase x
/--
Returns `true` if the given message kind has not been reported in the message log,
and then mark it as reported. Otherwise, returns `false`.
We use this API to ensure we don't report the same kind of warning multiple times.
-/
def reportMessageKind (kind : Name) : CoreM Bool := do
if ( get).messages.reportedKinds.contains kind then
return false
else
modify fun s => { s with messages.reportedKinds := s.messages.reportedKinds.insert kind }
return true
end Lean

View File

@@ -350,6 +350,13 @@ structure MessageLog where
hadErrors : Bool := false
/-- The list of messages not already reported, in insertion order. -/
unreported : PersistentArray Message := {}
/--
Set of message kinds that have been added to the log.
For example, we have the kind `unsafe.exponentiation.warning` for warning messages associated with
the configuration option `exponentiation.threshold`.
We don't produce a warning if the kind is already in the following set.
-/
reportedKinds : NameSet := {}
deriving Inhabited
namespace MessageLog
@@ -403,7 +410,7 @@ def indentExpr (e : Expr) : MessageData :=
indentD e
class AddMessageContext (m : Type Type) where
/--
/--
Without context, a `MessageData` object may be be missing information
(e.g. hover info) for pretty printing, or may print an error. Hence,
`addMessageContext` should be called on all constructed `MessageData`

View File

@@ -7,6 +7,8 @@ prelude
import Lean.Data.LBool
import Lean.Meta.InferType
import Lean.Meta.NatInstTesters
import Lean.Meta.NatInstTesters
import Lean.Util.SafeExponentiation
namespace Lean.Meta
@@ -29,6 +31,10 @@ partial def evalNat (e : Expr) : OptionT MetaM Nat := do
| .mvar .. => visit e
| _ => failure
where
evalPow (b n : Expr) : OptionT MetaM Nat := do
let n evalNat n
guard ( checkExponent n)
return ( evalNat b) ^ n
visit e := do
match_expr e with
| OfNat.ofNat _ n i => guard ( isInstOfNatNat i); evalNat n
@@ -48,10 +54,10 @@ where
| Nat.mod a b => return ( evalNat a) % ( evalNat b)
| Mod.mod _ i a b => guard ( isInstModNat i); return ( evalNat a) % ( evalNat b)
| HMod.hMod _ _ _ i a b => guard ( isInstHModNat i); return ( evalNat a) % ( evalNat b)
| Nat.pow a b => return ( evalNat a) ^ ( evalNat b)
| NatPow.pow _ i a b => guard ( isInstNatPowNat i); return ( evalNat a) ^ ( evalNat b)
| Pow.pow _ _ i a b => guard ( isInstPowNat i); return ( evalNat a) ^ ( evalNat b)
| HPow.hPow _ _ _ i a b => guard ( isInstHPowNat i); return ( evalNat a) ^ ( evalNat b)
| Nat.pow a b => evalPow a b
| NatPow.pow _ i a b => guard ( isInstNatPowNat i); evalPow a b
| Pow.pow _ _ i a b => guard ( isInstPowNat i); evalPow a b
| HPow.hPow _ _ _ i a b => guard ( isInstHPowNat i); evalPow a b
| _ => failure
/--

View File

@@ -82,6 +82,7 @@ builtin_dsimproc [simp, seval] reducePow ((_ : Int) ^ (_ : Nat)) := fun e => do
let_expr HPow.hPow _ _ _ _ a b e | return .continue
let some v₁ fromExpr? a | return .continue
let some v₂ Nat.fromExpr? b | return .continue
unless ( checkExponent v₂) do return .continue
return .done <| toExpr (v₁ ^ v₂)
builtin_simproc [simp, seval] reduceLT (( _ : Int) < _) := reduceBinPred ``LT.lt 4 (. < .)

View File

@@ -6,6 +6,7 @@ Authors: Leonardo de Moura
prelude
import Init.Simproc
import Init.Data.Nat.Simproc
import Lean.Util.SafeExponentiation
import Lean.Meta.LitValues
import Lean.Meta.Offset
import Lean.Meta.Tactic.Simp.Simproc
@@ -52,7 +53,13 @@ builtin_dsimproc [simp, seval] reduceMul ((_ * _ : Nat)) := reduceBin ``HMul.hMu
builtin_dsimproc [simp, seval] reduceSub ((_ - _ : Nat)) := reduceBin ``HSub.hSub 6 (· - ·)
builtin_dsimproc [simp, seval] reduceDiv ((_ / _ : Nat)) := reduceBin ``HDiv.hDiv 6 (· / ·)
builtin_dsimproc [simp, seval] reduceMod ((_ % _ : Nat)) := reduceBin ``HMod.hMod 6 (· % ·)
builtin_dsimproc [simp, seval] reducePow ((_ ^ _ : Nat)) := reduceBin ``HPow.hPow 6 (· ^ ·)
builtin_dsimproc [simp, seval] reducePow ((_ ^ _ : Nat)) := fun e => do
let some n fromExpr? e.appFn!.appArg! | return .continue
let some m fromExpr? e.appArg! | return .continue
unless ( checkExponent m) do return .continue
return .done <| toExpr (n ^ m)
builtin_dsimproc [simp, seval] reduceGcd (gcd _ _) := reduceBin ``gcd 2 gcd
builtin_simproc [simp, seval] reduceLT (( _ : Nat) < _) := reduceBinPred ``LT.lt 4 (. < .)

View File

@@ -6,6 +6,7 @@ Authors: Leonardo de Moura
prelude
import Lean.Structure
import Lean.Util.Recognizers
import Lean.Util.SafeExponentiation
import Lean.Meta.GetUnfoldableConst
import Lean.Meta.FunInfo
import Lean.Meta.Offset
@@ -885,6 +886,13 @@ def reduceBinNatOp (f : Nat → Nat → Nat) (a b : Expr) : MetaM (Option Expr)
trace[Meta.isDefEq.whnf.reduceBinOp] "{a} op {b}"
return mkRawNatLit <| f a b
def reducePow (a b : Expr) : MetaM (Option Expr) :=
withNatValue a fun a =>
withNatValue b fun b => OptionT.run do
guard ( checkExponent b)
trace[Meta.isDefEq.whnf.reduceBinOp] "{a} ^ {b}"
return mkRawNatLit <| a ^ b
def reduceBinNatPred (f : Nat Nat Bool) (a b : Expr) : MetaM (Option Expr) := do
withNatValue a fun a =>
withNatValue b fun b =>
@@ -904,7 +912,7 @@ def reduceNat? (e : Expr) : MetaM (Option Expr) :=
| ``Nat.mul => reduceBinNatOp Nat.mul a1 a2
| ``Nat.div => reduceBinNatOp Nat.div a1 a2
| ``Nat.mod => reduceBinNatOp Nat.mod a1 a2
| ``Nat.pow => reduceBinNatOp Nat.pow a1 a2
| ``Nat.pow => reducePow a1 a2
| ``Nat.gcd => reduceBinNatOp Nat.gcd a1 a2
| ``Nat.beq => reduceBinNatPred Nat.beq a1 a2
| ``Nat.ble => reduceBinNatPred Nat.ble a1 a2

View File

@@ -29,3 +29,4 @@ import Lean.Util.OccursCheck
import Lean.Util.HasConstCache
import Lean.Util.FileSetupInfo
import Lean.Util.Heartbeats
import Lean.Util.SafeExponentiation

View File

@@ -0,0 +1,34 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
prelude
import Lean.CoreM
namespace Lean
register_builtin_option exponentiation.threshold : Nat := {
defValue := 256
descr := "maximum value for \
which exponentiation operations are safe to evaluate. When an exponent \
is a value greater than this threshold, the exponentiation will not be evaluated, \
and a warning will be logged. This helps to prevent the system from becoming \
unresponsive due to excessively large computations."
}
/--
Returns `true` if `n` is `≤ exponentiation.threshold`. Otherwise,
reports a warning and returns `false`.
This method ensures there is at most one warning message of this kind in the message log.
-/
def checkExponent (n : Nat) : CoreM Bool := do
let threshold := exponentiation.threshold.get ( getOptions)
if n > threshold then
if ( reportMessageKind `unsafe.exponentiation) then
logWarning s!"exponent {n} exceeds the threshold {threshold}, exponentiation operation was not evaluated, use `set_option {exponentiation.threshold.name} <num>` to set a new threshold"
return false
else
return true
end Lean

View File

@@ -595,6 +595,18 @@ template<typename F> optional<expr> type_checker::reduce_bin_nat_op(F const & f,
return some_expr(mk_lit(literal(nat(f(v1.raw(), v2.raw())))));
}
#define ReducePowMaxExp 1<<24 // TODO: make it configurable
optional<expr> type_checker::reduce_pow(expr const & e) {
expr arg1 = whnf(app_arg(app_fn(e)));
expr arg2 = whnf(app_arg(e));
if (!is_nat_lit_ext(arg2)) return none_expr();
nat v1 = get_nat_val(arg1);
nat v2 = get_nat_val(arg2);
if (v2 > nat(ReducePowMaxExp)) return none_expr();
return some_expr(mk_lit(literal(nat(nat_pow(v1.raw(), v2.raw())))));
}
template<typename F> optional<expr> type_checker::reduce_bin_nat_pred(F const & f, expr const & e) {
expr arg1 = whnf(app_arg(app_fn(e)));
if (!is_nat_lit_ext(arg1)) return none_expr();
@@ -622,7 +634,7 @@ optional<expr> type_checker::reduce_nat(expr const & e) {
if (f == *g_nat_add) return reduce_bin_nat_op(nat_add, e);
if (f == *g_nat_sub) return reduce_bin_nat_op(nat_sub, e);
if (f == *g_nat_mul) return reduce_bin_nat_op(nat_mul, e);
if (f == *g_nat_pow) return reduce_bin_nat_op(nat_pow, e);
if (f == *g_nat_pow) return reduce_pow(e);
if (f == *g_nat_gcd) return reduce_bin_nat_op(nat_gcd, e);
if (f == *g_nat_mod) return reduce_bin_nat_op(nat_mod, e);
if (f == *g_nat_div) return reduce_bin_nat_op(nat_div, e);

View File

@@ -101,6 +101,7 @@ private:
template<typename F> optional<expr> reduce_bin_nat_op(F const & f, expr const & e);
template<typename F> optional<expr> reduce_bin_nat_pred(F const & f, expr const & e);
optional<expr> reduce_pow(expr const & e);
optional<expr> reduce_nat(expr const & e);
public:
type_checker(state & st, local_ctx const & lctx, definition_safety ds = definition_safety::safe);

View File

@@ -49,6 +49,7 @@ def p_31 := 216091
def p_32 := 756839
def p_33 := 859433
set_option exponentiation.threshold 10000000
/- GCD with large prime factors on one side, and small primes on the other. -/
example : Nat.gcd (p_29 * p_30 * p_31 * p_32 * p_33) 2^(2^20) = 1 := rfl
/- GCD with two prime factors on both sides, including one in common. -/

View File

@@ -0,0 +1,47 @@
/--
warning: exponent 10000000 exceeds the threshold 256, exponentiation operation was not evaluated, use `set_option exponentiation.threshold <num>` to set a new threshold
---
error: maximum recursion depth has been reached
use `set_option maxRecDepth <num>` to increase limit
use `set_option diagnostics true` to get diagnostic information
-/
#guard_msgs in
example : 2^2^8000000 = 3^3^10000000 :=
rfl
/--
-/
#guard_msgs in
set_option exponentiation.threshold 258 in
example : 2^257 = 2*2^256 :=
rfl
/--
warning: exponent 2008 exceeds the threshold 256, exponentiation operation was not evaluated, use `set_option exponentiation.threshold <num>` to set a new threshold
---
warning: declaration uses 'sorry'
---
error: (kernel) deep recursion detected
---
info: k : Nat
h : k = 2008 ^ 2 + 2 ^ 2008
⊢ ((4032064 + 2 ^ 2008) ^ 2 + 2 ^ (4032064 + 2 ^ 2008)) % 10 = 6
-/
#guard_msgs in
example (k : Nat) (h : k = 2008^2 + 2^2008) : (k^2 + 2^k)%10 = 6 := by
simp [h]
trace_state
sorry
/--
warning: declaration uses 'sorry'
---
info: k : Nat
h : k = 2008 ^ 2 + 2 ^ 2008
⊢ ((2008 ^ 2 + 2 ^ 2008) ^ 2 + 2 ^ (2008 ^ 2 + 2 ^ 2008)) % 10 = 6
-/
#guard_msgs in
example (k : Nat) (h : k = 2008^2 + 2^2008) : (k^2 + 2^k)%10 = 6 := by
rw [h]
trace_state
sorry