dns: implement tcp fallback in __res_msend query core

tcp fallback was originally deemed unwanted and unnecessary, since we
aim to return a bounded-size result from getaddrinfo anyway and
normally plenty of address records fit in the 512-byte udp dns limit.
however, this turned out to have several problems:

- some recursive nameservers truncate by omitting all the answers,
  rather than sending as many as can fit.

- a pathological worst-case CNAME for a worst-case name can fill the
  entire 512-byte space with just the two names, leaving no room for
  any addresses.

- the res_* family of interfaces allow querying of non-address records
  such as TLSA (DANE), TXT, etc. which can be very large. for many of
  these, it's critical that the caller see the whole RRset. also,
  res_send/res_query are specified to return the complete, untruncated
  length so that the caller can retry with an appropriately-sized
  buffer. determining this is not possible without tcp.

so, it's time to add tcp fallback.

the fallback strategy implemented here uses one tcp socket per
question (1 or 2 questions), initiated via tcp fastopen when possible.
the connection is made to the nameserver that issued the truncated
answer. right now, fallback happens unconditionally when truncation is
seen. this can, and may later be, relaxed for queries made by the
getaddrinfo system, since it will only use a bounded number of results
anyway.

retry is not attempted again after failure over tcp. the logic could
easily be adapted to do that, but it's of questionable value, since
the tcp stack automatically handles retransmission and the successs
answer with TC=1 over udp strongly suggests that the nameserver has
the full answer ready to give. further retry is likely just "take
longer to fail".
This commit is contained in:
Rich Felker 2022-09-22 14:17:05 -04:00
parent e2e9517607
commit 51d4669fb9
1 changed files with 117 additions and 2 deletions

View File

