From 4540428cd0adf039bcf5a8a27f2d5cdf09191513 Mon Sep 17 00:00:00 2001 From: "djm@openbsd.org" Date: Sat, 24 Jun 2017 05:37:44 +0000 Subject: [PATCH] upstream commit switch sshconnect.c from (slightly abused) select() to poll(); ok deraadt@ a while back Upstream-ID: efc1937fc591bbe70ac9e9542bb984f354c8c175 --- sshconnect.c | 158 +++++++++++++++++++++------------------------------ 1 file changed, 65 insertions(+), 93 deletions(-) diff --git a/sshconnect.c b/sshconnect.c index 4100cdc8c..8f527aa43 100644 --- a/sshconnect.c +++ b/sshconnect.c @@ -1,4 +1,4 @@ -/* $OpenBSD: sshconnect.c,v 1.281 2017/06/24 05:35:05 djm Exp $ */ +/* $OpenBSD: sshconnect.c,v 1.282 2017/06/24 05:37:44 djm Exp $ */ /* * Author: Tatu Ylonen * Copyright (c) 1995 Tatu Ylonen , Espoo, Finland @@ -34,6 +34,9 @@ #include #endif #include +#ifdef HAVE_POLL_H +#include +#endif #include #include #include @@ -328,87 +331,71 @@ ssh_create_socket(int privileged, struct addrinfo *ai) return sock; } +/* + * Wait up to *timeoutp milliseconds for fd to be readable. Updates + * *timeoutp with time remaining. + * Returns 0 if fd ready or -1 on timeout or error (see errno). + */ +static int +waitrfd(int fd, int *timeoutp) +{ + struct pollfd pfd; + struct timeval t_start; + int oerrno, r; + + gettimeofday(&t_start, NULL); + pfd.fd = fd; + pfd.events = POLLIN; + for (; *timeoutp >= 0;) { + r = poll(&pfd, 1, *timeoutp); + oerrno = errno; + ms_subtract_diff(&t_start, timeoutp); + errno = oerrno; + if (r > 0) + return 0; + else if (r == -1 && errno != EAGAIN) + return -1; + else if (r == 0) + break; + } + /* timeout */ + errno = ETIMEDOUT; + return -1; +} + static int timeout_connect(int sockfd, const struct sockaddr *serv_addr, socklen_t addrlen, int *timeoutp) { - fd_set *fdset; - struct timeval tv, t_start; - socklen_t optlen; - int optval, rc, result = -1; + int optval = 0; + socklen_t optlen = sizeof(optval); - gettimeofday(&t_start, NULL); - - if (*timeoutp <= 0) { - result = connect(sockfd, serv_addr, addrlen); - goto done; - } + /* No timeout: just do a blocking connect() */ + if (*timeoutp <= 0) + return connect(sockfd, serv_addr, addrlen); set_nonblock(sockfd); - rc = connect(sockfd, serv_addr, addrlen); - if (rc == 0) { + if (connect(sockfd, serv_addr, addrlen) == 0) { + /* Succeeded already? */ unset_nonblock(sockfd); - result = 0; - goto done; + return 0; + } else if (errno != EINPROGRESS) + return -1; + + if (waitrfd(sockfd, timeoutp) == -1) + return -1; + + /* Completed or failed */ + if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &optval, &optlen) == -1) { + debug("getsockopt: %s", strerror(errno)); + return -1; } - if (errno != EINPROGRESS) { - result = -1; - goto done; + if (optval != 0) { + errno = optval; + return -1; } - - fdset = xcalloc(howmany(sockfd + 1, NFDBITS), - sizeof(fd_mask)); - FD_SET(sockfd, fdset); - ms_to_timeval(&tv, *timeoutp); - - for (;;) { - rc = select(sockfd + 1, NULL, fdset, NULL, &tv); - if (rc != -1 || errno != EINTR) - break; - } - - switch (rc) { - case 0: - /* Timed out */ - errno = ETIMEDOUT; - break; - case -1: - /* Select error */ - debug("select: %s", strerror(errno)); - break; - case 1: - /* Completed or failed */ - optval = 0; - optlen = sizeof(optval); - if (getsockopt(sockfd, SOL_SOCKET, SO_ERROR, &optval, - &optlen) == -1) { - debug("getsockopt: %s", strerror(errno)); - break; - } - if (optval != 0) { - errno = optval; - break; - } - result = 0; - unset_nonblock(sockfd); - break; - default: - /* Should not occur */ - fatal("Bogus return (%d) from select()", rc); - } - - free(fdset); - - done: - if (result == 0 && *timeoutp > 0) { - ms_subtract_diff(&t_start, timeoutp); - if (*timeoutp <= 0) { - errno = ETIMEDOUT; - result = -1; - } - } - - return (result); + unset_nonblock(sockfd); + return 0; } /* @@ -546,39 +533,25 @@ ssh_exchange_identification(int timeout_ms) int connection_out = packet_get_connection_out(); u_int i, n; size_t len; - int fdsetsz, remaining, rc; - struct timeval t_start, t_remaining; - fd_set *fdset; - - fdsetsz = howmany(connection_in + 1, NFDBITS) * sizeof(fd_mask); - fdset = xcalloc(1, fdsetsz); + int rc; send_client_banner(connection_out, 0); /* Read other side's version identification. */ - remaining = timeout_ms; for (n = 0;;) { for (i = 0; i < sizeof(buf) - 1; i++) { if (timeout_ms > 0) { - gettimeofday(&t_start, NULL); - ms_to_timeval(&t_remaining, remaining); - FD_SET(connection_in, fdset); - rc = select(connection_in + 1, fdset, NULL, - fdset, &t_remaining); - ms_subtract_diff(&t_start, &remaining); - if (rc == 0 || remaining <= 0) + rc = waitrfd(connection_in, &timeout_ms); + if (rc == -1 && errno == ETIMEDOUT) { fatal("Connection timed out during " "banner exchange"); - if (rc == -1) { - if (errno == EINTR) - continue; - fatal("ssh_exchange_identification: " - "select: %s", strerror(errno)); + } else if (rc == -1) { + fatal("%s: %s", + __func__, strerror(errno)); } } len = atomicio(read, connection_in, &buf[i], 1); - if (len != 1 && errno == EPIPE) fatal("ssh_exchange_identification: " "Connection closed by remote host"); @@ -604,7 +577,6 @@ ssh_exchange_identification(int timeout_ms) debug("ssh_exchange_identification: %s", buf); } server_version_string = xstrdup(buf); - free(fdset); /* * Check that the versions match. In future this might accept