From ab4e197707cd456499d680091fdc9daeff18c62a Mon Sep 17 00:00:00 2001 From: Thorsten von Eicken Date: Tue, 28 Apr 2020 10:58:43 -0700 Subject: [PATCH] esp32/modsocket: Fix getaddrinfo to raise on error. This commit fixes the behaviour of socket.getaddrinfo on the ESP32 so it raises an OSError when the name resolution fails instead of returning a [] or a resolution for 0.0.0.0. Tests are added (generic and ESP32-specific) to verify behaviour consistent with CPython, modulo the different types of exceptions per MicroPython documentation. --- ports/esp32/modsocket.c | 19 ++++++---- tests/esp32/resolve_on_connect.py | 59 +++++++++++++++++++++++++++++++ tests/net_inet/getaddrinfo.py | 52 +++++++++++++++++++++++++++ 3 files changed, 123 insertions(+), 7 deletions(-) create mode 100644 tests/esp32/resolve_on_connect.py create mode 100644 tests/net_inet/getaddrinfo.py diff --git a/ports/esp32/modsocket.c b/ports/esp32/modsocket.c index 69a74ec25..85433e575 100644 --- a/ports/esp32/modsocket.c +++ b/ports/esp32/modsocket.c @@ -244,19 +244,24 @@ static int _socket_getaddrinfo2(const mp_obj_t host, const mp_obj_t portx, struc int res = _socket_getaddrinfo3(host_str, port_str, &hints, resp); MP_THREAD_GIL_ENTER(); + // Per docs: instead of raising gaierror getaddrinfo raises negative error number + if (res != 0) { + mp_raise_OSError(res > 0 ? -res : res); + } + // Somehow LwIP returns a resolution of 0.0.0.0 for failed lookups, traced it as far back + // as netconn_gethostbyname_addrtype returning OK instead of error. + if (*resp == NULL || + (strcmp(resp[0]->ai_canonname, "0.0.0.0") == 0 && strcmp(host_str, "0.0.0.0") != 0)) { + mp_raise_OSError(-2); // name or service not known + } + return res; } STATIC void _socket_getaddrinfo(const mp_obj_t addrtuple, struct addrinfo **resp) { mp_obj_t *elem; mp_obj_get_array_fixed_n(addrtuple, 2, &elem); - int res = _socket_getaddrinfo2(elem[0], elem[1], resp); - if (res != 0) { - mp_raise_OSError(res); - } - if (*resp == NULL) { - mp_raise_OSError(-2); // name or service not known - } + _socket_getaddrinfo2(elem[0], elem[1], resp); } STATIC mp_obj_t socket_make_new(const mp_obj_type_t *type_in, size_t n_args, size_t n_kw, const mp_obj_t *args) { diff --git a/tests/esp32/resolve_on_connect.py b/tests/esp32/resolve_on_connect.py new file mode 100644 index 000000000..068757ab2 --- /dev/null +++ b/tests/esp32/resolve_on_connect.py @@ -0,0 +1,59 @@ +# Test that the esp32's socket module performs DNS resolutions on bind and connect +import sys + +if sys.implementation.name == "micropython" and sys.platform != "esp32": + print("SKIP") + raise SystemExit + +try: + import usocket as socket, sys +except: + import socket, sys + + +def test_bind_resolves_0_0_0_0(): + s = socket.socket() + try: + s.bind(("0.0.0.0", 31245)) + print("bind actually bound") + s.close() + except Exception as e: + print("bind raised", e) + + +def test_bind_resolves_localhost(): + s = socket.socket() + try: + s.bind(("localhost", 31245)) + print("bind actually bound") + s.close() + except Exception as e: + print("bind raised", e) + + +def test_connect_resolves(): + s = socket.socket() + try: + s.connect(("micropython.org", 80)) + print("connect actually connected") + s.close() + except Exception as e: + print("connect raised", e) + + +def test_connect_non_existent(): + s = socket.socket() + try: + s.connect(("nonexistent.example.com", 80)) + print("connect actually connected") + s.close() + except OSError as e: + print("connect raised OSError") + except Exception as e: + print("connect raised", e) + + +test_funs = [n for n in dir() if n.startswith("test_")] +for f in sorted(test_funs): + print("--", f, end=": ") + eval(f + "()") diff --git a/tests/net_inet/getaddrinfo.py b/tests/net_inet/getaddrinfo.py new file mode 100644 index 000000000..765723ae7 --- /dev/null +++ b/tests/net_inet/getaddrinfo.py @@ -0,0 +1,52 @@ +try: + import usocket as socket, sys +except: + import socket, sys + + +def test_non_existent(): + try: + res = socket.getaddrinfo("nonexistent.example.com", 80) + print("getaddrinfo returned", res) + except OSError as e: + print("getaddrinfo raised") + + +def test_bogus(): + try: + res = socket.getaddrinfo("hey.!!$$", 80) + print("getaddrinfo returned", res) + except OSError as e: + print("getaddrinfo raised") + except Exception as e: + print("getaddrinfo raised") # CPython raises UnicodeError!? + + +def test_ip_addr(): + try: + res = socket.getaddrinfo("10.10.10.10", 80) + print("getaddrinfo returned resolutions") + except Exception as e: + print("getaddrinfo raised", e) + + +def test_0_0_0_0(): + try: + res = socket.getaddrinfo("0.0.0.0", 80) + print("getaddrinfo returned resolutions") + except Exception as e: + print("getaddrinfo raised", e) + + +def test_valid(): + try: + res = socket.getaddrinfo("micropython.org", 80) + print("getaddrinfo returned resolutions") + except Exception as e: + print("getaddrinfo raised", e) + + +test_funs = [n for n in dir() if n.startswith("test_")] +for f in sorted(test_funs): + print("--", f, end=": ") + eval(f + "()")