py/objarray: Prohibit comparison of mismatching types.

Array equality is defined as each element being equal but to keep
code size down MicroPython implements a binary comparison.  This
can only be used correctly for elements with the same binary layout
though so turn it into an NotImplementedError when comparing types
for which the binary comparison yielded incorrect results: types
with different sizes, and floating point numbers because nan != nan.
This commit is contained in:
stijn 2021-05-12 17:02:06 +02:00 committed by Damien George
parent 6affcb0104
commit 57365d8557
4 changed files with 54 additions and 1 deletions

View File

@ -258,6 +258,16 @@ STATIC mp_obj_t array_unary_op(mp_unary_op_t op, mp_obj_t o_in) {
}
}
STATIC int typecode_for_comparison(int typecode) {
if (typecode == BYTEARRAY_TYPECODE) {
typecode = 'B';
}
if (typecode <= 'Z') {
typecode += 32; // to lowercase
}
return typecode;
}
STATIC mp_obj_t array_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
mp_obj_array_t *lhs = MP_OBJ_TO_PTR(lhs_in);
switch (op) {
@ -319,7 +329,20 @@ STATIC mp_obj_t array_binary_op(mp_binary_op_t op, mp_obj_t lhs_in, mp_obj_t rhs
if (!mp_get_buffer(rhs_in, &rhs_bufinfo, MP_BUFFER_READ)) {
return mp_const_false;
}
return mp_obj_new_bool(mp_seq_cmp_bytes(op, lhs_bufinfo.buf, lhs_bufinfo.len, rhs_bufinfo.buf, rhs_bufinfo.len));
// mp_seq_cmp_bytes is used so only compatible representations can be correctly compared.
// The type doesn't matter: array/bytearray/str/bytes all have the same buffer layout, so
// just check if the typecodes are compatible; for testing equality the types should have the
// same code except for signedness, and not be floating point because nan never equals nan.
// Note that typecode_for_comparison always returns lowercase letters to save code size.
// No need for (& TYPECODE_MASK) here: xxx_get_buffer already takes care of that.
const int lhs_code = typecode_for_comparison(lhs_bufinfo.typecode);
const int rhs_code = typecode_for_comparison(rhs_bufinfo.typecode);
if (lhs_code == rhs_code && lhs_code != 'f' && lhs_code != 'd') {
return mp_obj_new_bool(mp_seq_cmp_bytes(op, lhs_bufinfo.buf, lhs_bufinfo.len, rhs_bufinfo.buf, rhs_bufinfo.len));
}
// mp_obj_equal_not_equal treats returning MP_OBJ_NULL as 'fall back to pointer comparison'
// for MP_BINARY_OP_EQUAL but that is incompatible with CPython.
mp_raise_NotImplementedError(NULL);
}
default:

View File

@ -41,14 +41,23 @@ except ValueError:
# equality (CPython requires both sides are array)
print(bytes(array.array('b', [0x61, 0x62, 0x63])) == b'abc')
print(array.array('b', [0x61, 0x62, 0x63]) == b'abc')
print(array.array('B', [0x61, 0x62, 0x63]) == b'abc')
print(array.array('b', [0x61, 0x62, 0x63]) != b'abc')
print(array.array('b', [0x61, 0x62, 0x63]) == b'xyz')
print(array.array('b', [0x61, 0x62, 0x63]) != b'xyz')
print(b'abc' == array.array('b', [0x61, 0x62, 0x63]))
print(b'abc' == array.array('B', [0x61, 0x62, 0x63]))
print(b'abc' != array.array('b', [0x61, 0x62, 0x63]))
print(b'xyz' == array.array('b', [0x61, 0x62, 0x63]))
print(b'xyz' != array.array('b', [0x61, 0x62, 0x63]))
compatible_typecodes = []
for t in ["b", "h", "i", "l", "q"]:
compatible_typecodes.append((t, t))
compatible_typecodes.append((t, t.upper()))
for a, b in compatible_typecodes:
print(array.array(a, [1, 2]) == array.array(b, [1, 2]))
class X(array.array):
pass

View File

@ -17,3 +17,15 @@ print(a[0])
a = array.array('P')
a.append(1)
print(a[0])
# comparison between mismatching binary layouts is not implemented
typecodes = ["b", "h", "i", "l", "q", "P", "O", "S", "f", "d"]
for a in typecodes:
for b in typecodes:
if a == b and a not in ["f", "d"]:
continue
try:
array.array(a) == array.array(b)
print('FAIL')
except NotImplementedError:
pass

View File

@ -0,0 +1,9 @@
"""
categories: Modules,array
description: Comparison between different typecodes not supported
cause: Code size
workaround: Compare individual elements
"""
import array
array.array("b", [1, 2]) == array.array("i", [1, 2])