/* $Id$ */
/*
 *
 */

#include <stdio.h>
#include <string.h>
#include <openssl/hmac.h>
#include <openssl/sha.h>

extern char *__progname;

typedef unsigned int u_int;

#define SSH_KNOWNHOSTS	"knownhosts"
#define HASH_MAGIC	"|1|"
#define HASH_DELIM	'|'

static int
extract_salt(const char *s, u_int l, char *salt, size_t salt_len)
{
	char *p, *b64salt;
	u_int b64len;
	int ret;

	if (l < sizeof(HASH_MAGIC) -1) {
		fprintf(stderr, "%s: extract_salt: string too short\n", __progname);
		return (-1);
	}
	if (strncmp(s, HASH_MAGIC, sizeof(HASH_MAGIC) - 1) != 0) {
		fprintf(stderr, "%s: extract_salt: invalid magic identifier\n", __progname);
		return (-1);
	}
	s += sizeof(HASH_MAGIC) - 1;
	l -= sizeof(HASH_MAGIC) - 1;
	if ((p = memchr(s, HASH_DELIM, l)) == NULL) {
		fprintf(stderr, "%s: extract_salt: missing salt termination character\n", __progname);
		return (-1);
	}

	b64len = p - s;
	/* Sanity check */
	if (b64len == 0 || b64len > 1024) {
		fprintf(stderr, "%s: extract_salt: bad encoded salt length %u\n", b64len);
		return (-1);
	}
	b64salt = malloc(1 + b64len);
	memcpy(b64salt, s, b64len);
	b64salt[b64len] = '\0';

	ret = res_9_b64_pton(b64salt, salt, salt_len);
	free(b64salt);
	if (ret == -1) {
		fprintf(stderr, "%s: extract_salt: salt decode error\n", __progname);
		return (-1);
	}
	if (ret != SHA_DIGEST_LENGTH) {
		fprintf(stderr, "%s: extract_salt: expected salt len %d, got %d\n", SHA_DIGEST_LENGTH, ret);
		return (-1);
	}

	return (0);
}

char *
host_hash(const char *host, const char *name_from_hostfile, u_int src_len)
{
	const EVP_MD *md = EVP_sha1();
	HMAC_CTX mac_ctx;
	char salt[256], result[256], uu_salt[512], uu_result[512];
	static char encoded[1024];
	u_int len;

	len = EVP_MD_size(md);

	if (extract_salt(name_from_hostfile, src_len, salt, sizeof(salt)) == -1)
		return (NULL);

	HMAC_Init(&mac_ctx, salt, len, md);
	HMAC_Update(&mac_ctx, host, strlen(host));
	HMAC_Final(&mac_ctx, result, NULL);
	HMAC_cleanup(&mac_ctx);

	if (res_9_b64_ntop(salt, len, uu_salt, sizeof(uu_salt)) == -1 ||
	    res_9_b64_ntop(result, len, uu_result, sizeof(uu_result)) == -1) {
		fprintf(stderr, "%s: host_hash: __b64_ntop failed\n", __progname);
		exit(-1);
	}

	snprintf(encoded, sizeof(encoded), "%s%s%c%s", HASH_MAGIC, uu_salt,
		HASH_DELIM, uu_result);

	return (encoded);
}

int
check_host_in_hostfile(const char *host)
{
	FILE *f;
	char line[8192];
	int linenum = 0;
	char *cp, *cp2, *hashed_host;

	/* Open the knownhosts file */
	f = fopen(SSH_KNOWNHOSTS, "r");
	if (!f) {
		fprintf(stderr, "%s: open: %s\n", __progname, SSH_KNOWNHOSTS);
		return(-1);
	}

	while(fgets(line, sizeof(line), f)) {
		cp = line;
		linenum++;

		/* Skip any leading whitespace, comments and empty lines. */
		for(; *cp == ' ' || *cp == '\t'; cp++)
			;
		if (!*cp || *cp == '#' || *cp == '\n')
			continue;

		/* Find the end of the host name portion. */
		for (cp2 = cp; *cp2 && *cp2 != ' ' && *cp2 != '\t'; cp2++)
			;

		hashed_host = host_hash(host, cp, (u_int) (cp2 - cp));
		if (hashed_host == NULL) {
			fprintf(stderr, "Invalid hashed host line %d of %s\n",
				linenum, SSH_KNOWNHOSTS);
			continue;
		}
		if (strncmp(hashed_host, cp, (u_int) (cp2 - cp)) == 0) {
			printf("Found at %d (%s)\n", linenum, hashed_host);
		}
	}

	fclose(f);
	return 0;
}

int
main(int argc, char *argv[])
{
	if (argc != 2) {
		fprintf(stderr, "Usage: %s <host>\n", __progname);
		exit(-1);
	}

	check_host_in_hostfile(argv[1]);

	return 0;
}

