/* * Copyright (C) 2010-2022 Willy Tarreau * * Permission is hereby granted, free of charge, to any person obtaining * a copy of this software and associated documentation files (the * "Software"), to deal in the Software without restriction, including * without limitation the rights to use, copy, modify, merge, publish, * distribute, sublicense, and/or sell copies of the Software, and to * permit persons to whom the Software is furnished to do so, subject to * the following conditions: * * The above copyright notice and this permission notice shall be * included in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR * OTHER DEALINGS IN THE SOFTWARE. */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #define MAXCONN 1 const int zero = 0; const int one = 1; struct conn { struct sockaddr_storage cli_addr; int fd_bck; }; struct errmsg { char *msg; int size; int len; }; struct sockaddr_storage frt_addr; // listen address struct sockaddr_storage srv_addr; // server address #define MAXPKTSIZE 16384 char trash[MAXPKTSIZE]; struct conn conns[MAXCONN]; // sole connection for now int fd_frt; int nbfd = 0; int nbconn = MAXCONN; /* display the message and exit with the code */ __attribute__((noreturn)) void die(int code, const char *format, ...) { va_list args; va_start(args, format); vfprintf(stderr, format, args); va_end(args); exit(code); } /* converts str in the form [||]:port to struct sockaddr_storage. * Returns < 0 with err set in case of error. */ int addr_to_ss(char *str, struct sockaddr_storage *ss, struct errmsg *err) { char *port_str; int port; /* look for the addr/port delimiter, it's the last colon. */ if ((port_str = strrchr(str, ':')) == NULL) port_str = str; else *port_str++ = 0; port = atoi(port_str); if (port <= 0 || port > 65535) { err->len = snprintf(err->msg, err->size, "Missing/invalid port number: '%s'\n", port_str); return -1; } *port_str = 0; // present an empty address if none was set memset(ss, 0, sizeof(*ss)); if (strrchr(str, ':') != NULL) { /* IPv6 address contains ':' */ ss->ss_family = AF_INET6; ((struct sockaddr_in6 *)ss)->sin6_port = htons(port); if (!inet_pton(ss->ss_family, str, &((struct sockaddr_in6 *)ss)->sin6_addr)) { err->len = snprintf(err->msg, err->size, "Invalid IPv6 server address: '%s'", str); return -1; } } else { ss->ss_family = AF_INET; ((struct sockaddr_in *)ss)->sin_port = htons(port); if (*str == '*' || *str == '\0') { /* INADDR_ANY */ ((struct sockaddr_in *)ss)->sin_addr.s_addr = INADDR_ANY; return 0; } if (!inet_pton(ss->ss_family, str, &((struct sockaddr_in *)ss)->sin_addr)) { struct hostent *he = gethostbyname(str); if (he == NULL) { err->len = snprintf(err->msg, err->size, "Invalid IPv4 server name: '%s'", str); return -1; } ((struct sockaddr_in *)ss)->sin_addr = *(struct in_addr *) *(he->h_addr_list); } } return 0; } /* returns <0 with err in case of error or the front FD */ int create_udp_listener(struct sockaddr_storage *addr, struct errmsg *err) { int fd; if ((fd = socket(addr->ss_family, SOCK_DGRAM, 0)) == -1) { err->len = snprintf(err->msg, err->size, "socket(): '%s'", strerror(errno)); goto fail; } if (fcntl(fd, F_SETFL, O_NONBLOCK) == -1) { err->len = snprintf(err->msg, err->size, "fcntl(O_NONBLOCK): '%s'", strerror(errno)); goto fail; } if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, (char *) &one, sizeof(one)) == -1) { err->len = snprintf(err->msg, err->size, "setsockopt(SO_REUSEADDR): '%s'", strerror(errno)); goto fail; } #ifdef SO_REUSEPORT if (setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, (char *) &one, sizeof(one)) == -1) { err->len = snprintf(err->msg, err->size, "setsockopt(SO_REUSEPORT): '%s'", strerror(errno)); goto fail; } #endif if (bind(fd, (struct sockaddr *)&frt_addr, addr->ss_family == AF_INET6 ? sizeof(struct sockaddr_in6) : sizeof(struct sockaddr_in)) == -1) { err->len = snprintf(err->msg, err->size, "bind(): '%s'", strerror(errno)); goto fail; } /* the socket is ready */ return fd; fail: if (fd > -1) close(fd); fd = -1; return fd; } /* recompute pollfds using frt_fd and scanning nbconn connections. * Returns the number of FDs in the set. */ int update_pfd(struct pollfd *pfd, int frt_fd, struct conn *conns, int nbconn) { int nbfd = 0; int i; pfd[nbfd].fd = frt_fd; pfd[nbfd].events = POLLIN; nbfd++; for (i = 0; i < nbconn; i++) { if (conns[i].fd_bck < 0) continue; pfd[nbfd].fd = conns[i].fd_bck; pfd[nbfd].events = POLLIN; nbfd++; } return nbfd; } /* searches a connection using fd as back connection, returns it if found * otherwise NULL. */ struct conn *conn_bck_lookup(struct conn *conns, int nbconn, int fd) { int i; for (i = 0; i < nbconn; i++) { if (conns[i].fd_bck < 0) continue; if (conns[i].fd_bck == fd) return &conns[i]; } return NULL; } /* Try to establish a connection to . Return the fd or -1 in case of error */ int add_connection(struct sockaddr_storage *ss) { int fd; fd = socket(ss->ss_family, SOCK_DGRAM, 0); if (fd < 0) goto fail; if (fcntl(fd, F_SETFL, O_NONBLOCK) == -1) goto fail; if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) == -1) goto fail; if (connect(fd, (struct sockaddr *)ss, ss->ss_family == AF_INET6 ? sizeof(struct sockaddr_in6) : sizeof(struct sockaddr_in)) == -1) { if (errno != EINPROGRESS) goto fail; } return fd; fail: if (fd > -1) close(fd); return -1; } /* Handle a read operation on an front FD. Will either reuse the existing * connection if the source is found, or will allocate a new one, possibly * replacing the oldest one. Returns <0 on error or the number of bytes * transmitted. */ int handle_frt(int fd, struct pollfd *pfd, struct conn *conns, int nbconn) { struct sockaddr_storage addr; socklen_t addrlen; struct conn *conn; int ret; int i; ret = recvfrom(fd, trash, sizeof(trash), MSG_DONTWAIT | MSG_NOSIGNAL, (struct sockaddr *)&addr, &addrlen); if (ret == 0) return 0; if (ret < 0) return errno == EAGAIN ? 0 : -1; conn = NULL; for (i = 0; i < nbconn; i++) { if (addr.ss_family != conns[i].cli_addr.ss_family) continue; if (memcmp(&conns[i].cli_addr, &addr, (addr.ss_family == AF_INET6) ? sizeof(struct sockaddr_in6) : sizeof(struct sockaddr_in)) != 0) continue; conn = &conns[i]; break; } if (!conn) { /* address not found, create a new conn or replace the oldest * one. For now we support a single one. */ conn = &conns[0]; memcpy(&conn->cli_addr, &addr, (addr.ss_family == AF_INET6) ? sizeof(struct sockaddr_in6) : sizeof(struct sockaddr_in)); if (conn->fd_bck < 0) { /* try to create a new connection */ conn->fd_bck = add_connection(&srv_addr); nbfd = update_pfd(pfd, fd, conns, nbconn); // FIXME: MAXCONN instead ? } } if (conn->fd_bck < 0) return 0; ret = send(conn->fd_bck, trash, ret, MSG_DONTWAIT | MSG_NOSIGNAL); return ret; } /* Handle a read operation on an FD. Close and return 0 when the read returns zero or an error */ int handle_bck(int fd, struct pollfd *pfd, struct conn *conns, int nbconn) { struct sockaddr_storage addr; socklen_t addrlen; struct conn *conn; int ret; ret = recvfrom(fd, trash, sizeof(trash), MSG_DONTWAIT | MSG_NOSIGNAL, (struct sockaddr *)&addr, &addrlen); if (ret == 0) return 0; if (ret < 0) return errno == EAGAIN ? 0 : -1; conn = conn_bck_lookup(conns, nbconn, fd); if (!conn) return 0; ret = sendto(fd_frt, trash, ret, MSG_DONTWAIT | MSG_NOSIGNAL, (struct sockaddr *)&conn->cli_addr, conn->cli_addr.ss_family == AF_INET6 ? sizeof(struct sockaddr_in6) : sizeof(struct sockaddr_in)); return ret; } int main(int argc, char **argv) { struct errmsg err; struct pollfd *pfd; int i; err.len = 0; err.size = 100; err.msg = malloc(err.size); if (argc < 3) die(1, "Usage: %s [:] [:]\n", argv[0]); if (addr_to_ss(argv[1], &frt_addr, &err) < 0) die(1, "parsing listen address: %s\n", err.msg); if (addr_to_ss(argv[2], &srv_addr, &err) < 0) die(1, "parsing server address: %s\n", err.msg); pfd = calloc(sizeof(struct pollfd), MAXCONN + 1); if (!pfd) die(1, "out of memory\n"); fd_frt = create_udp_listener(&frt_addr, &err); if (fd_frt < 0) die(1, "binding listener: %s\n", err.msg); for (i = 0; i < MAXCONN; i++) conns[i].fd_bck = -1; nbfd = update_pfd(pfd, fd_frt, conns, MAXCONN); while (1) { /* listen for incoming packets */ int ret, i; ret = poll(pfd, nbfd, 1000); if (ret <= 0) continue; for (i = 0; ret; i++) { if (!pfd[i].revents) continue; ret--; if (pfd[i].fd == fd_frt) { handle_frt(pfd[i].fd, pfd, conns, nbconn); continue; } handle_bck(pfd[i].fd, pfd, conns, nbconn); } } }