/* $Cambridge: hermes/src/mailchk/ssl.c,v 1.1 2003/08/10 22:27:44 dpc22 Exp $ */

#include "mailchk.h"

/* Headers files for OpenSSL */

#include <openssl/lhash.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <openssl/rand.h>

/* ====================================================================== */

BOOL ssl_is_available()
{
    return (T);
}

/* ====================================================================== */

/* Assorted bits stolen straight from Stunnel (ssl.c) that we might need */

/* Global SSL context shared by client iostreams */
static SSL_CTX *client_ctx;

/* Identifier string used by both context */
static unsigned char *sid_ctx = (unsigned char *) "Mailchk SID";

/* RSA key length */
#define SSL_RSA_KEYLENGTH   (1024)

/* Cipher list */
#define SSLCIPHERLIST "ALL:!LOW"

/* ====================================================================== */

/* PRNG stuff for SSL */

/* shortcut to determine if sufficient entropy for PRNG is present */
static int prng_seeded(int bytes)
{
    if (RAND_status()) {
        log_debug("RAND_status claims sufficient entropy for the PRNG\n");
        return (1);
    }
    return (0);                 /* assume we don't have enough */
}

static int add_rand_file(char *filename)
{
    int readbytes;
    struct stat sb;

    if (stat(filename, &sb) != 0) {
        return (0);
    }

    if ((readbytes = RAND_load_file(filename, 2048))) {
        log_debug("Snagged %lu random bytes from %s\n",
                  (unsigned long) readbytes, filename);
    } else {
        log_misc("Unable to retrieve any random data from %s\n", filename);
    }
    return (readbytes);
}

static void os_initialize_prng()
{
    int totbytes = 0;

    /* Try the good-old default /dev/urandom, if available  */
    totbytes += add_rand_file("/dev/urandom");
    if (prng_seeded(totbytes)) {
        goto SEEDED;
    }

    /* Random file specified during configure */

    log_fatal("PRNG seeded with %lu bytes total (insufficent)\n",
              (unsigned long) totbytes);
    exit(1);

  SEEDED:
    log_debug("PRNG seeded successfully\n");
    return;
}

/* ====================================================================== */

/* ====================================================================== */

static RSA *rsa_tmp = NIL;      /* temporary RSA key    */
static time_t rsa_timeout = (time_t) 0; /* Timeout for this key */

/* ssl_make_rsakey() *****************************************************
 *
 * Set up RSAkey
 ************************************************************************/

static void ssl_make_rsakey()
{
    time_t now = time(NIL);

    log_debug("Generating fresh RSA key");

    if (rsa_tmp)
        RSA_free(rsa_tmp);

    if (!
        (rsa_tmp =
         RSA_generate_key(SSL_RSA_KEYLENGTH, RSA_F4, NULL, NULL)))
        log_fatal("tmp_rsa_cb");

    log_debug("Generated fresh RSA key");

    rsa_timeout = now + (30*60);   /* 30 minutes */
}

/* ssl_init_rsakey() *****************************************************
 *
 * Initialise RSAkey stuff 
 ************************************************************************/

static void ssl_init_rsakey()
{
    ssl_make_rsakey();
}


/* ssl_check_rsakey() *****************************************************
 *
 * Generate fresh RSAkey if existing key has expired.
 *************************************************************************/

void ssl_check_rsakey()
{
    time_t now = time(NIL);

    if (!rsa_tmp || (rsa_timeout != (time_t) 0L) || (rsa_timeout < now))
        ssl_make_rsakey();
}

/* ====================================================================== */

/* A pair of OpenSSL callbacks */

static RSA *rsa_callback(SSL * s, int export, int keylen)
{
    ssl_check_rsakey();

    log_debug("rsa_callback(): Requested %lu bit key", keylen);
    return rsa_tmp;
}

static void info_callback(const SSL * s, int where, int ret)
{
}

/* ====================================================================== */

/* ssl_context_init() ****************************************************
 *
 * Initialise SSL "context"es: one for server size activity and one for
 * client side activity.
 ************************************************************************/

void ssl_context_init()
{
    /* Set up random number generator */
    os_initialize_prng();

    SSLeay_add_ssl_algorithms();
    SSL_load_error_strings();

    /* Set up client context */
    client_ctx = SSL_CTX_new(SSLv3_client_method());
    SSL_CTX_set_session_cache_mode(client_ctx, SSL_SESS_CACHE_BOTH);
    SSL_CTX_set_info_callback(client_ctx, info_callback);
    SSL_CTX_set_mode(client_ctx, SSL_MODE_AUTO_RETRY);

    if (SSL_CTX_need_tmp_RSA(client_ctx)) {
        SSL_CTX_set_tmp_rsa_callback(client_ctx, rsa_callback);
    }

    /* Don't bother with session cache for client side: not enough
     * connections to worry about caching */
    SSL_CTX_set_session_cache_mode(client_ctx, SSL_SESS_CACHE_OFF);
    SSL_CTX_set_timeout(client_ctx, 0);

    /* Set cipherlist */
    if (!SSL_CTX_set_cipher_list(client_ctx, SSLCIPHERLIST))
        log_fatal("SSL_CTX_set_cipher_list");

#if 1
    /* Happens on demand here */
    /* Initialise RSA temporary key (will take a couple of secs to complete) */
    ssl_init_rsakey();
#endif
}

