diff --git a/py/builtintables.c b/py/builtintables.c index 9a46be6e0..ea864c6c7 100644 --- a/py/builtintables.c +++ b/py/builtintables.c @@ -17,6 +17,7 @@ STATIC const mp_map_elem_t mp_builtin_object_table[] = { // built-in types { MP_OBJ_NEW_QSTR(MP_QSTR_bool), (mp_obj_t)&mp_type_bool }, { MP_OBJ_NEW_QSTR(MP_QSTR_bytes), (mp_obj_t)&mp_type_bytes }, + { MP_OBJ_NEW_QSTR(MP_QSTR_bytearray), (mp_obj_t)&mp_type_bytearray }, #if MICROPY_ENABLE_FLOAT { MP_OBJ_NEW_QSTR(MP_QSTR_complex), (mp_obj_t)&mp_type_complex }, #endif @@ -72,7 +73,6 @@ STATIC const mp_map_elem_t mp_builtin_object_table[] = { { MP_OBJ_NEW_QSTR(MP_QSTR_repr), (mp_obj_t)&mp_builtin_repr_obj }, { MP_OBJ_NEW_QSTR(MP_QSTR_sorted), (mp_obj_t)&mp_builtin_sorted_obj }, { MP_OBJ_NEW_QSTR(MP_QSTR_sum), (mp_obj_t)&mp_builtin_sum_obj }, - { MP_OBJ_NEW_QSTR(MP_QSTR_bytearray), (mp_obj_t)&mp_builtin_bytearray_obj }, // built-in exceptions { MP_OBJ_NEW_QSTR(MP_QSTR_BaseException), (mp_obj_t)&mp_type_BaseException }, diff --git a/py/obj.h b/py/obj.h index 77cf7838e..2446f6375 100644 --- a/py/obj.h +++ b/py/obj.h @@ -259,6 +259,7 @@ extern const mp_obj_type_t mp_type_bool; extern const mp_obj_type_t mp_type_int; extern const mp_obj_type_t mp_type_str; extern const mp_obj_type_t mp_type_bytes; +extern const mp_obj_type_t mp_type_bytearray; extern const mp_obj_type_t mp_type_float; extern const mp_obj_type_t mp_type_complex; extern const mp_obj_type_t mp_type_tuple; diff --git a/py/objarray.c b/py/objarray.c index 4f9fa49bc..c65673fff 100644 --- a/py/objarray.c +++ b/py/objarray.c @@ -88,19 +88,20 @@ STATIC mp_obj_t array_make_new(mp_obj_t type_in, uint n_args, uint n_kw, const m return array_construct(*typecode, args[1]); } -// This is top-level factory function, not virtual method -// TODO: "bytearray" really should be type, not function -STATIC mp_obj_t mp_builtin_bytearray(mp_obj_t arg) { - if (MP_OBJ_IS_SMALL_INT(arg)) { - uint len = MP_OBJ_SMALL_INT_VALUE(arg); +STATIC mp_obj_t bytearray_make_new(mp_obj_t type_in, uint n_args, uint n_kw, const mp_obj_t *args) { + if (n_args > 1) { + nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_TypeError, "unexpected # of arguments, %d given", n_args)); + } + + if (MP_OBJ_IS_SMALL_INT(args[0])) { + uint len = MP_OBJ_SMALL_INT_VALUE(args[0]); mp_obj_array_t *o = array_new(BYTEARRAY_TYPECODE, len); memset(o->items, 0, len); return o; } - return array_construct(BYTEARRAY_TYPECODE, arg); + return array_construct(BYTEARRAY_TYPECODE, args[0]); } -MP_DEFINE_CONST_FUN_OBJ_1(mp_builtin_bytearray_obj, mp_builtin_bytearray); STATIC mp_obj_t array_unary_op(int op, mp_obj_t o_in) { mp_obj_array_t *o = o_in; @@ -127,7 +128,7 @@ STATIC mp_obj_t array_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) { } STATIC mp_obj_t array_append(mp_obj_t self_in, mp_obj_t arg) { - assert(MP_OBJ_IS_TYPE(self_in, &mp_type_array)); + assert(MP_OBJ_IS_TYPE(self_in, &mp_type_array) || MP_OBJ_IS_TYPE(self_in, &mp_type_bytearray)); mp_obj_array_t *self = self_in; if (self->free == 0) { int item_sz = mp_binary_get_size(self->typecode); @@ -174,9 +175,22 @@ const mp_obj_type_t mp_type_array = { .locals_dict = (mp_obj_t)&array_locals_dict, }; +const mp_obj_type_t mp_type_bytearray = { + { &mp_type_type }, + .name = MP_QSTR_bytearray, + .print = array_print, + .make_new = bytearray_make_new, + .getiter = array_iterator_new, + .unary_op = array_unary_op, + .binary_op = array_binary_op, + .store_item = array_store_item, + .buffer_p = { .get_buffer = array_get_buffer }, + .locals_dict = (mp_obj_t)&array_locals_dict, +}; + STATIC mp_obj_array_t *array_new(char typecode, uint n) { mp_obj_array_t *o = m_new_obj(mp_obj_array_t); - o->base.type = &mp_type_array; + o->base.type = (typecode == BYTEARRAY_TYPECODE) ? &mp_type_bytearray : &mp_type_array; o->typecode = typecode; o->free = 0; o->len = n; diff --git a/py/objarray.h b/py/objarray.h index 620426ce3..0f6ede86f 100644 --- a/py/objarray.h +++ b/py/objarray.h @@ -1,3 +1 @@ -MP_DECLARE_CONST_FUN_OBJ(mp_builtin_bytearray_obj); - mp_obj_t mp_obj_new_bytearray(uint n, void *items); diff --git a/tests/basics/bytearray1.py b/tests/basics/bytearray1.py index 201b5b659..e564165b9 100644 --- a/tests/basics/bytearray1.py +++ b/tests/basics/bytearray1.py @@ -1,5 +1,6 @@ print(bytearray(4)) a = bytearray([1, 2, 200]) +print(type(a)) print(a[0], a[2]) print(a[-1]) print(a)