diff --git a/py/compile.c b/py/compile.c index bb7c1117f..4f91ca49b 100644 --- a/py/compile.c +++ b/py/compile.c @@ -1768,18 +1768,21 @@ STATIC void compile_await_object_method(compiler_t *comp, qstr method) { } STATIC void compile_async_for_stmt(compiler_t *comp, mp_parse_node_struct_t *pns) { - // comp->break_label |= MP_EMIT_BREAK_FROM_FOR; - - qstr context = MP_PARSE_NODE_LEAF_ARG(pns->nodes[1]); + // Allocate labels. uint while_else_label = comp_next_label(comp); uint try_exception_label = comp_next_label(comp); uint try_else_label = comp_next_label(comp); uint try_finally_label = comp_next_label(comp); + // Stack: (...) + + // Compile the iterator expression and load and call its __aiter__ method. compile_node(comp, pns->nodes[1]); // iterator + // Stack: (..., iterator) EMIT_ARG(load_method, MP_QSTR___aiter__, false); + // Stack: (..., iterator, __aiter__) EMIT_ARG(call_method, 0, 0, 0); - compile_store_id(comp, context); + // Stack: (..., iterable) START_BREAK_CONTINUE_BLOCK @@ -1787,9 +1790,15 @@ STATIC void compile_async_for_stmt(compiler_t *comp, mp_parse_node_struct_t *pns compile_increase_except_level(comp, try_exception_label, MP_EMIT_SETUP_BLOCK_EXCEPT); - compile_load_id(comp, context); + EMIT(dup_top); + // Stack: (..., iterable, iterable) + + // Compile: yield from iterable.__anext__() compile_await_object_method(comp, MP_QSTR___anext__); + // Stack: (..., iterable, yielded_value) + c_assign(comp, pns->nodes[0], ASSIGN_STORE); // variable + // Stack: (..., iterable) EMIT_ARG(pop_except_jump, try_else_label, false); EMIT_ARG(label_assign, try_exception_label); @@ -1806,6 +1815,8 @@ STATIC void compile_async_for_stmt(compiler_t *comp, mp_parse_node_struct_t *pns compile_decrease_except_level(comp); EMIT(end_except_handler); + // Stack: (..., iterable) + EMIT_ARG(label_assign, try_else_label); compile_node(comp, pns->nodes[2]); // body @@ -1817,6 +1828,10 @@ STATIC void compile_async_for_stmt(compiler_t *comp, mp_parse_node_struct_t *pns compile_node(comp, pns->nodes[3]); // else EMIT_ARG(label_assign, break_label); + // Stack: (..., iterable) + + EMIT(pop_top); + // Stack: (...) } STATIC void compile_async_with_stmt_helper(compiler_t *comp, size_t n, mp_parse_node_t *nodes, mp_parse_node_t body) { diff --git a/tests/basics/async_for.py b/tests/basics/async_for.py index 5fd054082..f54f70238 100644 --- a/tests/basics/async_for.py +++ b/tests/basics/async_for.py @@ -1,29 +1,75 @@ # test basic async for execution # example taken from PEP0492 + class AsyncIteratorWrapper: def __init__(self, obj): - print('init') - self._it = iter(obj) + print("init") + self._obj = obj + + def __repr__(self): + return "AsyncIteratorWrapper-" + self._obj def __aiter__(self): - print('aiter') - return self + print("aiter") + return AsyncIteratorWrapperIterator(self._obj) + + +class AsyncIteratorWrapperIterator: + def __init__(self, obj): + print("init") + self._it = iter(obj) async def __anext__(self): - print('anext') + print("anext") try: value = next(self._it) except StopIteration: raise StopAsyncIteration return value -async def coro(): - async for letter in AsyncIteratorWrapper('abc'): + +def run_coro(c): + print("== start ==") + try: + c.send(None) + except StopIteration: + print("== finish ==") + + +async def coro0(): + async for letter in AsyncIteratorWrapper("abc"): print(letter) -o = coro() -try: - o.send(None) -except StopIteration: - print('finished') + +run_coro(coro0()) + + +async def coro1(): + a = AsyncIteratorWrapper("def") + async for letter in a: + print(letter) + print(a) + + +run_coro(coro1()) + +a_global = AsyncIteratorWrapper("ghi") + + +async def coro2(): + async for letter in a_global: + print(letter) + print(a_global) + + +run_coro(coro2()) + + +async def coro3(a): + async for letter in a: + print(letter) + print(a) + + +run_coro(coro3(AsyncIteratorWrapper("jkl"))) diff --git a/tests/basics/async_for.py.exp b/tests/basics/async_for.py.exp index 1f728a66c..6f59979c0 100644 --- a/tests/basics/async_for.py.exp +++ b/tests/basics/async_for.py.exp @@ -1,5 +1,7 @@ +== start == init aiter +init anext a anext @@ -7,4 +9,43 @@ b anext c anext -finished +== finish == +== start == +init +aiter +init +anext +d +anext +e +anext +f +anext +AsyncIteratorWrapper-def +== finish == +init +== start == +aiter +init +anext +g +anext +h +anext +i +anext +AsyncIteratorWrapper-ghi +== finish == +init +== start == +aiter +init +anext +j +anext +k +anext +l +anext +AsyncIteratorWrapper-jkl +== finish ==