void ssl_context_free()
{
    SSL_CTX_free(client_ctx);
}

void ssl_shutdown(void *ssl)
{
    SSL_shutdown((SSL *) ssl);
}

int ssl_get_error(void *ssl, int code)
{
    return (SSL_get_error((SSL *) ssl, code));
}

void ssl_free(void *ssl)
{
    SSL_free((SSL *) ssl);
    ERR_remove_state(0);
}

/* ====================================================================== */

int
os_socket_nonblocking(int sockfd)
{
    int mode;

    mode = fcntl(sockfd, F_GETFL, 0);
    mode |= O_NDELAY;

    if (fcntl(sockfd, F_SETFL, mode) != 0) {
        log_fatal("[os_socket_nonblocking()] fcntl() failed: %s",
                  strerror(errno));
    }
    return (T);
}

/* ssl_start_client() ****************************************************
 *
 * Start client side SSL
 ************************************************************************/

void *ssl_start_client(int fd, unsigned long timeout)
{
    SSL *ssl;
    SSL_CIPHER *c;
    char *ver;
    int bits;

    if (!(ssl = (void *) SSL_new(client_ctx)))
        return (NIL);

    SSL_set_session_id_context((SSL *) ssl, sid_ctx,
                               strlen((char *) sid_ctx));

    SSL_set_fd((SSL *) ssl, fd);
    SSL_set_connect_state((SSL *) ssl);

    if (SSL_connect((SSL *) ssl) <= 0)
        return (NIL);

    /* Verify certificate here? Need local context to play with? */

    switch (((SSL *) ssl)->session->ssl_version) {
    case SSL2_VERSION:
        ver = "SSLv2";
        break;
    case SSL3_VERSION:
        ver = "SSLv3";
        break;
    case TLS1_VERSION:
        ver = "TLSv1";
        break;
    default:
        ver = "UNKNOWN";
    }
    c = SSL_get_current_cipher((SSL *) ssl);
    SSL_CIPHER_get_bits(c, &bits);
    log_debug("Opened client connection with %s, cipher %s (%lu bits)\n",
              ver, SSL_CIPHER_get_name(c), (unsigned long) bits);

    /* Put underlying socket in non-blocking mode: stops occasional
     * deadlocks where select() timeout preferred */
    os_socket_nonblocking(fd);

    return ((void *) ssl);
}

/* ====================================================================== */

/* ssl_read() ************************************************************
 *
 * read() from SSL pipe:
 *    ssl     - SSL abstraction
 *  buffer    - Buffer to read into
 *  blocksize - Size of buffer
 *
 * Returns: Numbers of bytes read. 0 => EOF, 
 *          -1 => error (SSL_MAILCHK_RETRY or SSL_MAILCHK_ERROR)
 ************************************************************************/

int ssl_read(void *ssl, unsigned char *buffer, unsigned long blocksize)
{
    int rc = SSL_read((SSL *) ssl, (char *) buffer, blocksize);

    switch (SSL_get_error((SSL *) ssl, rc)) {
    case SSL_ERROR_NONE:
        return (rc);
    case SSL_ERROR_ZERO_RETURN:
        return (0);
    case SSL_ERROR_WANT_READ:
        return (SSL_MAILCHK_RETRY);
    default:
        return (SSL_MAILCHK_ERROR);
    }
}

/* ssl_write() ***********************************************************
 *
 * write() to SSL pipe:
 *    ssl  - SSL abstraction
 *  buffer - Buffer to write from
 *  bytes  - Number of bytes to write
 *
 * Returns: Numbers of bytes written. -1 => error
 ************************************************************************/

int ssl_write(void *ssl, unsigned char *buffer, unsigned long bytes)
{
    int rc = SSL_write((SSL *) ssl, (char *) buffer, bytes);

    switch (SSL_get_error((SSL *) ssl, rc)) {
    case SSL_ERROR_NONE:
        return (rc);
    case SSL_ERROR_ZERO_RETURN:
        return (0);
    case SSL_ERROR_WANT_WRITE:
        return (SSL_MAILCHK_RETRY);
    default:
        return (SSL_MAILCHK_ERROR);
    }
}

/* ssl_pending()**********************************************************
 *
 * Check for pending input on SSL pipe.
 ************************************************************************/

int ssl_pending(void *ssl)
{
    return (SSL_pending((SSL *) ssl));
}
