sbase/tftp.c

310 lines
5.7 KiB
C

/* See LICENSE file for copyright and license details. */
#include <sys/time.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netdb.h>
#include <netinet/in.h>
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include "util.h"
#define BLKSIZE 512
#define HDRSIZE 4
#define PKTSIZE (BLKSIZE + HDRSIZE)
#define TIMEOUT_SEC 5
/* transfer will time out after NRETRIES * TIMEOUT_SEC */
#define NRETRIES 5
#define RRQ 1
#define WWQ 2
#define DATA 3
#define ACK 4
#define ERR 5
static char *errtext[] = {
"Undefined",
"File not found",
"Access violation",
"Disk full or allocation exceeded",
"Illegal TFTP operation",
"Unknown transfer ID",
"File already exists",
"No such user"
};
static struct sockaddr_storage to;
static socklen_t tolen;
static int timeout;
static int state;
static int s;
static int
packreq(unsigned char *buf, int op, char *path, char *mode)
{
unsigned char *p = buf;
*p++ = op >> 8;
*p++ = op & 0xff;
if (strlen(path) + 1 > 256)
eprintf("filename too long\n");
memcpy(p, path, strlen(path) + 1);
p += strlen(path) + 1;
memcpy(p, mode, strlen(mode) + 1);
p += strlen(mode) + 1;
return p - buf;
}
static int
packack(unsigned char *buf, int blkno)
{
buf[0] = ACK >> 8;
buf[1] = ACK & 0xff;
buf[2] = blkno >> 8;
buf[3] = blkno & 0xff;
return 4;
}
static int
packdata(unsigned char *buf, int blkno)
{
buf[0] = DATA >> 8;
buf[1] = DATA & 0xff;
buf[2] = blkno >> 8;
buf[3] = blkno & 0xff;
return 4;
}
static int
unpackop(unsigned char *buf)
{
return (buf[0] << 8) | (buf[1] & 0xff);
}
static int
unpackblkno(unsigned char *buf)
{
return (buf[2] << 8) | (buf[3] & 0xff);
}
static int
unpackerrc(unsigned char *buf)
{
int errc;
errc = (buf[2] << 8) | (buf[3] & 0xff);
if (errc < 0 || errc >= LEN(errtext))
eprintf("bad error code: %d\n", errc);
return errc;
}
static int
writepkt(unsigned char *buf, int len)
{
int n;
n = sendto(s, buf, len, 0, (struct sockaddr *)&to,
tolen);
if (n < 0)
if (errno != EINTR)
eprintf("sendto:");
return n;
}
static int
readpkt(unsigned char *buf, int len)
{
int n;
n = recvfrom(s, buf, len, 0, (struct sockaddr *)&to,
&tolen);
if (n < 0) {
if (errno != EINTR && errno != EWOULDBLOCK)
eprintf("recvfrom:");
timeout++;
if (timeout == NRETRIES)
eprintf("transfer timed out\n");
} else {
timeout = 0;
}
return n;
}
static void
getfile(char *file)
{
unsigned char buf[PKTSIZE];
int n, op, blkno, nextblkno = 1, done = 0;
state = RRQ;
for (;;) {
switch (state) {
case RRQ:
n = packreq(buf, RRQ, file, "octet");
writepkt(buf, n);
n = readpkt(buf, sizeof(buf));
if (n > 0) {
op = unpackop(buf);
if (op != DATA && op != ERR)
eprintf("bad opcode: %d\n", op);
state = op;
}
break;
case DATA:
n -= HDRSIZE;
if (n < 0)
eprintf("truncated packet\n");
blkno = unpackblkno(buf);
if (blkno == nextblkno) {
nextblkno++;
write(1, &buf[HDRSIZE], n);
}
if (n < BLKSIZE)
done = 1;
state = ACK;
break;
case ACK:
n = packack(buf, blkno);
writepkt(buf, n);
if (done)
return;
n = readpkt(buf, sizeof(buf));
if (n > 0) {
op = unpackop(buf);
if (op != DATA && op != ERR)
eprintf("bad opcode: %d\n", op);
state = op;
}
break;
case ERR:
eprintf("error: %s\n", errtext[unpackerrc(buf)]);
}
}
}
static void
putfile(char *file)
{
unsigned char inbuf[PKTSIZE], outbuf[PKTSIZE];
int inb, outb, op, blkno, nextblkno = 0, done = 0;
state = WWQ;
for (;;) {
switch (state) {
case WWQ:
outb = packreq(outbuf, WWQ, file, "octet");
writepkt(outbuf, outb);
inb = readpkt(inbuf, sizeof(inbuf));
if (inb > 0) {
op = unpackop(inbuf);
if (op != ACK && op != ERR)
eprintf("bad opcode: %d\n", op);
state = op;
}
break;
case DATA:
if (blkno == nextblkno) {
nextblkno++;
packdata(outbuf, nextblkno);
outb = read(0, &outbuf[HDRSIZE], BLKSIZE);
if (outb < BLKSIZE)
done = 1;
}
writepkt(outbuf, outb + HDRSIZE);
inb = readpkt(inbuf, sizeof(inbuf));
if (inb > 0) {
op = unpackop(inbuf);
if (op != ACK && op != ERR)
eprintf("bad opcode: %d\n", op);
state = op;
}
break;
case ACK:
if (inb < HDRSIZE)
eprintf("truncated packet\n");
blkno = unpackblkno(inbuf);
if (blkno == nextblkno)
if (done)
return;
state = DATA;
break;
case ERR:
eprintf("error: %s\n", errtext[unpackerrc(inbuf)]);
}
}
}
static void
usage(void)
{
eprintf("usage: %s -h host [-p port] [-x | -c] file\n", argv0);
}
int
main(int argc, char *argv[])
{
struct addrinfo hints, *res, *r;
struct timeval tv;
char *host = NULL, *port = "tftp";
void (*fn)(char *) = getfile;
int ret;
ARGBEGIN {
case 'h':
host = EARGF(usage());
break;
case 'p':
port = EARGF(usage());
break;
case 'x':
fn = getfile;
break;
case 'c':
fn = putfile;
break;
default:
usage();
} ARGEND
if (!host || !argc)
usage();
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_DGRAM;
hints.ai_protocol = IPPROTO_UDP;
ret = getaddrinfo(host, port, &hints, &res);
if (ret)
eprintf("getaddrinfo: %s\n", gai_strerror(ret));
for (r = res; r; r = r->ai_next) {
if (r->ai_family != AF_INET &&
r->ai_family != AF_INET6)
continue;
s = socket(r->ai_family, r->ai_socktype,
r->ai_protocol);
if (s < 0)
continue;
break;
}
if (!r)
eprintf("cannot create socket\n");
memcpy(&to, r->ai_addr, r->ai_addrlen);
tolen = r->ai_addrlen;
freeaddrinfo(res);
tv.tv_sec = TIMEOUT_SEC;
tv.tv_usec = 0;
if (setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv)) < 0)
eprintf("setsockopt:");
fn(argv[0]);
return 0;
}