Compare commits

...

7 Commits

Author SHA1 Message Date
Joe Hendrix
6d117b725b fix: fix non-gmp path 2024-02-18 22:37:45 -08:00
Joe Hendrix
299c162df4 fix: fix reference counting on ediv/emod 2024-02-18 21:35:26 -08:00
Joe Hendrix
317e1bd5c2 chore: add empty expected file 2024-02-18 12:18:19 -08:00
Joe Hendrix
97b0d896d1 chore: improve bigint tests 2024-02-18 00:04:20 -08:00
Joe Hendrix
effc76aebb chore: add integer division tests 2024-02-18 00:04:20 -08:00
Joe Hendrix
d54597bce2 Apply suggestions from code review 2024-02-16 18:04:02 -08:00
Joe Hendrix
a8178aafb3 feat: Add native ediv and emod implementations. 2024-02-16 18:00:11 -08:00
7 changed files with 266 additions and 2 deletions

View File

@@ -131,7 +131,8 @@ Integer division. This version of `Int.div` uses the E-rounding convention
(euclidean division), in which `Int.emod x y` satisfies `0 ≤ mod x y < natAbs y` for `y ≠ 0`
and `Int.ediv` is the unique function satisfying `emod x y + (ediv x y) * y = x`.
-/
def ediv : Int Int Int
@[extern "lean_int_ediv"]
def ediv : (@& Int) (@& Int) Int
| ofNat m, ofNat n => ofNat (m / n)
| ofNat m, -[n+1] => -ofNat (m / succ n)
| -[_+1], 0 => 0
@@ -143,7 +144,8 @@ Integer modulus. This version of `Int.mod` uses the E-rounding convention
(euclidean division), in which `Int.emod x y` satisfies `0 ≤ emod x y < natAbs y` for `y ≠ 0`
and `Int.ediv` is the unique function satisfying `emod x y + (ediv x y) * y = x`.
-/
def emod : Int Int Int
@[extern "lean_int_emod"]
def emod : (@& Int) (@& Int) Int
| ofNat m, n => ofNat (m % natAbs n)
| -[m+1], n => subNatNat (natAbs n) (succ (m % natAbs n))

View File

