py/parse: Improve constant folding to operate on small and big ints.

Constant folding in the parser can now operate on big ints, whatever
their representation.  This is now possible because the parser can create
parse nodes holding arbitrary objects.  For the case of small ints the
folding is still efficient in RAM because the folded small int is stored
inplace in the parse node.

Adds 48 bytes to code size on Thumb2 architecture.  Helps reduce heap
usage because more constants can be computed at compile time, leading to
a smaller parse tree, and most importantly means that the constants don't
have to be computed at runtime (perhaps more than once).  Parser will now
be a little slower when folding due to calls to runtime to do the
arithmetic.
This commit is contained in:
Damien George 2016-01-07 14:40:35 +00:00
parent d6b31e4578
commit 22b2265053
3 changed files with 85 additions and 83 deletions

View File

@ -34,8 +34,9 @@
#include "py/lexer.h"
#include "py/parse.h"
#include "py/parsenum.h"
#include "py/smallint.h"
#include "py/runtime0.h"
#include "py/runtime.h"
#include "py/objint.h"
#include "py/builtin.h"
#if MICROPY_ENABLE_COMPILER
@ -234,6 +235,24 @@ mp_parse_node_t mp_parse_node_new_leaf(size_t kind, mp_int_t arg) {
return (mp_parse_node_t)(kind | (arg << 4));
}
bool mp_parse_node_get_int_maybe(mp_parse_node_t pn, mp_obj_t *o) {
if (MP_PARSE_NODE_IS_SMALL_INT(pn)) {
*o = MP_OBJ_NEW_SMALL_INT(MP_PARSE_NODE_LEAF_SMALL_INT(pn));
return true;
} else if (MP_PARSE_NODE_IS_STRUCT_KIND(pn, RULE_const_object)) {
mp_parse_node_struct_t *pns = (mp_parse_node_struct_t*)pn;
#if MICROPY_OBJ_REPR == MICROPY_OBJ_REPR_D
// nodes are 32-bit pointers, but need to extract 64-bit object
*o = (uint64_t)pns->nodes[0] | ((uint64_t)pns->nodes[1] << 32);
#else
*o = (mp_obj_t)pns->nodes[0];
#endif
return MP_OBJ_IS_INT(*o);
} else {
return false;
}
}
int mp_parse_node_extract_list(mp_parse_node_t *pn, size_t pn_kind, mp_parse_node_t **nodes) {
if (MP_PARSE_NODE_IS_NULL(*pn)) {
*nodes = NULL;
@ -445,119 +464,94 @@ STATIC bool fold_constants(parser_t *parser, const rule_t *rule, size_t num_args
// this code does folding of arbitrary integer expressions, eg 1 + 2 * 3 + 4
// it does not do partial folding, eg 1 + 2 + x -> 3 + x
mp_int_t arg0;
mp_obj_t arg0;
if (rule->rule_id == RULE_expr
|| rule->rule_id == RULE_xor_expr
|| rule->rule_id == RULE_and_expr) {
// folding for binary ops: | ^ &
mp_parse_node_t pn = peek_result(parser, num_args - 1);
if (!MP_PARSE_NODE_IS_SMALL_INT(pn)) {
if (!mp_parse_node_get_int_maybe(pn, &arg0)) {
return false;
}
arg0 = MP_PARSE_NODE_LEAF_SMALL_INT(pn);
mp_binary_op_t op;
if (rule->rule_id == RULE_expr) {
op = MP_BINARY_OP_OR;
} else if (rule->rule_id == RULE_xor_expr) {
op = MP_BINARY_OP_XOR;
} else {
op = MP_BINARY_OP_AND;
}
for (ssize_t i = num_args - 2; i >= 0; --i) {
pn = peek_result(parser, i);
if (!MP_PARSE_NODE_IS_SMALL_INT(pn)) {
mp_obj_t arg1;
if (!mp_parse_node_get_int_maybe(pn, &arg1)) {
return false;
}
mp_int_t arg1 = MP_PARSE_NODE_LEAF_SMALL_INT(pn);
if (rule->rule_id == RULE_expr) {
// int | int
arg0 |= arg1;
} else if (rule->rule_id == RULE_xor_expr) {
// int ^ int
arg0 ^= arg1;
} else if (rule->rule_id == RULE_and_expr) {
// int & int
arg0 &= arg1;
}
arg0 = mp_binary_op(op, arg0, arg1);
}
} else if (rule->rule_id == RULE_shift_expr
|| rule->rule_id == RULE_arith_expr
|| rule->rule_id == RULE_term) {
// folding for binary ops: << >> + - * / % //
mp_parse_node_t pn = peek_result(parser, num_args - 1);
if (!MP_PARSE_NODE_IS_SMALL_INT(pn)) {
if (!mp_parse_node_get_int_maybe(pn, &arg0)) {
return false;
}
arg0 = MP_PARSE_NODE_LEAF_SMALL_INT(pn);
for (ssize_t i = num_args - 2; i >= 1; i -= 2) {
pn = peek_result(parser, i - 1);
if (!MP_PARSE_NODE_IS_SMALL_INT(pn)) {
mp_obj_t arg1;
if (!mp_parse_node_get_int_maybe(pn, &arg1)) {
return false;
}
mp_int_t arg1 = MP_PARSE_NODE_LEAF_SMALL_INT(pn);
mp_token_kind_t tok = MP_PARSE_NODE_LEAF_ARG(peek_result(parser, i));
if (tok == MP_TOKEN_OP_DBL_LESS) {
// int << int
if (arg1 >= (mp_int_t)BITS_PER_WORD
|| arg0 > (MP_SMALL_INT_MAX >> arg1)
|| arg0 < (MP_SMALL_INT_MIN >> arg1)) {
return false;
}
arg0 <<= arg1;
} else if (tok == MP_TOKEN_OP_DBL_MORE) {
// int >> int
if (arg1 >= (mp_int_t)BITS_PER_WORD) {
// Shifting to big amounts is underfined behavior
// in C and is CPU-dependent; propagate sign bit.
arg1 = BITS_PER_WORD - 1;
}
arg0 >>= arg1;
} else if (tok == MP_TOKEN_OP_PLUS) {
// int + int
arg0 += arg1;
} else if (tok == MP_TOKEN_OP_MINUS) {
// int - int
arg0 -= arg1;
} else if (tok == MP_TOKEN_OP_STAR) {
// int * int
if (mp_small_int_mul_overflow(arg0, arg1)) {
return false;
}
arg0 *= arg1;
} else if (tok == MP_TOKEN_OP_SLASH) {
// int / int
return false;
} else if (tok == MP_TOKEN_OP_PERCENT) {
// int % int
if (arg1 == 0) {
return false;
}
arg0 = mp_small_int_modulo(arg0, arg1);
} else {
assert(tok == MP_TOKEN_OP_DBL_SLASH); // should be
// int // int
if (arg1 == 0) {
return false;
}
arg0 = mp_small_int_floor_divide(arg0, arg1);
}
if (!MP_SMALL_INT_FITS(arg0)) {
static const uint8_t token_to_op[] = {
MP_BINARY_OP_ADD,
MP_BINARY_OP_SUBTRACT,
MP_BINARY_OP_MULTIPLY,
255,//MP_BINARY_OP_POWER,
255,//MP_BINARY_OP_TRUE_DIVIDE,
MP_BINARY_OP_FLOOR_DIVIDE,
MP_BINARY_OP_MODULO,
255,//MP_BINARY_OP_LESS
MP_BINARY_OP_LSHIFT,
255,//MP_BINARY_OP_MORE
MP_BINARY_OP_RSHIFT,
};
mp_binary_op_t op = token_to_op[tok - MP_TOKEN_OP_PLUS];
if (op == 255) {
return false;
}
int rhs_sign = mp_obj_int_sign(arg1);
if (op <= MP_BINARY_OP_RSHIFT) {
// << and >> can't have negative rhs
if (rhs_sign < 0) {
return false;
}
} else if (op >= MP_BINARY_OP_FLOOR_DIVIDE) {
// % and // can't have zero rhs
if (rhs_sign == 0) {
return false;
}
}
arg0 = mp_binary_op(op, arg0, arg1);
}
} else if (rule->rule_id == RULE_factor_2) {
// folding for unary ops: + - ~
mp_parse_node_t pn = peek_result(parser, 0);
if (!MP_PARSE_NODE_IS_SMALL_INT(pn)) {
if (!mp_parse_node_get_int_maybe(pn, &arg0)) {
return false;
}
arg0 = MP_PARSE_NODE_LEAF_SMALL_INT(pn);
mp_token_kind_t tok = MP_PARSE_NODE_LEAF_ARG(peek_result(parser, 1));
mp_binary_op_t op;
if (tok == MP_TOKEN_OP_PLUS) {
// +int
op = MP_UNARY_OP_POSITIVE;
} else if (tok == MP_TOKEN_OP_MINUS) {
// -int
arg0 = -arg0;
if (!MP_SMALL_INT_FITS(arg0)) {
return false;
}
op = MP_UNARY_OP_NEGATIVE;
} else {
assert(tok == MP_TOKEN_OP_TILDE); // should be
// ~int
arg0 = ~arg0;
op = MP_UNARY_OP_INVERT;
}
arg0 = mp_unary_op(op, arg0);
#if MICROPY_COMP_CONST
} else if (rule->rule_id == RULE_expr_stmt) {
@ -625,10 +619,10 @@ STATIC bool fold_constants(parser_t *parser, const rule_t *rule, size_t num_args
}
mp_obj_t dest[2];
mp_load_method_maybe(elem->value, q_attr, dest);
if (!(MP_OBJ_IS_SMALL_INT(dest[0]) && dest[1] == MP_OBJ_NULL)) {
if (!(MP_OBJ_IS_INT(dest[0]) && dest[1] == MP_OBJ_NULL)) {
return false;
}
arg0 = MP_OBJ_SMALL_INT_VALUE(dest[0]);
arg0 = dest[0];
#endif
} else {
@ -640,7 +634,12 @@ STATIC bool fold_constants(parser_t *parser, const rule_t *rule, size_t num_args
for (size_t i = num_args; i > 0; i--) {
pop_result(parser);
}
push_result_node(parser, mp_parse_node_new_leaf(MP_PARSE_NODE_SMALL_INT, arg0));
if (MP_OBJ_IS_SMALL_INT(arg0)) {
push_result_node(parser, mp_parse_node_new_leaf(MP_PARSE_NODE_SMALL_INT, MP_OBJ_SMALL_INT_VALUE(arg0)));
} else {
// TODO reuse memory for parse node struct?
push_result_node(parser, make_node_const_object(parser, 0, arg0));
}
return true;
}

View File

@ -29,7 +29,7 @@
#include <stddef.h>
#include <stdint.h>
#include "py/mpconfig.h"
#include "py/obj.h"
struct _mp_lexer_t;
@ -77,6 +77,7 @@ typedef struct _mp_parse_node_struct_t {
#define MP_PARSE_NODE_STRUCT_NUM_NODES(pns) ((pns)->kind_num_nodes >> 8)
mp_parse_node_t mp_parse_node_new_leaf(size_t kind, mp_int_t arg);
bool mp_parse_node_get_int_maybe(mp_parse_node_t pn, mp_obj_t *o);
int mp_parse_node_extract_list(mp_parse_node_t *pn, size_t pn_kind, mp_parse_node_t **nodes);
void mp_parse_node_print(mp_parse_node_t pn, size_t indent);

View File

@ -70,20 +70,22 @@ try:
except NotImplementedError:
print('NotImplementedError')
mpz = 1 << 70
# mpz and with both args negative
try:
-(1<<70) & -2
-mpz & -2
except NotImplementedError:
print('NotImplementedError')
# mpz or with args opposite sign
try:
-(1<<70) | 2
-mpz | 2
except NotImplementedError:
print('NotImplementedError')
# mpz xor with args opposite sign
try:
-(1<<70) ^ 2
-mpz ^ 2
except NotImplementedError:
print('NotImplementedError')