py/objtuple: Properly implement comparison with incompatible types.

Should raise TypeError, unless it's (in)equality comparison.
This commit is contained in:
Paul Sokolovsky 2017-09-06 00:23:41 +03:00
parent e354b0a0cb
commit 1aaba5cabe
2 changed files with 17 additions and 4 deletions

View File

@ -101,7 +101,7 @@ STATIC mp_obj_t mp_obj_tuple_make_new(const mp_obj_type_t *type_in, size_t n_arg
}
// Don't pass MP_BINARY_OP_NOT_EQUAL here
STATIC bool tuple_cmp_helper(mp_uint_t op, mp_obj_t self_in, mp_obj_t another_in) {
STATIC mp_obj_t tuple_cmp_helper(mp_uint_t op, mp_obj_t self_in, mp_obj_t another_in) {
// type check is done on getiter method to allow tuple, namedtuple, attrtuple
mp_check_self(mp_obj_get_type(self_in)->getiter == mp_obj_tuple_getiter);
mp_obj_type_t *another_type = mp_obj_get_type(another_in);
@ -110,12 +110,15 @@ STATIC bool tuple_cmp_helper(mp_uint_t op, mp_obj_t self_in, mp_obj_t another_in
// Slow path for user subclasses
another_in = mp_instance_cast_to_native_base(another_in, MP_OBJ_FROM_PTR(&mp_type_tuple));
if (another_in == MP_OBJ_NULL) {
return false;
if (op == MP_BINARY_OP_EQUAL) {
return mp_const_false;
}
return MP_OBJ_NULL;
}
}
mp_obj_tuple_t *another = MP_OBJ_TO_PTR(another_in);
return mp_seq_cmp_objs(op, self->items, self->len, another->items, another->len);
return mp_obj_new_bool(mp_seq_cmp_objs(op, self->items, self->len, another->items, another->len));
}
mp_obj_t mp_obj_tuple_unary_op(mp_unary_op_t op, mp_obj_t self_in) {
@ -166,7 +169,7 @@ mp_obj_t mp_obj_tuple_binary_op(mp_binary_op_t op, mp_obj_t lhs, mp_obj_t rhs) {
case MP_BINARY_OP_LESS_EQUAL:
case MP_BINARY_OP_MORE:
case MP_BINARY_OP_MORE_EQUAL:
return mp_obj_new_bool(tuple_cmp_helper(op, lhs, rhs));
return tuple_cmp_helper(op, lhs, rhs);
default:
return MP_OBJ_NULL; // op not supported

View File

@ -53,3 +53,13 @@ print((10, 0) > (1, 1))
print((10, 0) < (1, 1))
print((0, 0, 10, 0) > (0, 0, 1, 1))
print((0, 0, 10, 0) < (0, 0, 1, 1))
print(() == {})
print(() != {})
print((1,) == [1])
try:
print(() < {})
except TypeError:
print("TypeError")