@@ -1320,6 +1320,8 @@ LEAN_SHARED lean_object * lean_int_big_sub(lean_object * a1, lean_object * a2);
LEAN_SHARED lean_object * lean_int_big_mul(lean_object * a1, lean_object * a2);
LEAN_SHARED lean_object * lean_int_big_div(lean_object * a1, lean_object * a2);
LEAN_SHARED lean_object * lean_int_big_mod(lean_object * a1, lean_object * a2);
LEAN_SHARED lean_object * lean_int_big_ediv(lean_object * a1, lean_object * a2);
LEAN_SHARED lean_object * lean_int_big_emod(lean_object * a1, lean_object * a2);
LEAN_SHARED bool lean_int_big_eq(lean_object * a1, lean_object * a2);
LEAN_SHARED bool lean_int_big_le(lean_object * a1, lean_object * a2);
LEAN_SHARED bool lean_int_big_lt(lean_object * a1, lean_object * a2);
@@ -1461,6 +1463,81 @@ static inline lean_obj_res lean_int_mod(b_lean_obj_arg a1, b_lean_obj_arg a2) {
}
}
/*
lean_int_ediv and lean_int_emod implement "Euclidean" division and modulus using the
algorithm in:
Division and Modulus for Computer Scientists
Daan Leijen
https://www.microsoft.com/en-us/research/publication/division-and-modulus-for-computer-scientists/
*/
static inline lean_obj_res lean_int_ediv(b_lean_obj_arg a1, b_lean_obj_arg a2) {
if (LEAN_LIKELY(lean_is_scalar(a1) && lean_is_scalar(a2))) {
if (sizeof(void*) == 8) {
/* 64-bit version, we use 64-bit numbers to avoid overflow when v1 == LEAN_MIN_SMALL_INT. */
int64_t n = lean_scalar_to_int(a1);
int64_t d = lean_scalar_to_int(a2);
if (d == 0)
return lean_box(0);
else {
int64_t q = n / d;
int64_t r = n % d;
if (r < 0)
q = (d > 0) ? q - 1 : q + 1;
return lean_int64_to_int(q);
}
} else {
/* 32-bit version */
int n = lean_scalar_to_int(a1);
int d = lean_scalar_to_int(a2);
if (d == 0) {
return lean_box(0);
} else {
int q = n / d;
int r = n % d;
if (r < 0)
q = (d > 0) ? q - 1 : q + 1;
return lean_int_to_int(q);
}
}
} else {
return lean_int_big_ediv(a1, a2);
}
}
static inline lean_obj_res lean_int_emod(b_lean_obj_arg a1, b_lean_obj_arg a2) {
if (LEAN_LIKELY(lean_is_scalar(a1) && lean_is_scalar(a2))) {
if (sizeof(void*) == 8) {
/* 64-bit version, we use 64-bit numbers to avoid overflow when v1 == LEAN_MIN_SMALL_INT. */
int64_t n = lean_scalar_to_int64(a1);
int64_t d = lean_scalar_to_int64(a2);
if (d == 0) {
return a1;
} else {
int64_t r = n % d;
if (r < 0)
r = (d > 0) ? r + d : r - d;
return lean_int64_to_int(r);
}
} else {
/* 32-bit version */
int n = lean_scalar_to_int(a1);
int d = lean_scalar_to_int(a2);
if (d == 0)
return a1;
else {
int r = n % d;
if (r < 0)
r = (d > 0) ? r + d : r - d;
return lean_int_to_int(r);
}
}
} else {
return lean_int_big_emod(a1, a2);
}
}
static inline bool lean_int_eq(b_lean_obj_arg a1, b_lean_obj_arg a2) {
if (LEAN_LIKELY(lean_is_scalar(a1) && lean_is_scalar(a2))) {
return a1 == a2;

View File

@@ -160,6 +160,43 @@ mpz & mpz::operator*=(unsigned u) { mpz_mul_ui(m_val, m_val, u); return *this; }
mpz & mpz::operator*=(int u) { mpz_mul_si(m_val, m_val, u); return *this; }
mpz mpz::ediv(mpz const & n, mpz const & d) {
mpz q;
mpz_t r;
mpz_init(r);
/* (q,r) = (n/d, n%d) */
mpz_tdiv_qr(q.m_val, r, n.m_val, d.m_val);
/* if (r < 0) */
if (mpz_sgn(r) < 0) {
if (mpz_sgn(d.m_val) > 0) {
/* q = q - 1. */
mpz_sub_ui(q.m_val, q.m_val, 1);
} else {
/* q = q + 1. */
mpz_add_ui(q.m_val, q.m_val, 1);
}
}
mpz_clear(r);
return q;
}
mpz mpz::emod(mpz const & n, mpz const & d) {
mpz r;
/* (q,r) = (n/d, n%d) */
mpz_tdiv_r(r.m_val, n.m_val, d.m_val);
/* if (r < 0) */
if (mpz_sgn(r.m_val) < 0) {
if (mpz_sgn(d.m_val) > 0) {
/* r = r + d. */
mpz_add(r.m_val, r.m_val, d.m_val);
} else {
/* r = r - d. */
mpz_sub(r.m_val, r.m_val, d.m_val);
}
}
return r;
}
mpz & mpz::operator/=(mpz const & o) { mpz_tdiv_q(m_val, m_val, o.m_val); return *this; }
mpz & mpz::operator/=(unsigned u) { mpz_tdiv_q_ui(m_val, m_val, u); return *this; }
@@ -630,6 +667,7 @@ mpz & mpz::rem(size_t sz, mpn_digit const * digits) {
digits, sz,
q1.begin(), r1.begin());
set(r_sz, r1.begin());
m_sign = m_sign && !is_zero();
return *this;
}
@@ -699,6 +737,53 @@ mpz & mpz::operator%=(mpz const & o) {
return rem(o.m_size, o.m_digits);
}
mpz mpz::ediv(mpz const & n, mpz const & d) {
if (d.m_size > n.m_size) {
if (n.is_neg()) {
int64_t r = d.is_pos() ? -1 : 1;
return mpz(r);
} else {
return mpz(0);
}
} else {
digit_buffer q1, r1;
size_t q_sz = n.m_size - d.m_size + 1;
size_t r_sz = d.m_size;
q1.ensure_capacity(q_sz);
r1.ensure_capacity(r_sz);
mpn_div(n.m_digits, n.m_size,
d.m_digits, d.m_size,
q1.begin(), r1.begin());
mpz q;
q.set(q_sz, q1.begin());
q.m_sign = !q.is_zero() && n.m_sign != d.m_sign;
mpz r;
r.set(r_sz, r1.begin());
r.m_sign = n.m_sign && !r.is_zero();
if (r.is_neg()) {
if (d.is_pos()) {
q -= 1;
} else {
q += 1;
}
}
return q;
}
}
mpz mpz::emod(mpz const & n, mpz const & d) {
mpz r(n);
r.rem(d.m_size, d.m_digits);
if (r.is_neg()) {
if (d.is_pos()) {
r += d;
} else {
r -= d;
}
}
return r;
}
mpz mpz::pow(unsigned int p) const {
unsigned mask = 1;
mpz power(*this);

View File

@@ -245,6 +245,14 @@ public:
friend mpz operator%(mpz a, mpz const & b) { return a %= b; }
static mpz ediv(mpz const & n, mpz const & d);
static mpz ediv(int n, mpz const & d) { return ediv(mpz(n), d); }
static mpz ediv(mpz const& n, int d) { return ediv(n, mpz(d)); }
static mpz emod(mpz const & n, mpz const & d);
static mpz emod(int n, mpz const & d) { return emod(mpz(n), d); }
static mpz emod(mpz const & n, int d) { return emod(n, mpz(d)); };
mpz & operator&=(mpz const & o);
mpz & operator|=(mpz const & o);
mpz & operator^=(mpz const & o);

View File

@@ -1432,6 +1432,36 @@ extern "C" LEAN_EXPORT object * lean_int_big_mod(object * a1, object * a2) {
}
}
extern "C" LEAN_EXPORT object * lean_int_big_ediv(object * a1, object * a2) {
if (lean_is_scalar(a1)) {
return mpz_to_int(mpz::ediv(lean_scalar_to_int(a1), mpz_value(a2)));
} else if (lean_is_scalar(a2)) {
int d = lean_scalar_to_int(a2);
if (d == 0)
return a2;
else
return mpz_to_int(mpz::ediv(mpz_value(a1), d));
} else {
return mpz_to_int(mpz::ediv(mpz_value(a1), mpz_value(a2)));
}
}
extern "C" LEAN_EXPORT object * lean_int_big_emod(object * a1, object * a2) {
if (lean_is_scalar(a1)) {
return mpz_to_int(mpz::emod(lean_scalar_to_int(a1), mpz_value(a2)));
} else if (lean_is_scalar(a2)) {
int i2 = lean_scalar_to_int(a2);
if (i2 == 0) {
lean_inc(a1);
return a1;
} else {
return mpz_to_int(mpz::emod(mpz_value(a1), i2));
}
} else {
return mpz_to_int(mpz::emod(mpz_value(a1), mpz_value(a2)));
}
}
extern "C" LEAN_EXPORT bool lean_int_big_eq(object * a1, object * a2) {
if (lean_is_scalar(a1)) {
lean_assert(lean_scalar_to_int(a1) != mpz_value(a2))

View File

@@ -0,0 +1,62 @@
-- Divide by zero tests
#guard ( 0 : Int) / 0 = 0
#guard ( 0 : Int) % 0 == 0
#guard ( 4 : Int) / 0 == 0
#guard ( 4 : Int) % 0 == 4
#guard (-4 : Int) / 0 == 0
#guard (-4 : Int) % 0 == -4
#guard ( 0 : Int) / 4 == 0
#guard ( 0 : Int) % 4 == 0
#guard ( 0 : Int) / -4 == 0
#guard ( 0 : Int) % -4 == 0
-- Euclidean division tests
#guard ( 4 : Int) / 3 == 1
#guard ( 4 : Int) % 3 == 1
#guard ( 5 : Int) / 3 == 1
#guard ( 5 : Int) % 3 == 2
#guard ( 6 : Int) / 3 == 2
#guard ( 6 : Int) % 3 == 0
#guard ( 7 : Int) / 4 == 1
#guard ( 7 : Int) % 4 == 3
#guard ( 4 : Int) / -3 == -1
#guard ( 4 : Int) % -3 == 1
#guard ( 5 : Int) / -3 == -1
#guard ( 5 : Int) % -3 == 2
#guard ( 6 : Int) / -3 == -2
#guard ( 6 : Int) % -3 == 0
#guard ( 7 : Int) / -4 == -1
#guard ( 7 : Int) % -4 == 3
#guard (-4 : Int) / 3 == -2
#guard (-4 : Int) % 3 == 2
#guard (-5 : Int) / 3 == -2
#guard (-5 : Int) % 3 == 1
#guard (-6 : Int) / 3 == -2
#guard (-6 : Int) % 3 == 0
#guard (-7 : Int) / 4 == -2
#guard (-7 : Int) % 4 == 1
#guard (-4 : Int) / -3 == 2
#guard (-4 : Int) % -3 == 2
#guard (-5 : Int) / -3 == 2
#guard (-5 : Int) % -3 == 1
#guard (-6 : Int) / -3 == 2
#guard (-6 : Int) % -3 == 0
#guard (-7 : Int) / -4 == 2
#guard (-7 : Int) % -4 == 1
-- Basic big integer tests
#guard let n : Int := 0; let d : Int := 2^64; n / d = 0 n % d = n
#guard let n : Int := 1; let d : Int := 2^64; n / d = 0 n % d = n
#guard let n : Int := -1; let d : Int := 2^64; n / d = -1 n % d = (d + n)
#guard let n : Int := 2^128; let d : Int := 3; d * (n / d) + n % d = n n % d 0 n % d < d
#guard let n : Int := 2^128; let d : Int := 2^64; d * (n / d) + n % d = n n % d 0 n % d < d
#guard let n : Int := -2^128; let d : Int := 2^64; d * (n / d) + n % d = n n % d 0 n % d < d
#guard let n : Int := 2^128; let d : Int := -2^64; d * (n / d) + n % d = n n % d 0 n % d < d.natAbs
#guard let n : Int := -2^128; let d : Int := -2^64; d * (n / d) + n % d = n n % d 0 n % d < d.natAbs
#guard let n : Int := 2^128+7; let d : Int := 2^64; d * (n / d) + n % d = n n % d 0 n % d < d
#guard let n : Int := -2^128+3; let d : Int := 2^64; d * (n / d) + n % d = n n % d 0 n % d < d
#guard let n : Int := 2^128+2; let d : Int := -2^64; d * (n / d) + n % d = n n % d 0 n % d < d.natAbs
#guard let n : Int := -2^128+2; let d : Int := -2^64; d * (n / d) + n % d = n n % d 0 n % d < d.natAbs

View File