From e6bb39a2d3470239b2e35de4a7c3c412700d09f2 Mon Sep 17 00:00:00 2001
From: Sergey Lyubka <valenok@gmail.com>
Date: Sat, 28 Jul 2012 18:57:04 +0100
Subject: [PATCH] Using client-side SSL context for mg_connect()

---
 mongoose.c | 75 ++++++++++++++++++++++++++++++++----------------------
 mongoose.h |  4 +++
 2 files changed, 49 insertions(+), 30 deletions(-)

diff --git a/mongoose.c b/mongoose.c
index ece174d..30eac15 100644
--- a/mongoose.c
+++ b/mongoose.c
@@ -280,6 +280,7 @@ extern int SSL_set_fd(SSL *, int);
 extern SSL *SSL_new(SSL_CTX *);
 extern SSL_CTX *SSL_CTX_new(SSL_METHOD *);
 extern SSL_METHOD *SSLv23_server_method(void);
+extern SSL_METHOD *SSLv23_client_method(void);
 extern int SSL_library_init(void);
 extern void SSL_load_error_strings(void);
 extern int SSL_CTX_use_PrivateKey_file(SSL_CTX *, const char *, int);
@@ -320,6 +321,7 @@ struct ssl_func {
 #define SSL_load_error_strings (* (void (*)(void)) ssl_sw[15].ptr)
 #define SSL_CTX_use_certificate_chain_file \
   (* (int (*)(SSL_CTX *, const char *)) ssl_sw[16].ptr)
+#define SSLv23_client_method (* (SSL_METHOD * (*)(void)) ssl_sw[17].ptr)
 
 #define CRYPTO_num_locks (* (int (*)(void)) crypto_sw[0].ptr)
 #define CRYPTO_set_locking_callback \
@@ -351,6 +353,7 @@ static struct ssl_func ssl_sw[] = {
   {"SSL_CTX_free",  NULL},
   {"SSL_load_error_strings", NULL},
   {"SSL_CTX_use_certificate_chain_file", NULL},
+  {"SSLv23_client_method", NULL},
   {NULL,    NULL}
 };
 
@@ -447,6 +450,7 @@ static const char *config_options[] = {
 struct mg_context {
   volatile int stop_flag;       // Should we stop event loop
   SSL_CTX *ssl_ctx;             // SSL context
+  SSL_CTX *client_ssl_ctx;      // Client SSL context
   char *config[NUM_OPTIONS];    // Mongoose configuration parameters
   mg_callback_t user_callback;  // User-defined callback function
   void *user_data;              // User-defined data
@@ -1608,8 +1612,8 @@ static int convert_uri_to_file_name(struct mg_connection *conn, char *buf,
   return stat_result;
 }
 
-static int sslize(struct mg_connection *conn, int (*func)(SSL *)) {
-  return (conn->ssl = SSL_new(conn->ctx->ssl_ctx)) != NULL &&
+static int sslize(struct mg_connection *conn, SSL_CTX *s, int (*func)(SSL *)) {
+  return (conn->ssl = SSL_new(s)) != NULL &&
     SSL_set_fd(conn->ssl, conn->client.sock) == 1 &&
     func(conn->ssl) == 1;
 }
@@ -3485,7 +3489,8 @@ static int set_ports_option(struct mg_context *ctx) {
       cry(fc(ctx), "%s: %.*s: invalid port spec. Expecting list of: %s",
           __func__, vec.len, vec.ptr, "[IP_ADDRESS:]PORT[s|p]");
       success = 0;
-    } else if (so.is_ssl && ctx->ssl_ctx == NULL) {
+    } else if (so.is_ssl &&
+               (ctx->ssl_ctx == NULL || ctx->config[SSL_CERTIFICATE] == NULL)) {
       cry(fc(ctx), "Cannot add SSL socket, is -ssl_certificate option set?");
       success = 0;
     } else if ((sock = socket(so.lsa.sa.sa_family, SOCK_STREAM, 6)) ==
@@ -3720,15 +3725,10 @@ static int load_dll(struct mg_context *ctx, const char *dll_name,
 // Dynamically load SSL library. Set up ctx->ssl_ctx pointer.
 static int set_ssl_option(struct mg_context *ctx) {
   struct mg_request_info request_info;
-  SSL_CTX *CTX;
   int i, size;
   const char *pem = ctx->config[SSL_CERTIFICATE];
   const char *chain = ctx->config[SSL_CHAIN_FILE];
 
-  if (pem == NULL) {
-    return 1;
-  }
-
 #if !defined(NO_SSL_DL)
   if (!load_dll(ctx, SSL_LIB, ssl_sw) ||
       !load_dll(ctx, CRYPTO_LIB, crypto_sw)) {
@@ -3740,27 +3740,31 @@ static int set_ssl_option(struct mg_context *ctx) {
   SSL_library_init();
   SSL_load_error_strings();
 
-  if ((CTX = SSL_CTX_new(SSLv23_server_method())) == NULL) {
+  if ((ctx->client_ssl_ctx = SSL_CTX_new(SSLv23_client_method())) == NULL) {
+    cry(fc(ctx), "SSL_CTX_new error: %s", ssl_error());
+  }
+
+  if ((ctx->ssl_ctx = SSL_CTX_new(SSLv23_server_method())) == NULL) {
     cry(fc(ctx), "SSL_CTX_new error: %s", ssl_error());
   } else if (ctx->user_callback != NULL) {
     memset(&request_info, 0, sizeof(request_info));
     request_info.user_data = ctx->user_data;
-    ctx->user_callback(MG_INIT_SSL, (struct mg_connection *) CTX,
+    ctx->user_callback(MG_INIT_SSL, (struct mg_connection *) ctx->ssl_ctx,
                        &request_info);
   }
 
-  if (CTX != NULL && SSL_CTX_use_certificate_file(CTX, pem,
-        SSL_FILETYPE_PEM) == 0) {
+  if (ctx->ssl_ctx != NULL && pem != NULL &&
+      SSL_CTX_use_certificate_file(ctx->ssl_ctx, pem, SSL_FILETYPE_PEM) == 0) {
     cry(fc(ctx), "%s: cannot open %s: %s", __func__, pem, ssl_error());
     return 0;
-  } else if (CTX != NULL && SSL_CTX_use_PrivateKey_file(CTX, pem,
-        SSL_FILETYPE_PEM) == 0) {
+  }
+  if (ctx->ssl_ctx != NULL && pem != NULL &&
+      SSL_CTX_use_PrivateKey_file(ctx->ssl_ctx, pem, SSL_FILETYPE_PEM) == 0) {
     cry(fc(ctx), "%s: cannot open %s: %s", NULL, pem, ssl_error());
     return 0;
   }
-
-  if (CTX != NULL && chain != NULL &&
-      SSL_CTX_use_certificate_chain_file(CTX, chain) == 0) {
+  if (ctx->ssl_ctx != NULL && chain != NULL &&
+      SSL_CTX_use_certificate_chain_file(ctx->ssl_ctx, chain) == 0) {
     cry(fc(ctx), "%s: cannot open %s: %s", NULL, chain, ssl_error());
     return 0;
   }
@@ -3780,9 +3784,6 @@ static int set_ssl_option(struct mg_context *ctx) {
   CRYPTO_set_locking_callback(&ssl_locking_callback);
   CRYPTO_set_id_callback(&ssl_id_callback);
 
-  // Done with everything. Save the context.
-  ctx->ssl_ctx = CTX;
-
   return 1;
 }
 
@@ -3857,6 +3858,11 @@ static void close_connection(struct mg_connection *conn) {
   }
 }
 
+void mg_close_connection(struct mg_connection *conn) {
+  close_connection(conn);
+  free(conn);
+}
+
 struct mg_connection *mg_connect(struct mg_context *ctx,
                                  const char *host, int port, int use_ssl) {
   struct mg_connection *newconn = NULL;
@@ -3864,7 +3870,7 @@ struct mg_connection *mg_connect(struct mg_context *ctx,
   struct hostent *he;
   int sock;
 
-  if (ctx->ssl_ctx == NULL && use_ssl) {
+  if (ctx->client_ssl_ctx == NULL && use_ssl) {
     cry(fc(ctx), "%s: SSL is not initialized", __func__);
   } else if ((he = gethostbyname(host)) == NULL) {
     cry(fc(ctx), "%s: gethostbyname(%s): %s", __func__, host, strerror(ERRNO));
@@ -3883,11 +3889,12 @@ struct mg_connection *mg_connect(struct mg_context *ctx,
       cry(fc(ctx), "%s: calloc: %s", __func__, strerror(ERRNO));
       closesocket(sock);
     } else {
+      newconn->ctx = ctx;
       newconn->client.sock = sock;
       newconn->client.rsa.sin = sin;
       newconn->client.is_ssl = use_ssl;
       if (use_ssl) {
-        sslize(newconn, SSL_connect);
+        sslize(newconn, ctx->client_ssl_ctx, SSL_connect);
       }
     }
   }
@@ -3898,18 +3905,24 @@ struct mg_connection *mg_connect(struct mg_context *ctx,
 FILE *mg_fetch(struct mg_context *ctx, const char *url, const char *path,
                struct mg_request_info *ri) {
   struct mg_connection *newconn;
-  int n, req_length, data_length = 0, port = 80;
+  int n, req_length, data_length, port;
   char host[1025], proto[10], buf[16384];
   FILE *fp = NULL;
 
-  if (sscanf(url, "%9[htps]://%1024[^:]:%d/%n", proto, host, &port, &n) != 3 &&
-      sscanf(url, "%9[htps]://%1024[^/]/%n", proto, host, &n) != 2) {
+  if (sscanf(url, "%9[htps]://%1024[^:]:%d/%n", proto, host, &port, &n) == 3) {
+  } else if (sscanf(url, "%9[htps]://%1024[^/]/%n", proto, host, &n) == 2) {
+    port = mg_strcasecmp(proto, "https") == 0 ? 443 : 80;
+  } else {
     cry(fc(ctx), "%s: invalid URL: [%s]", __func__, url);
-  } else if ((newconn = mg_connect(ctx, host, port,
-                                   !strcmp(proto, "https"))) == NULL) {
+    return NULL;
+  }
+
+  if ((newconn = mg_connect(ctx, host, port,
+                            !strcmp(proto, "https"))) == NULL) {
     cry(fc(ctx), "%s: mg_connect(%s): %s", __func__, url, strerror(ERRNO));
   } else {
     mg_printf(newconn, "GET /%s HTTP/1.0\r\n\r\n", url + n);
+    data_length = 0;
     req_length = read_request(NULL, newconn->client.sock,
                               newconn->ssl, buf, sizeof(buf), &data_length);
     if (req_length <= 0) {
@@ -3930,8 +3943,7 @@ FILE *mg_fetch(struct mg_context *ctx, const char *url, const char *path,
         data_length = mg_read(newconn, buf, sizeof(buf));
       } while (data_length > 0);
     }
-    close_connection(newconn);
-    free(newconn);
+    mg_close_connection(newconn);
   }
 
   return fp;
@@ -4074,7 +4086,7 @@ static void worker_thread(struct mg_context *ctx) {
     conn->request_info.is_ssl = conn->client.is_ssl;
 
     if (!conn->client.is_ssl ||
-        (conn->client.is_ssl && sslize(conn, SSL_accept))) {
+        (conn->client.is_ssl && sslize(conn, conn->ctx->ssl_ctx, SSL_accept))) {
       process_new_connection(conn);
     }
 
@@ -4226,6 +4238,9 @@ static void free_context(struct mg_context *ctx) {
   if (ctx->ssl_ctx != NULL) {
     SSL_CTX_free(ctx->ssl_ctx);
   }
+  if (ctx->client_ssl_ctx != NULL) {
+    SSL_CTX_free(ctx->client_ssl_ctx);
+  }
 #ifndef NO_SSL
   if (ssl_mutexes != NULL) {
     free(ssl_mutexes);
diff --git a/mongoose.h b/mongoose.h
index 3a4fdfe..c25753a 100644
--- a/mongoose.h
+++ b/mongoose.h
@@ -226,6 +226,10 @@ struct mg_connection *mg_connect(struct mg_context *ctx,
                                  const char *host, int port, int use_ssl);
 
 
+// Close the connection opened by mg_connect().
+void mg_close_connection(struct mg_connection *conn);
+
+
 // Download given URL to a given file.
 //   url: URL to download
 //   path: file name where to save the data
-- 
2.18.1