libc/winsup/cygserver/transport_sockets.cc

396 lines
9.1 KiB
C++

/* transport_sockets.cc
Copyright 2001, 2002, 2003, 2004 Red Hat Inc.
Written by Robert Collins <rbtcollins@hotmail.com>
This file is part of Cygwin.
This software is a copyrighted work licensed under the terms of the
Cygwin license. Please consult the file "CYGWIN_LICENSE" for
details. */
/* to allow this to link into cygwin and the .dll, a little magic is needed. */
#ifdef __OUTSIDE_CYGWIN__
#include "woutsup.h"
#else
#include "winsup.h"
#endif
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <assert.h>
#include <stdio.h>
#include <unistd.h>
#include "cygerrno.h"
#include "transport.h"
#include "transport_sockets.h"
#ifdef __INSIDE_CYGWIN__
#define SET_ERRNO(err) set_errno (err)
#define GET_ERRNO() get_errno ()
#else
#define SET_ERRNO(err) errno = (err)
#define GET_ERRNO() (errno)
#endif
/* to allow this to link into cygwin and the .dll, a little magic is needed. */
#ifndef __OUTSIDE_CYGWIN__
extern "C" int cygwin_accept (int fd, struct sockaddr *, int *len);
extern "C" int cygwin_bind (int fd, const struct sockaddr *, int len);
extern "C" int cygwin_connect (int fd, const struct sockaddr *, int len);
extern "C" int cygwin_listen (int fd, int backlog);
extern "C" int cygwin_shutdown (int fd, int how);
extern "C" int cygwin_socket (int af, int type, int protocol);
#else /* __OUTSIDE_CYGWIN__ */
#define cygwin_accept(A,B,C) ::accept (A,B,C)
#define cygwin_bind(A,B,C) ::bind (A,B,C)
#define cygwin_connect(A,B,C) ::connect (A,B,C)
#define cygwin_listen(A,B) ::listen (A,B)
#define cygwin_shutdown(A,B) ::shutdown (A,B)
#define cygwin_socket(A,B,C) ::socket (A,B,C)
#endif /* __OUTSIDE_CYGWIN__ */
enum
{
MAX_CONNECT_RETRY = 64
};
transport_layer_sockets::transport_layer_sockets (const int fd)
: _fd (fd),
_addr_len (0),
_is_accepted_endpoint (true),
_is_listening_endpoint (false)
{
assert (_fd != -1);
memset (&_addr, '\0', sizeof (_addr));
}
transport_layer_sockets::transport_layer_sockets ()
: _fd (-1),
_addr_len (0),
_is_accepted_endpoint (false),
_is_listening_endpoint (false)
{
memset (&_addr, '\0', sizeof (_addr));
_addr.sun_family = AF_UNIX;
strcpy (_addr.sun_path, "/tmp/cygdaemo"); // FIXME: $TMP?
_addr_len = SUN_LEN (&_addr);
}
transport_layer_sockets::~transport_layer_sockets ()
{
close ();
}
#ifndef __INSIDE_CYGWIN__
int
transport_layer_sockets::listen ()
{
assert (_fd == -1);
assert (!_is_accepted_endpoint);
assert (!_is_listening_endpoint);
debug_printf ("listen () [this = %p]", this);
struct stat sbuf;
if (stat (_addr.sun_path, &sbuf) == -1)
{
if (GET_ERRNO () != ENOENT)
{
system_printf ("cannot access socket file `%s': %s",
_addr.sun_path, strerror (GET_ERRNO ()));
return -1;
}
}
else if (S_ISSOCK (sbuf.st_mode))
{
// The socket already exists: is a duplicate cygserver running?
const int newfd = cygwin_socket (AF_UNIX, SOCK_STREAM, 0);
if (newfd == -1)
{
system_printf ("failed to create UNIX domain socket: %s",
strerror (GET_ERRNO ()));
return -1;
}
if (cygwin_connect (newfd, (struct sockaddr *) &_addr, _addr_len) == 0)
{
system_printf ("the daemon is already running");
(void) cygwin_shutdown (newfd, SHUT_WR);
char buf[BUFSIZ];
while (::read (newfd, buf, sizeof (buf)) > 0)
{}
(void) ::close (newfd);
return -1;
}
if (unlink (_addr.sun_path) == -1)
{
system_printf ("failed to remove `%s': %s",
_addr.sun_path, strerror (GET_ERRNO ()));
(void) ::close (newfd);
return -1;
}
}
else
{
system_printf ("cannot create socket `%s': File already exists",
_addr.sun_path);
return -1;
}
_fd = cygwin_socket (AF_UNIX, SOCK_STREAM, 0);
if (_fd == -1)
{
system_printf ("failed to create UNIX domain socket: %s",
strerror (GET_ERRNO ()));
return -1;
}
if (cygwin_bind (_fd, (struct sockaddr *) &_addr, _addr_len) == -1)
{
const int saved_errno = GET_ERRNO ();
close ();
SET_ERRNO (saved_errno);
system_printf ("failed to bind UNIX domain socket `%s': %s",
_addr.sun_path, strerror (GET_ERRNO ()));
return -1;
}
_is_listening_endpoint = true; // i.e. this really means "have bound".
if (cygwin_listen (_fd, SOMAXCONN) == -1)
{
const int saved_errno = GET_ERRNO ();
close ();
SET_ERRNO (saved_errno);
system_printf ("failed to listen on UNIX domain socket `%s': %s",
_addr.sun_path, strerror (GET_ERRNO ()));
return -1;
}
debug_printf ("0 = listen () [this = %p, fd = %d]", this, _fd);
return 0;
}
class transport_layer_sockets *
transport_layer_sockets::accept (bool *const recoverable)
{
assert (_fd != -1);
assert (!_is_accepted_endpoint);
assert (_is_listening_endpoint);
debug_printf ("accept () [this = %p, fd = %d]", this, _fd);
struct sockaddr_un client_addr;
socklen_t client_addr_len = sizeof (client_addr);
const int accept_fd =
cygwin_accept (_fd, (struct sockaddr *) &client_addr, &client_addr_len);
if (accept_fd == -1)
{
system_printf ("failed to accept connection: %s", strerror (GET_ERRNO ()));
switch (GET_ERRNO ())
{
case ECONNABORTED:
case EINTR:
case EMFILE:
case ENFILE:
case ENOBUFS:
case ENOMEM:
*recoverable = true;
break;
default:
*recoverable = false;
break;
}
return NULL;
}
debug_printf ("%d = accept () [this = %p, fd = %d]", accept_fd, this, _fd);
return new transport_layer_sockets (accept_fd);
}
#endif /* !__INSIDE_CYGWIN__ */
void
transport_layer_sockets::close ()
{
debug_printf ("close () [this = %p, fd = %d]", this, _fd);
if (_is_listening_endpoint)
(void) unlink (_addr.sun_path);
if (_fd != -1)
{
(void) cygwin_shutdown (_fd, SHUT_WR);
if (!_is_listening_endpoint)
{
char buf[BUFSIZ];
while (::read (_fd, buf, sizeof (buf)) > 0)
{}
}
(void) ::close (_fd);
_fd = -1;
}
_is_listening_endpoint = false;
}
ssize_t
transport_layer_sockets::read (void *const buf, const size_t buf_len)
{
assert (_fd != -1);
assert (!_is_listening_endpoint);
assert (buf);
assert (buf_len > 0);
// verbose: debug_printf ("read (buf = %p, len = %u) [this = %p, fd = %d]",
// buf, buf_len, this, _fd);
char *read_buf = static_cast<char *> (buf);
size_t read_buf_len = buf_len;
ssize_t res = 0;
while (read_buf_len != 0
&& (res = ::read (_fd, read_buf, read_buf_len)) > 0)
{
read_buf += res;
read_buf_len -= res;
assert (read_buf_len >= 0);
}
if (res != -1)
{
if (res == 0)
SET_ERRNO (EIO); // FIXME?
res = buf_len - read_buf_len;
}
if (res != static_cast<ssize_t> (buf_len))
debug_printf ("%d = read (buf = %p, len = %u) [this = %p, fd = %d]: %s",
res, buf, buf_len, this, _fd,
(res == -1 ? strerror (GET_ERRNO ()) : "EOF"));
else
{
// verbose: debug_printf ("%d = read (buf = %p, len = %u) [this = %p, fd = %d]",
// res, buf, buf_len, this, _fd);
}
return res;
}
ssize_t
transport_layer_sockets::write (void *const buf, const size_t buf_len)
{
assert (_fd != -1);
assert (!_is_listening_endpoint);
assert (buf);
assert (buf_len > 0);
// verbose: debug_printf ("write (buf = %p, len = %u) [this = %p, fd = %d]",
// buf, buf_len, this, _fd);
char *write_buf = static_cast<char *> (buf);
size_t write_buf_len = buf_len;
ssize_t res = 0;
while (write_buf_len != 0
&& (res = ::write (_fd, write_buf, write_buf_len)) > 0)
{
write_buf += res;
write_buf_len -= res;
assert (write_buf_len >= 0);
}
if (res != -1)
{
if (res == 0)
SET_ERRNO (EIO); // FIXME?
res = buf_len - write_buf_len;
}
if (res != static_cast<ssize_t> (buf_len))
debug_printf ("%d = write (buf = %p, len = %u) [this = %p, fd = %d]: %s",
res, buf, buf_len, this, _fd,
(res == -1 ? strerror (GET_ERRNO ()) : "EOF"));
else
{
// verbose: debug_printf ("%d = write (buf = %p, len = %u) [this = %p, fd = %d]",
// res, buf, buf_len, this, _fd);
}
return res;
}
int
transport_layer_sockets::connect ()
{
assert (_fd == -1);
assert (!_is_accepted_endpoint);
assert (!_is_listening_endpoint);
static bool assume_cygserver = false;
debug_printf ("connect () [this = %p]", this);
for (int retries = 0; retries != MAX_CONNECT_RETRY; retries++)
{
_fd = cygwin_socket (AF_UNIX, SOCK_STREAM, 0);
if (_fd == -1)
{
system_printf ("failed to create UNIX domain socket: %s",
strerror (GET_ERRNO ()));
return -1;
}
if (cygwin_connect (_fd, (struct sockaddr *) &_addr, _addr_len) == 0)
{
assume_cygserver = true;
debug_printf ("0 = connect () [this = %p, fd = %d]", this, _fd);
return 0;
}
if (!assume_cygserver || GET_ERRNO () != ECONNREFUSED)
{
debug_printf ("failed to connect to server: %s", strerror (GET_ERRNO ()));
(void) ::close (_fd);
_fd = -1;
return -1;
}
(void) ::close (_fd);
_fd = -1;
Sleep (0); // Give the server a chance.
}
debug_printf ("failed to connect to server: %s", strerror (GET_ERRNO ()));
return -1;
}