From 2c139bbf4e5724ab253b5b034ce925e04267a9c4 Mon Sep 17 00:00:00 2001 From: Damien George Date: Tue, 30 Nov 2021 00:31:46 +1100 Subject: [PATCH] py/mpz: Fix bugs with bitwise of -0 by ensuring all 0's are positive. This commit makes sure that the value zero is always encoded in an mpz_t as neg=0 and len=0 (previously it was just len=0). This invariant is needed for some of the bitwise operations that operate on negative numbers, because they cannot handle -0. For example (-((1<<100)-(1<<100)))|1 was being computed as -65535, instead of 1. Fixes issue #8042. Signed-off-by: Damien George --- py/mpz.c | 19 ++++++++------- py/mpz.h | 3 ++- tests/basics/int_big_zeroone.py | 41 +++++++++++++++++++++++++++++++-- 3 files changed, 52 insertions(+), 11 deletions(-) diff --git a/py/mpz.c b/py/mpz.c index 75e1fb1fd..b61997e2f 100644 --- a/py/mpz.c +++ b/py/mpz.c @@ -713,6 +713,7 @@ void mpz_set(mpz_t *dest, const mpz_t *src) { void mpz_set_from_int(mpz_t *z, mp_int_t val) { if (val == 0) { + z->neg = 0; z->len = 0; return; } @@ -899,10 +900,6 @@ bool mpz_is_even(const mpz_t *z) { #endif int mpz_cmp(const mpz_t *z1, const mpz_t *z2) { - // to catch comparison of -0 with +0 - if (z1->len == 0 && z2->len == 0) { - return 0; - } int cmp = (int)z2->neg - (int)z1->neg; if (cmp != 0) { return cmp; @@ -1052,7 +1049,9 @@ void mpz_neg_inpl(mpz_t *dest, const mpz_t *z) { if (dest != z) { mpz_set(dest, z); } - dest->neg = 1 - dest->neg; + if (dest->len) { + dest->neg = 1 - dest->neg; + } } /* computes dest = ~z (= -z - 1) @@ -1148,7 +1147,7 @@ void mpz_add_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) { dest->len = mpn_sub(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len); } - dest->neg = lhs->neg; + dest->neg = lhs->neg & !!dest->len; } /* computes dest = lhs - rhs @@ -1172,7 +1171,9 @@ void mpz_sub_inpl(mpz_t *dest, const mpz_t *lhs, const mpz_t *rhs) { dest->len = mpn_sub(dest->dig, lhs->dig, lhs->len, rhs->dig, rhs->len); } - if (neg) { + if (dest->len == 0) { + dest->neg = 0; + } else if (neg) { dest->neg = 1 - lhs->neg; } else { dest->neg = lhs->neg; @@ -1484,14 +1485,16 @@ void mpz_divmod_inpl(mpz_t *dest_quo, mpz_t *dest_rem, const mpz_t *lhs, const m mpz_need_dig(dest_quo, lhs->len + 1); // +1 necessary? memset(dest_quo->dig, 0, (lhs->len + 1) * sizeof(mpz_dig_t)); + dest_quo->neg = 0; dest_quo->len = 0; mpz_need_dig(dest_rem, lhs->len + 1); // +1 necessary? mpz_set(dest_rem, lhs); mpn_div(dest_rem->dig, &dest_rem->len, rhs->dig, rhs->len, dest_quo->dig, &dest_quo->len); + dest_rem->neg &= !!dest_rem->len; // check signs and do Python style modulo if (lhs->neg != rhs->neg) { - dest_quo->neg = 1; + dest_quo->neg = !!dest_quo->len; if (!mpz_is_zero(dest_rem)) { mpz_t mpzone; mpz_init_from_int(&mpzone, -1); diff --git a/py/mpz.h b/py/mpz.h index 425587ee9..d27f57240 100644 --- a/py/mpz.h +++ b/py/mpz.h @@ -91,6 +91,7 @@ typedef int8_t mpz_dbl_dig_signed_t; #define MPZ_NUM_DIG_FOR_LL ((sizeof(long long) * 8 + MPZ_DIG_SIZE - 1) / MPZ_DIG_SIZE) typedef struct _mpz_t { + // Zero has neg=0, len=0. Negative zero is not allowed. size_t neg : 1; size_t fixed_dig : 1; size_t alloc : (8 * sizeof(size_t) - 2); @@ -119,7 +120,7 @@ static inline bool mpz_is_zero(const mpz_t *z) { return z->len == 0; } static inline bool mpz_is_neg(const mpz_t *z) { - return z->len != 0 && z->neg != 0; + return z->neg != 0; } int mpz_cmp(const mpz_t *lhs, const mpz_t *rhs); diff --git a/tests/basics/int_big_zeroone.py b/tests/basics/int_big_zeroone.py index 7e0b7a720..81381526b 100644 --- a/tests/basics/int_big_zeroone.py +++ b/tests/basics/int_big_zeroone.py @@ -1,4 +1,4 @@ -# test [0,-0,1,-1] edge cases of bignum +# test [0,1,-1] edge cases of bignum long_zero = (2**64) >> 65 long_neg_zero = -long_zero @@ -13,7 +13,7 @@ print([~c for c in cases]) print([c >> 1 for c in cases]) print([c << 1 for c in cases]) -# comparison of 0/-0/+0 +# comparison of 0 print(long_zero == 0) print(long_neg_zero == 0) print(long_one - 1 == 0) @@ -26,3 +26,40 @@ print(long_neg_zero < 1) print(long_neg_zero < -1) print(long_neg_zero > 1) print(long_neg_zero > -1) + +# generate zeros that involve negative numbers +large = 1 << 70 +large_plus_one = large + 1 +zeros = ( + large - large, + -large + large, + large + -large, + -(large - large), + large - large_plus_one + 1, + -large & (large - large), + -large ^ -large, + -large * (large - large), + (large - large) // -large, + -large // -large_plus_one, + -(large + large) % large, + (large + large) % -large, + -(large + large) % -large, +) +print(zeros) + +# compute arithmetic operations that may have problems with -0 +# (this checks that -0 is never generated in the zeros tuple) +cases = (0, 1, -1) + zeros +for lhs in cases: + print("-{} = {}".format(lhs, -lhs)) + print("~{} = {}".format(lhs, ~lhs)) + print("{} >> 1 = {}".format(lhs, lhs >> 1)) + print("{} << 1 = {}".format(lhs, lhs << 1)) + for rhs in cases: + print("{} == {} = {}".format(lhs, rhs, lhs == rhs)) + print("{} + {} = {}".format(lhs, rhs, lhs + rhs)) + print("{} - {} = {}".format(lhs, rhs, lhs - rhs)) + print("{} * {} = {}".format(lhs, rhs, lhs * rhs)) + print("{} | {} = {}".format(lhs, rhs, lhs | rhs)) + print("{} & {} = {}".format(lhs, rhs, lhs & rhs)) + print("{} ^ {} = {}".format(lhs, rhs, lhs ^ rhs))