diff --git a/py/compile.c b/py/compile.c index 76d4c1bf5..7a1660b1b 100644 --- a/py/compile.c +++ b/py/compile.c @@ -1274,6 +1274,14 @@ STATIC void compile_declare_nonlocal(compiler_t *comp, mp_parse_node_t pn, id_in } } +STATIC void compile_declare_global_or_nonlocal(compiler_t *comp, mp_parse_node_t pn, id_info_t *id_info, bool is_global) { + if (is_global) { + compile_declare_global(comp, pn, id_info); + } else { + compile_declare_nonlocal(comp, pn, id_info); + } +} + STATIC void compile_global_nonlocal_stmt(compiler_t *comp, mp_parse_node_struct_t *pns) { if (comp->pass == MP_PASS_SCOPE) { bool is_global = MP_PARSE_NODE_STRUCT_KIND(pns) == PN_global_stmt; @@ -1288,11 +1296,7 @@ STATIC void compile_global_nonlocal_stmt(compiler_t *comp, mp_parse_node_struct_ for (size_t i = 0; i < n; i++) { qstr qst = MP_PARSE_NODE_LEAF_ARG(nodes[i]); id_info_t *id_info = scope_find_or_add_id(comp->scope_cur, qst, ID_INFO_KIND_UNDECIDED); - if (is_global) { - compile_declare_global(comp, (mp_parse_node_t)pns, id_info); - } else { - compile_declare_nonlocal(comp, (mp_parse_node_t)pns, id_info); - } + compile_declare_global_or_nonlocal(comp, (mp_parse_node_t)pns, id_info, is_global); } } } @@ -2133,13 +2137,30 @@ STATIC void compile_namedexpr_helper(compiler_t *comp, mp_parse_node_t pn_name, } compile_node(comp, pn_expr); EMIT(dup_top); - scope_t *old_scope = comp->scope_cur; - if (SCOPE_IS_COMP_LIKE(comp->scope_cur->kind)) { - // Use parent's scope for assigned value so it can "escape" - comp->scope_cur = comp->scope_cur->parent; + + qstr target = MP_PARSE_NODE_LEAF_ARG(pn_name); + + // When a variable is assigned via := in a comprehension then that variable is bound to + // the parent scope. Any global or nonlocal declarations in the parent scope are honoured. + // For details see: https://peps.python.org/pep-0572/#scope-of-the-target + if (comp->pass == MP_PASS_SCOPE && SCOPE_IS_COMP_LIKE(comp->scope_cur->kind)) { + id_info_t *id_info_parent = mp_emit_common_get_id_for_modification(comp->scope_cur->parent, target); + if (id_info_parent->kind == ID_INFO_KIND_GLOBAL_EXPLICIT) { + scope_find_or_add_id(comp->scope_cur, target, ID_INFO_KIND_GLOBAL_EXPLICIT); + } else { + id_info_t *id_info = scope_find_or_add_id(comp->scope_cur, target, ID_INFO_KIND_UNDECIDED); + bool is_global = comp->scope_cur->parent->parent == NULL; // comprehension is defined in outer scope + if (!is_global && id_info->kind == ID_INFO_KIND_GLOBAL_IMPLICIT) { + // Variable was already referenced but now needs to be closed over, so reset the kind + // such that scope_check_to_close_over() is called in compile_declare_nonlocal(). + id_info->kind = ID_INFO_KIND_UNDECIDED; + } + compile_declare_global_or_nonlocal(comp, pn_name, id_info, is_global); + } } - compile_store_id(comp, MP_PARSE_NODE_LEAF_ARG(pn_name)); - comp->scope_cur = old_scope; + + // Do the store to the target variable. + compile_store_id(comp, target); } STATIC void compile_namedexpr(compiler_t *comp, mp_parse_node_struct_t *pns) { diff --git a/py/emit.h b/py/emit.h index 4e8a55e77..26f978ba5 100644 --- a/py/emit.h +++ b/py/emit.h @@ -191,7 +191,7 @@ static inline void mp_emit_common_get_id_for_load(scope_t *scope, qstr qst) { scope_find_or_add_id(scope, qst, ID_INFO_KIND_GLOBAL_IMPLICIT); } -void mp_emit_common_get_id_for_modification(scope_t *scope, qstr qst); +id_info_t *mp_emit_common_get_id_for_modification(scope_t *scope, qstr qst); void mp_emit_common_id_op(emit_t *emit, const mp_emit_method_table_id_ops_t *emit_method_table, scope_t *scope, qstr qst); extern const emit_method_table_t emit_bc_method_table; diff --git a/py/emitcommon.c b/py/emitcommon.c index 679ef1d97..a9eb6e202 100644 --- a/py/emitcommon.c +++ b/py/emitcommon.c @@ -86,7 +86,7 @@ size_t mp_emit_common_use_const_obj(mp_emit_common_t *emit, mp_obj_t const_obj) return emit->const_obj_list.len - 1; } -void mp_emit_common_get_id_for_modification(scope_t *scope, qstr qst) { +id_info_t *mp_emit_common_get_id_for_modification(scope_t *scope, qstr qst) { // name adding/lookup id_info_t *id = scope_find_or_add_id(scope, qst, ID_INFO_KIND_GLOBAL_IMPLICIT); if (id->kind == ID_INFO_KIND_GLOBAL_IMPLICIT) { @@ -98,6 +98,7 @@ void mp_emit_common_get_id_for_modification(scope_t *scope, qstr qst) { id->kind = ID_INFO_KIND_GLOBAL_IMPLICIT_ASSIGNED; } } + return id; } void mp_emit_common_id_op(emit_t *emit, const mp_emit_method_table_id_ops_t *emit_method_table, scope_t *scope, qstr qst) { diff --git a/tests/basics/assign_expr_scope.py b/tests/basics/assign_expr_scope.py new file mode 100644 index 000000000..69dc9f0d1 --- /dev/null +++ b/tests/basics/assign_expr_scope.py @@ -0,0 +1,81 @@ +# Test scoping rules for assignment expression :=. + +# Test that var is closed over (assigned to in the scope of scope0). +def scope0(): + any((var0 := i) for i in range(2)) + return var0 + + +print("scope0") +print(scope0()) +print(globals().get("var0", None)) + +# Test that var1 gets closed over correctly in the list comprehension. +def scope1(): + var1 = 0 + dummy1 = 1 + dummy2 = 1 + print([var1 := i for i in [0, 1] if i > var1]) + print(var1) + + +print("scope1") +scope1() +print(globals().get("var1", None)) + +# Test that var2 in the comprehension honours the global declaration. +def scope2(): + global var2 + print([var2 := i for i in range(2)]) + print(globals().get("var2", None)) + + +print("scope2") +scope2() +print(globals().get("var2", None)) + +# Test that var1 in the comprehension remains local to inner1. +def scope3(): + global var3 + + def inner3(): + print([var3 := i for i in range(2)]) + + inner3() + print(globals().get("var3", None)) + + +print("scope3") +scope3() +print(globals().get("var3", None)) + +# Test that var4 in the comprehension honours the global declarations. +def scope4(): + global var4 + + def inner4(): + global var4 + print([var4 := i for i in range(2)]) + + inner4() + print(var4) + + +print("scope4") +scope4() +print(globals().get("var4", None)) + +# Test that var5 in the comprehension honours the nonlocal declaration. +def scope5(): + def inner5(): + nonlocal var5 + print([var5 := i for i in range(2)]) + + inner5() + print(var5) + var5 = 0 # force var5 to be a local to scope5 + + +print("scope5") +scope5() +print(globals().get("var5", None)) diff --git a/tests/basics/assign_expr_scope.py.exp b/tests/basics/assign_expr_scope.py.exp new file mode 100644 index 000000000..5c780b382 --- /dev/null +++ b/tests/basics/assign_expr_scope.py.exp @@ -0,0 +1,23 @@ +scope0 +1 +None +scope1 +[1] +1 +None +scope2 +[0, 1] +1 +1 +scope3 +[0, 1] +None +None +scope4 +[0, 1] +1 +1 +scope5 +[0, 1] +1 +None diff --git a/tests/basics/assign_expr_syntaxerror.py b/tests/basics/assign_expr_syntaxerror.py index 11b350129..0c334d075 100644 --- a/tests/basics/assign_expr_syntaxerror.py +++ b/tests/basics/assign_expr_syntaxerror.py @@ -8,9 +8,9 @@ def test(code): test("x := 1") test("((x, y) := 1)") - -# these are currently all allowed in MicroPython, but not in CPython test("([i := i + 1 for i in range(4)])") test("([i := -1 for i, j in [(1, 2)]])") test("([[(i := j) for i in range(2)] for j in range(2)])") + +# this is currently allowed in MicroPython, but not in CPython test("([[(j := i) for i in range(2)] for j in range(2)])") diff --git a/tests/basics/assign_expr_syntaxerror.py.exp b/tests/basics/assign_expr_syntaxerror.py.exp index 2ba7d7df8..8b386b2a9 100644 --- a/tests/basics/assign_expr_syntaxerror.py.exp +++ b/tests/basics/assign_expr_syntaxerror.py.exp @@ -1,6 +1,6 @@ SyntaxError SyntaxError -[1, 2, 3, 4] -[-1] -[[0, 0], [1, 1]] +SyntaxError +SyntaxError +SyntaxError [[0, 1], [0, 1]]