diff --git a/darkhttpd.c b/darkhttpd.c index 76859d8..7a4b411 100644 --- a/darkhttpd.c +++ b/darkhttpd.c @@ -235,7 +235,7 @@ struct connection { size_t request_length; /* request fields */ - char *method, *url, *referer, *user_agent; + char *method, *url, *referer, *user_agent, *authorization; off_t range_begin, range_end; off_t range_begin_given, range_end_given; @@ -301,6 +301,7 @@ static char *pidfile_name = NULL; /* NULL = no pidfile */ static int want_chroot = 0, want_daemon = 0, want_accf = 0, want_keepalive = 1, want_server_id = 1; static char *server_hdr = NULL; +static char *auth_key = NULL; static uint64_t num_requests = 0, total_in = 0, total_out = 0; static int accepting = 1; /* set to 0 to stop accept()ing */ @@ -933,6 +934,8 @@ static void usage(const char *argv0) { "\t\tIf a connection is idle for more than this many seconds,\n" "\t\tit will be closed. Set to zero to disable timeouts.\n\n", timeout_secs); + printf("\t--auth username:password\n" + "\t\tEnable basic authentication.\n\n"); #ifdef HAVE_INET6 printf("\t--ipv6\n" "\t\tListen on IPv6 address.\n\n"); @@ -941,6 +944,45 @@ static void usage(const char *argv0) { #endif } +static char *base64_encode(char *str) { + const char base64_table[] = { + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', + 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', + 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', + 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', + 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', + 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', + 'w', 'x', 'y', 'z', '0', '1', '2', '3', + '4', '5', '6', '7', '8', '9', '+', '/'}; + + int input_length = strlen(str); + int output_length = 4 * ((input_length + 2) / 3); + + char *encoded_data = malloc(output_length+1); + if (encoded_data == NULL) return NULL; + + for (int i = 0, j = 0; i < input_length;) { + + uint32_t octet_a = i < input_length ? (unsigned char)str[i++] : 0; + uint32_t octet_b = i < input_length ? (unsigned char)str[i++] : 0; + uint32_t octet_c = i < input_length ? (unsigned char)str[i++] : 0; + + uint32_t triple = (octet_a << 0x10) + (octet_b << 0x08) + octet_c; + + encoded_data[j++] = base64_table[(triple >> 3 * 6) & 0x3F]; + encoded_data[j++] = base64_table[(triple >> 2 * 6) & 0x3F]; + encoded_data[j++] = base64_table[(triple >> 1 * 6) & 0x3F]; + encoded_data[j++] = base64_table[(triple >> 0 * 6) & 0x3F]; + } + + const int mod_table[] = {0, 2, 1}; + for (int i = 0; i < mod_table[input_length % 3]; i++) + encoded_data[output_length - 1 - i] = '='; + encoded_data[output_length] = '\0'; + + return encoded_data; +} + /* Returns 1 if string is a number, 0 otherwise. Set num to NULL if * disinterested in value. */ @@ -1095,6 +1137,14 @@ static void parse_commandline(const int argc, char *argv[]) { errx(1, "missing number after --timeout"); timeout_secs = (int)xstr_to_num(argv[i]); } + else if (strcmp(argv[i], "--auth") == 0) { + if (++i >= argc || strchr(argv[i], ':') == NULL) + errx(1, "missing 'user:pass' after --auth"); + + char *key = base64_encode(argv[i]); + xasprintf(&auth_key, "Basic %s", key); + free(key); + } #ifdef HAVE_INET6 else if (strcmp(argv[i], "--ipv6") == 0) { inet6 = 1; @@ -1118,6 +1168,7 @@ static struct connection *new_connection(void) { conn->url = NULL; conn->referer = NULL; conn->user_agent = NULL; + conn->authorization = NULL; conn->range_begin = 0; conn->range_end = 0; conn->range_begin_given = 0; @@ -1288,6 +1339,7 @@ static void free_connection(struct connection *conn) { if (conn->url != NULL) free(conn->url); if (conn->referer != NULL) free(conn->referer); if (conn->user_agent != NULL) free(conn->user_agent); + if (conn->authorization != NULL) free(conn->authorization); if (conn->header != NULL && !conn->header_dont_free) free(conn->header); if (conn->reply != NULL && !conn->reply_dont_free) free(conn->reply); if (conn->reply_fd != -1) xclose(conn->reply_fd); @@ -1311,6 +1363,7 @@ static void recycle_connection(struct connection *conn) { conn->url = NULL; conn->referer = NULL; conn->user_agent = NULL; + conn->authorization = NULL; conn->range_begin = 0; conn->range_end = 0; conn->range_begin_given = 0; @@ -1442,6 +1495,9 @@ static void default_reply(struct connection *conn, errcode, errname, errname, reason, generated_on(date)); free(reason); + const char *auth_header = + "WWW-Authenticate: Basic realm=\"User Visible Realm\""; + conn->header_length = xasprintf(&(conn->header), "HTTP/1.1 %d %s\r\n" "Date: %s\r\n" @@ -1450,9 +1506,11 @@ static void default_reply(struct connection *conn, "%s" /* keep-alive */ "Content-Length: %llu\r\n" "Content-Type: text/html; charset=UTF-8\r\n" + "%s\r\n" "\r\n", errcode, errname, date, server_hdr, keep_alive(conn), - llu(conn->reply_length)); + llu(conn->reply_length), + (auth_key != NULL ? auth_header : "")); conn->reply_type = REPLY_GENERATED; conn->http_code = errcode; @@ -1653,6 +1711,7 @@ static int parse_request(struct connection *conn) { /* parse important fields */ conn->referer = parse_field(conn, "Referer: "); conn->user_agent = parse_field(conn, "User-Agent: "); + conn->authorization = parse_field(conn, "Authorization: "); parse_range_field(conn); return 1; } @@ -2085,10 +2144,19 @@ static void process_get(struct connection *conn) { /* Process a request: build the header and reply, advance state. */ static void process_request(struct connection *conn) { num_requests++; + if (!parse_request(conn)) { default_reply(conn, 400, "Bad Request", "You sent a request that the server couldn't understand."); } + // fail if: (auth_enabled) AND (client supplied invalid credentials) + if (auth_key != NULL && + (conn->authorization == NULL || + strcmp(conn->authorization, auth_key))) + { + default_reply(conn, 401, "Unauthorized", + "Access denied due to invalid credentials."); + } else if (strcmp(conn->method, "GET") == 0) { process_get(conn); } @@ -2701,6 +2769,7 @@ int main(int argc, char **argv) { free(keep_alive_field); free(wwwroot); free(server_hdr); + free(auth_key); } /* usage stats */