@ -1,5 +1,6 @@
#include <sys/socket.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <netdb.h>
#include <arpa/inet.h>
#include <stdint.h>
@ -29,6 +30,51 @@ static unsigned long mtime()
+ ts.tv_nsec / 1000000;
}
static int start_tcp(struct pollfd *pfd, int family, const void *sa, socklen_t sl, const unsigned char *q, int ql)
{
struct msghdr mh = {
.msg_name = (void *)sa,
.msg_namelen = sl,
.msg_iovlen = 2,
.msg_iov = (struct iovec [2]){
{ .iov_base = (uint8_t[]){ ql>>8, ql }, .iov_len = 2 },
{ .iov_base = (void *)q, .iov_len = ql } }
};
int r;
int fd = socket(family, SOCK_STREAM|SOCK_CLOEXEC|SOCK_NONBLOCK, 0);
pfd->fd = fd;
pfd->events = POLLOUT;
if (!setsockopt(fd, IPPROTO_TCP, TCP_FASTOPEN_CONNECT,
&(int){1}, sizeof(int))) {
r = sendmsg(fd, &mh, MSG_FASTOPEN|MSG_NOSIGNAL);
if (r == ql+2) pfd->events = POLLIN;
if (r >= 0) return r;
if (errno == EINPROGRESS) return 0;
}
r = connect(fd, sa, sl);
if (!r || errno == EINPROGRESS) return 0;
close(fd);
pfd->fd = -1;
return -1;
}
static void step_mh(struct msghdr *mh, size_t n)
{
/* Adjust iovec in msghdr to skip first n bytes. */
while (mh->msg_iovlen && n >= mh->msg_iov->iov_len) {
n -= mh->msg_iov->iov_len;
mh->msg_iov++;
mh->msg_iovlen--;
}
if (!mh->msg_iovlen) return;
mh->msg_iov->iov_base = (char *)mh->msg_iov->iov_base + n;
mh->msg_iov->iov_len -= n;
}
/* Internal contract for __res_msend[_rc]: asize must be >=512, nqueries
* must be sufficiently small to be safe as VLA size. In practice it's
* either 1 or 2, anyway. */
int __res_msend_rc(int nqueries, const unsigned char *const *queries,
const int *qlens, unsigned char *const *answers, int *alens, int asize,
const struct resolvconf *conf)
@ -47,6 +93,9 @@ int __res_msend_rc(int nqueries, const unsigned char *const *queries,
int i, j;
int cs;
struct pollfd pfd[nqueries+2];
int qpos[nqueries], apos[nqueries];
unsigned char alen_buf[nqueries][2];
int r;
unsigned long t0, t1, t2;
pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs);
@ -125,6 +174,11 @@ int __res_msend_rc(int nqueries, const unsigned char *const *queries,
t1 = t2 - retry_interval;
for (; t2-t0 < timeout; t2=mtime()) {
/* This is the loop exit condition: that all queries
* have an accepted answer. */
for (i=0; i<nqueries && alens[i]>0; i++);
if (i==nqueries) break;
if (t2-t1 >= retry_interval) {
/* Query all configured namservers in parallel */
for (i=0; i<nqueries; i++)
@ -140,7 +194,8 @@ int __res_msend_rc(int nqueries, const unsigned char *const *queries,
/* Wait for a response, or until time to retry */
if (poll(pfd, nqueries+1, t1+retry_interval-t2) <= 0) continue;
while ((rlen = recvfrom(fd, answers[next], asize, 0,
while (next < nqueries &&
(rlen = recvfrom(fd, answers[next], asize, 0,
(void *)&sa, (socklen_t[1]){sl})) >= 0) {
/* Ignore non-identifiable packets */
@ -181,12 +236,72 @@ int __res_msend_rc(int nqueries, const unsigned char *const *queries,
else
memcpy(answers[i], answers[next], rlen);
if (next == nqueries) goto out;
/* Ignore further UDP if all slots full or TCP-mode */
if (next == nqueries) pfd[nqueries].events = 0;
/* If answer is truncated (TC bit), fallback to TCP */
if (answers[i][2] & 2) {
alens[i] = -1;
pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, 0);
r = start_tcp(pfd+i, family, ns+j, sl, queries[i], qlens[i]);
pthread_setcancelstate(cs, 0);
if (r >= 0) {
qpos[i] = r;
apos[i] = 0;
}
continue;
}
}
for (i=0; i<nqueries; i++) if (pfd[i].revents & POLLOUT) {
struct msghdr mh = {
.msg_iovlen = 2,
.msg_iov = (struct iovec [2]){
{ .iov_base = (uint8_t[]){ qlens[i]>>8, qlens[i] }, .iov_len = 2 },
{ .iov_base = (void *)queries[i], .iov_len = qlens[i] } }
};
step_mh(&mh, qpos[i]);
r = sendmsg(pfd[i].fd, &mh, MSG_NOSIGNAL);
if (r < 0) goto out;
qpos[i] += r;
if (qpos[i] == qlens[i]+2)
pfd[i].events = POLLIN;
}
for (i=0; i<nqueries; i++) if (pfd[i].revents & POLLIN) {
struct msghdr mh = {
.msg_iovlen = 2,
.msg_iov = (struct iovec [2]){
{ .iov_base = alen_buf[i], .iov_len = 2 },
{ .iov_base = answers[i], .iov_len = asize } }
};
step_mh(&mh, apos[i]);
r = recvmsg(pfd[i].fd, &mh, 0);
if (r < 0) goto out;
apos[i] += r;
if (apos[i] < 2) continue;
int alen = alen_buf[i][0]*256 + alen_buf[i][1];
if (alen < 13) goto out;
if (apos[i] < alen+2 && apos[i] < asize+2)
continue;
int rcode = answers[i][3] & 15;
if (rcode != 0 && rcode != 3)
goto out;
/* Storing the length here commits the accepted answer.
* Immediately close TCP socket so as not to consume
* resources we no longer need. */
alens[i] = alen;
__syscall(SYS_close, pfd[i].fd);
pfd[i].fd = -1;
}
}
out:
pthread_cleanup_pop(1);
/* Disregard any incomplete TCP results */
for (i=0; i<nqueries; i++) if (alens[i]<0) alens[i] = 0;
return 0;
}