From 9e1e8cd6428e875eb29be98124ee3b1ba2bace30 Mon Sep 17 00:00:00 2001 From: xbe Date: Wed, 12 Mar 2014 22:57:16 -0700 Subject: [PATCH] Implement str.count and add tests for it. Also modify mp_get_index to accept: 1. Indices that are or evaluate to a boolean. 2. Slice indices. Add tests for these two cases. --- py/obj.c | 32 ++++++++++++++++-------- py/obj.h | 2 +- py/objarray.c | 4 +-- py/objlist.c | 6 ++--- py/objstr.c | 48 +++++++++++++++++++++++++++++++++--- py/objtuple.c | 2 +- py/sequence.c | 4 +-- tests/basics/list_index.py | 10 ++++++++ tests/basics/string_count.py | 21 ++++++++++++++++ tests/basics/string_find.py | 12 +++++++++ 10 files changed, 119 insertions(+), 22 deletions(-) create mode 100644 tests/basics/string_count.py diff --git a/py/obj.c b/py/obj.c index 0c97ee5aa..bdbf8f797 100644 --- a/py/obj.c +++ b/py/obj.c @@ -218,20 +218,32 @@ mp_obj_t *mp_obj_get_array_fixed_n(mp_obj_t o_in, machine_int_t n) { } } -uint mp_get_index(const mp_obj_type_t *type, machine_uint_t len, mp_obj_t index) { - // TODO False and True are considered 0 and 1 for indexing purposes +// is_slice determines whether the index is a slice index +uint mp_get_index(const mp_obj_type_t *type, machine_uint_t len, mp_obj_t index, bool is_slice) { + int i; if (MP_OBJ_IS_SMALL_INT(index)) { - int i = MP_OBJ_SMALL_INT_VALUE(index); - if (i < 0) { - i += len; - } - if (i < 0 || i >= len) { - nlr_jump(mp_obj_new_exception_msg_varg(&mp_type_IndexError, "%s index out of range", qstr_str(type->name))); - } - return i; + i = MP_OBJ_SMALL_INT_VALUE(index); + } else if (MP_OBJ_IS_TYPE(index, &bool_type)) { + i = index == mp_const_true ? 1 : 0; } else { nlr_jump(mp_obj_new_exception_msg_varg(&mp_type_TypeError, "%s indices must be integers, not %s", qstr_str(type->name), mp_obj_get_type_str(index))); } + + if (i < 0) { + i += len; + } + if (is_slice) { + if (i < 0) { + i = 0; + } else if (i > len) { + i = len; + } + } else { + if (i < 0 || i >= len) { + nlr_jump(mp_obj_new_exception_msg_varg(&mp_type_IndexError, "%s index out of range", qstr_str(type->name))); + } + } + return i; } // may return MP_OBJ_NULL diff --git a/py/obj.h b/py/obj.h index d41db37c0..dd80b3f02 100644 --- a/py/obj.h +++ b/py/obj.h @@ -267,7 +267,7 @@ void mp_obj_get_complex(mp_obj_t self_in, mp_float_t *real, mp_float_t *imag); #endif //qstr mp_obj_get_qstr(mp_obj_t arg); mp_obj_t *mp_obj_get_array_fixed_n(mp_obj_t o, machine_int_t n); -uint mp_get_index(const mp_obj_type_t *type, machine_uint_t len, mp_obj_t index); +uint mp_get_index(const mp_obj_type_t *type, machine_uint_t len, mp_obj_t index, bool is_slice); mp_obj_t mp_obj_len_maybe(mp_obj_t o_in); /* may return NULL */ // none diff --git a/py/objarray.c b/py/objarray.c index 30a218311..d0b3e003b 100644 --- a/py/objarray.c +++ b/py/objarray.c @@ -113,7 +113,7 @@ STATIC mp_obj_t array_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) { switch (op) { case RT_BINARY_OP_SUBSCR: { - uint index = mp_get_index(o->base.type, o->len, rhs); + uint index = mp_get_index(o->base.type, o->len, rhs, false); return mp_binary_get_val(o->typecode, o->items, index); } @@ -140,7 +140,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(array_append_obj, array_append); STATIC bool array_store_item(mp_obj_t self_in, mp_obj_t index_in, mp_obj_t value) { mp_obj_array_t *o = self_in; - uint index = mp_get_index(o->base.type, o->len, index_in); + uint index = mp_get_index(o->base.type, o->len, index_in, false); mp_binary_set_val(o->typecode, o->items, index, value); return true; } diff --git a/py/objlist.c b/py/objlist.c index a6fbe4e42..aa082ea34 100644 --- a/py/objlist.c +++ b/py/objlist.c @@ -104,7 +104,7 @@ STATIC mp_obj_t list_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) { return res; } #endif - uint index = mp_get_index(o->base.type, o->len, rhs); + uint index = mp_get_index(o->base.type, o->len, rhs, false); return o->items[index]; } case RT_BINARY_OP_ADD: @@ -190,7 +190,7 @@ STATIC mp_obj_t list_pop(uint n_args, const mp_obj_t *args) { if (self->len == 0) { nlr_jump(mp_obj_new_exception_msg(&mp_type_IndexError, "pop from empty list")); } - uint index = mp_get_index(self->base.type, self->len, n_args == 1 ? mp_obj_new_int(-1) : args[1]); + uint index = mp_get_index(self->base.type, self->len, n_args == 1 ? mp_obj_new_int(-1) : args[1], false); mp_obj_t ret = self->items[index]; self->len -= 1; memcpy(self->items + index, self->items + index + 1, (self->len - index) * sizeof(mp_obj_t)); @@ -383,7 +383,7 @@ void mp_obj_list_get(mp_obj_t self_in, uint *len, mp_obj_t **items) { void mp_obj_list_store(mp_obj_t self_in, mp_obj_t index, mp_obj_t value) { mp_obj_list_t *self = self_in; - uint i = mp_get_index(self->base.type, self->len, index); + uint i = mp_get_index(self->base.type, self->len, index, false); self->items[i] = value; } diff --git a/py/objstr.c b/py/objstr.c index 130c3266a..c5c7f87f6 100644 --- a/py/objstr.c +++ b/py/objstr.c @@ -107,7 +107,7 @@ STATIC mp_obj_t str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { // TODO: need predicate to check for int-like type (bools are such for example) // ["no", "yes"][1 == 2] is common idiom if (MP_OBJ_IS_SMALL_INT(rhs_in)) { - uint index = mp_get_index(mp_obj_get_type(lhs_in), lhs_len, rhs_in); + uint index = mp_get_index(mp_obj_get_type(lhs_in), lhs_len, rhs_in, false); if (MP_OBJ_IS_TYPE(lhs_in, &bytes_type)) { return MP_OBJ_NEW_SMALL_INT((mp_small_int_t)lhs_data[index]); } else { @@ -290,10 +290,10 @@ STATIC mp_obj_t str_find(uint n_args, const mp_obj_t *args) { size_t end = haystack_len; /* TODO use a non-exception-throwing mp_get_index */ if (n_args >= 3 && args[2] != mp_const_none) { - start = mp_get_index(&str_type, haystack_len, args[2]); + start = mp_get_index(&str_type, haystack_len, args[2], true); } if (n_args >= 4 && args[3] != mp_const_none) { - end = mp_get_index(&str_type, haystack_len, args[3]); + end = mp_get_index(&str_type, haystack_len, args[3], true); } const byte *p = find_subbytes(haystack + start, haystack_len - start, needle, needle_len); @@ -487,6 +487,46 @@ STATIC mp_obj_t str_replace(uint n_args, const mp_obj_t *args) { return mp_obj_str_builder_end(replaced_str); } +STATIC mp_obj_t str_count(uint n_args, const mp_obj_t *args) { + assert(2 <= n_args && n_args <= 4); + assert(MP_OBJ_IS_STR(args[0])); + assert(MP_OBJ_IS_STR(args[1])); + + GET_STR_DATA_LEN(args[0], haystack, haystack_len); + GET_STR_DATA_LEN(args[1], needle, needle_len); + + size_t start = 0; + size_t end = haystack_len; + /* TODO use a non-exception-throwing mp_get_index */ + if (n_args >= 3 && args[2] != mp_const_none) { + start = mp_get_index(&str_type, haystack_len, args[2], true); + } + if (n_args >= 4 && args[3] != mp_const_none) { + end = mp_get_index(&str_type, haystack_len, args[3], true); + } + + machine_int_t num_occurrences = 0; + + // needle won't exist in haystack if it's longer, so nothing to count + if (needle_len > haystack_len) { + MP_OBJ_NEW_SMALL_INT(0); + } + + for (machine_uint_t haystack_index = start; haystack_index <= end; haystack_index++) { + for (machine_uint_t needle_index = 0; needle_index < needle_len; needle_index++) { + if ((haystack_index + needle_len) > end) { + return MP_OBJ_NEW_SMALL_INT(num_occurrences); + } + if (haystack[haystack_index + needle_index] == needle[needle_index] && needle_index == (needle_len - 1)) { + num_occurrences++; + } + + } + } + + return MP_OBJ_NEW_SMALL_INT(num_occurrences); +} + STATIC machine_int_t str_get_buffer(mp_obj_t self_in, buffer_info_t *bufinfo, int flags) { if (flags == BUFFER_READ) { GET_STR_DATA_LEN(self_in, str_data, str_len); @@ -508,6 +548,7 @@ STATIC MP_DEFINE_CONST_FUN_OBJ_2(str_startswith_obj, str_startswith); STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(str_strip_obj, 1, 2, str_strip); STATIC MP_DEFINE_CONST_FUN_OBJ_VAR(str_format_obj, 1, str_format); STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(str_replace_obj, 3, 4, str_replace); +STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(str_count_obj, 2, 4, str_count); STATIC const mp_method_t str_type_methods[] = { { "find", &str_find_obj }, @@ -517,6 +558,7 @@ STATIC const mp_method_t str_type_methods[] = { { "strip", &str_strip_obj }, { "format", &str_format_obj }, { "replace", &str_replace_obj }, + { "count", &str_count_obj }, { NULL, NULL }, // end-of-list sentinel }; diff --git a/py/objtuple.c b/py/objtuple.c index d39b36d9f..827441f70 100644 --- a/py/objtuple.c +++ b/py/objtuple.c @@ -111,7 +111,7 @@ mp_obj_t tuple_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) { return res; } #endif - uint index = mp_get_index(o->base.type, o->len, rhs); + uint index = mp_get_index(o->base.type, o->len, rhs, false); return o->items[index]; } case RT_BINARY_OP_ADD: diff --git a/py/sequence.c b/py/sequence.c index d3c3e3285..1723bb416 100644 --- a/py/sequence.c +++ b/py/sequence.c @@ -149,9 +149,9 @@ mp_obj_t mp_seq_index_obj(const mp_obj_t *items, uint len, uint n_args, const mp uint stop = len; if (n_args >= 3) { - start = mp_get_index(type, len, args[2]); + start = mp_get_index(type, len, args[2], true); if (n_args >= 4) { - stop = mp_get_index(type, len, args[3]); + stop = mp_get_index(type, len, args[3], true); } } diff --git a/tests/basics/list_index.py b/tests/basics/list_index.py index f28263fba..a669e69c4 100644 --- a/tests/basics/list_index.py +++ b/tests/basics/list_index.py @@ -3,6 +3,16 @@ print(a.index(1)) print(a.index(2)) print(a.index(3)) print(a.index(3, 2)) +print(a.index(1, -100)) +print(a.index(1, False)) + +try: + print(a.index(1, True)) +except ValueError: + print("Raised ValueError") +else: + print("Did not raise ValueError") + try: print(a.index(3, 2, 2)) except ValueError: diff --git a/tests/basics/string_count.py b/tests/basics/string_count.py new file mode 100644 index 000000000..42f807c93 --- /dev/null +++ b/tests/basics/string_count.py @@ -0,0 +1,21 @@ +print("asdfasdfaaa".count("asdf", -100)) +print("asdfasdfaaa".count("asdf", -8)) +print("asdf".count('s', True)) +print("asdf".count('a', True)) +print("asdf".count('a', False)) +print("asdf".count('a', 1 == 2)) +print("hello world".count('l')) +print("hello world".count('l', 5)) +print("hello world".count('l', 3)) +print("hello world".count('z', 3, 6)) +print("aaaa".count('a')) +print("aaaa".count('a', 0, 3)) +print("aaaa".count('a', 0, 4)) +print("aaaa".count('a', 0, 5)) +print("aaaa".count('a', 1, 5)) +print("aaaa".count('a', -1, 5)) + +def t(): + return True + +print("0000".count('0', t())) diff --git a/tests/basics/string_find.py b/tests/basics/string_find.py index 90063228f..df65fd6e5 100644 --- a/tests/basics/string_find.py +++ b/tests/basics/string_find.py @@ -9,3 +9,15 @@ print("hello world".find("ll", 1, 2)) print("hello world".find("ll", 1, 3)) print("hello world".find("ll", 1, 4)) print("hello world".find("ll", 1, 5)) +print("hello world".find("ll", -100)) +print("0000".find('0')) +print("0000".find('0', 0)) +print("0000".find('0', 1)) +print("0000".find('0', 2)) +print("0000".find('0', 3)) +print("0000".find('0', 4)) +print("0000".find('0', 5)) +print("0000".find('-1', 3)) +print("0000".find('1', 3)) +print("0000".find('1', 4)) +print("0000".find('1', 5))