From 1a8c19e05a7e01f88d353662afff2c15868f03f2 Mon Sep 17 00:00:00 2001
From: "Jonathan M. Henson" <jonathan.michael.henson@gmail.com>
Date: Fri, 15 Mar 2019 19:37:11 -0700
Subject: [PATCH] =?UTF-8?q?Fixed=20fragmentation=20bugs=20for=20secure=20c?=
 =?UTF-8?q?hannel,=20alpn=20no=20longer=20fails=20the=E2=80=A6=20(#111)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* Fixed fragmentaiton bugs for secure channel, alpn no longer fails the negotiation if one of the endpoints didn't negotiate.

* Fixed bug in iocp socket callback, logging stuff it shouldn't and more fragmentation fixes for securechannel.

* Addressed PR feedback.

* Fixed warning on windows about append_failed
---
 source/windows/iocp/socket.c                |   6 +-
 source/windows/secure_channel_tls_handler.c | 239 +++++++++++---------
 2 files changed, 129 insertions(+), 116 deletions(-)

diff --git a/source/windows/iocp/socket.c b/source/windows/iocp/socket.c
index 6ba427b2a..0fe23a939 100644
--- a/source/windows/iocp/socket.c
+++ b/source/windows/iocp/socket.c
@@ -835,8 +835,8 @@ void s_socket_connection_completion(
             AWS_LOGF_ERROR(
                 AWS_LS_IO_SOCKET,
                 "id=%p handle=%p: connect completion triggered with error %d",
-                (void *)socket_args->socket,
-                (void *)socket_args->socket->io_handle.data.handle,
+                (void *)socket,
+                (void *)socket->io_handle.data.handle,
                 status_code);
             int error = s_determine_socket_error(status_code);
             socket_impl->vtable->connection_error(socket, error);
@@ -958,7 +958,7 @@ static inline int s_tcp_connect(
     }
 
     AWS_LOGF_TRACE(
-        AWS_LS_IO_TLS,
+        AWS_LS_IO_SOCKET,
         "id=%p handle=%p: connection pending, scheduling timeout task",
         (void *)socket,
         (void *)socket->io_handle.data.handle);
diff --git a/source/windows/secure_channel_tls_handler.c b/source/windows/secure_channel_tls_handler.c
index a65dfd720..58caf7673 100644
--- a/source/windows/secure_channel_tls_handler.c
+++ b/source/windows/secure_channel_tls_handler.c
@@ -87,7 +87,10 @@ struct secure_channel_handler {
     struct aws_byte_buf server_name;
     TimeStamp sspi_timestamp;
     int (*s_connection_state_fn)(struct aws_channel_handler *handler);
-    uint8_t buffered_read_in_data[READ_IN_SIZE];
+    /*
+     * Give a little bit of extra head room, for split records.
+     */
+    uint8_t buffered_read_in_data[READ_IN_SIZE + KB_1];
     struct aws_byte_buf buffered_read_in_data_buf;
     size_t estimated_incomplete_size;
     size_t read_extra;
@@ -315,7 +318,8 @@ static int s_fillin_alpn_data(
     struct aws_byte_cursor alpn_buffer_array[4];
     aws_array_list_init_static(&alpn_buffers, alpn_buffer_array, 4, sizeof(struct aws_byte_cursor));
 
-    AWS_LOGF_DEBUG(AWS_LS_IO_TLS, "Setting ALPN extension with string %s.", (const char *)sc_handler->alpn_list);
+    AWS_LOGF_DEBUG(
+        AWS_LS_IO_TLS, "Setting ALPN extension with string %s.", (const char *)aws_string_bytes(sc_handler->alpn_list));
     struct aws_byte_cursor alpn_str_cur = aws_byte_cursor_from_string(sc_handler->alpn_list);
     if (aws_byte_cursor_split_on_char(&alpn_str_cur, ';', &alpn_buffers)) {
         return AWS_OP_ERR;
@@ -454,11 +458,7 @@ static int s_do_server_side_negotiation_step_1(struct aws_channel_handler *handl
 
     size_t data_to_write_len = output_buffer.cbBuffer;
 
-    AWS_LOGF_TRACE(
-        AWS_LS_IO_TLS,
-        "id=%p: Sending ServerHello. Data size %llu",
-        (void *)handler,
-        (unsigned long long)data_to_write_len);
+    AWS_LOGF_TRACE(AWS_LS_IO_TLS, "id=%p: Sending ServerHello. Data size %zu", (void *)handler, data_to_write_len);
     /* send the server hello. */
     struct aws_io_message *outgoing_message = aws_channel_acquire_message_from_pool(
         sc_handler->slot->channel, AWS_IO_MESSAGE_APPLICATION_DATA, data_to_write_len);
@@ -623,15 +623,13 @@ static int s_do_server_side_negotiation_step_2(struct aws_channel_handler *handl
                 AWS_LOGF_DEBUG(
                     AWS_LS_IO_TLS, "id=%p: negotiated protocol %s", handler, (char *)sc_handler->protocol.buffer);
             } else {
-                AWS_LOGF_ERROR(
+                AWS_LOGF_WARN(
                     AWS_LS_IO_TLS,
                     "id=%p: Error retrieving negotiated protocol. SECURITY_STATUS is %d",
                     handler,
                     (int)status);
                 int aws_error = s_determine_sspi_error(status);
                 aws_raise_error(aws_error);
-                s_invoke_negotiation_error(handler, aws_error);
-                return AWS_OP_ERR;
             }
         }
 #endif
@@ -668,7 +666,11 @@ static int s_do_client_side_negotiation_step_1(struct aws_channel_handler *handl
     /* add alpn data to the client hello if it's supported. */
 #ifdef SECBUFFER_APPLICATION_PROTOCOLS
     if (sc_handler->alpn_list && aws_tls_is_alpn_available()) {
-        AWS_LOGF_DEBUG(AWS_LS_IO_TLS, "id=%p: Setting ALPN data as %s", handler, sc_handler->alpn_list);
+        AWS_LOGF_DEBUG(
+            AWS_LS_IO_TLS,
+            "id=%p: Setting ALPN data as %s",
+            handler,
+            (const char *)aws_string_bytes(sc_handler->alpn_list));
         size_t extension_length = 0;
         if (s_fillin_alpn_data(handler, alpn_buffer_data, sizeof(alpn_buffer_data), &extension_length)) {
             s_invoke_negotiation_error(handler, aws_last_error());
@@ -730,10 +732,7 @@ static int s_do_client_side_negotiation_step_1(struct aws_channel_handler *handl
 
     size_t data_to_write_len = output_buffer.cbBuffer;
     AWS_LOGF_TRACE(
-        AWS_LS_IO_TLS,
-        "id=%p: Sending client handshake data of size %llu",
-        (void *)handler,
-        (unsigned long long)data_to_write_len);
+        AWS_LS_IO_TLS, "id=%p: Sending client handshake data of size %zu", (void *)handler, data_to_write_len);
 
     struct aws_io_message *outgoing_message = aws_channel_acquire_message_from_pool(
         sc_handler->slot->channel, AWS_IO_MESSAGE_APPLICATION_DATA, data_to_write_len);
@@ -836,9 +835,9 @@ static int s_do_client_side_negotiation_step_2(struct aws_channel_handler *handl
         sc_handler->estimated_incomplete_size = input_buffers[1].cbBuffer;
         AWS_LOGF_TRACE(
             AWS_LS_IO_TLS,
-            "id=%p: Incomplete buffer recieved. Incomplete size is %llu. Waiting for more data.",
+            "id=%p: Incomplete buffer recieved. Incomplete size is %zu. Waiting for more data.",
             (void *)handler,
-            (unsigned long long)sc_handler->estimated_incomplete_size);
+            sc_handler->estimated_incomplete_size);
         return aws_raise_error(AWS_IO_READ_WOULD_BLOCK);
     }
 
@@ -908,15 +907,13 @@ static int s_do_client_side_negotiation_step_2(struct aws_channel_handler *handl
                 AWS_LOGF_DEBUG(
                     AWS_LS_IO_TLS, "id=%p: Negotiated protocol %s", handler, (char *)sc_handler->protocol.buffer);
             } else {
-                AWS_LOGF_DEBUG(
+                AWS_LOGF_WARN(
                     AWS_LS_IO_TLS,
                     "id=%p: Error retrieving negotiated protocol. SECURITY_STATUS is %d",
                     handler,
                     (int)status);
                 int aws_error = s_determine_sspi_error(status);
                 aws_raise_error(aws_error);
-                s_invoke_negotiation_error(handler, aws_error);
-                return AWS_OP_ERR;
             }
         }
 #endif
@@ -931,73 +928,94 @@ static int s_do_client_side_negotiation_step_2(struct aws_channel_handler *handl
 static int s_do_application_data_decrypt(struct aws_channel_handler *handler) {
     struct secure_channel_handler *sc_handler = handler->impl;
 
-    /* 4 buffers are needed, only one is input, the others get zeroed out for the output operation. */
-    SecBuffer input_buffers[4];
-    AWS_ZERO_ARRAY(input_buffers);
-    input_buffers[0] = (SecBuffer){
-        .cbBuffer = (unsigned long)sc_handler->buffered_read_in_data_buf.len,
-        .pvBuffer = sc_handler->buffered_read_in_data_buf.buffer,
-        .BufferType = SECBUFFER_DATA,
-    };
-
-    SecBufferDesc buffer_desc = {
-        .ulVersion = SECBUFFER_VERSION,
-        .cBuffers = 4,
-        .pBuffers = input_buffers,
-    };
+    /* I know this is an unncessary initialization, it's initialized here to make linters happy.*/
+    int error = AWS_OP_ERR;
+    /* when we get an Extra buffer we have to move the pointer and replay the buffer, so we loop until we don't have
+       any extra buffers left over, in the last phase, we then go ahead and send the output. This state function will
+       always say BLOCKED_ON_READ or SUCCESS. There will never be left over reads.*/
+    do {
+        error = AWS_OP_ERR;
+        /* 4 buffers are needed, only one is input, the others get zeroed out for the output operation. */
+        SecBuffer input_buffers[4];
+        AWS_ZERO_ARRAY(input_buffers);
+
+        size_t read_len = sc_handler->read_extra ? sc_handler->read_extra : sc_handler->buffered_read_in_data_buf.len;
+        size_t offset = sc_handler->read_extra ? sc_handler->buffered_read_in_data_buf.len - sc_handler->read_extra : 0;
+        sc_handler->read_extra = 0;
+
+        input_buffers[0] = (SecBuffer){
+            .cbBuffer = (unsigned long)(read_len),
+            .pvBuffer = sc_handler->buffered_read_in_data_buf.buffer + offset,
+            .BufferType = SECBUFFER_DATA,
+        };
 
-    SECURITY_STATUS status = DecryptMessage(&sc_handler->sec_handle, &buffer_desc, 0, NULL);
+        SecBufferDesc buffer_desc = {
+            .ulVersion = SECBUFFER_VERSION,
+            .cBuffers = 4,
+            .pBuffers = input_buffers,
+        };
 
-    if (status == SEC_E_OK) {
-        /* if SECBUFFER_DATA is the buffer type of the second buffer, we have decrypted data to process.
-           If SECBUFFER_DATA is the type for the fourth buffer we need to keep track of it so we can shift
-           everything before doing another decrypt operation.
-           As far as I can tell, we don't care what's in the third buffer for TLS usage.*/
-        if (input_buffers[1].BufferType == SECBUFFER_DATA) {
-            size_t decrypted_length = input_buffers[1].cbBuffer;
-            AWS_LOGF_TRACE(
-                AWS_LS_IO_TLS,
-                "id=%p: Decrypted message with length %llu.",
-                (void *)handler,
-                (unsigned long long)decrypted_length);
+        SECURITY_STATUS status = DecryptMessage(&sc_handler->sec_handle, &buffer_desc, 0, NULL);
 
-            if (input_buffers[3].BufferType == SECBUFFER_EXTRA) {
-                sc_handler->read_extra = input_buffers[3].cbBuffer;
+        if (status == SEC_E_OK) {
+            /* if SECBUFFER_DATA is the buffer type of the second buffer, we have decrypted data to process.
+               If SECBUFFER_DATA is the type for the fourth buffer we need to keep track of it so we can shift
+               everything before doing another decrypt operation.
+               We don't care what's in the third buffer for TLS usage.*/
+            if (input_buffers[1].BufferType == SECBUFFER_DATA) {
+                size_t decrypted_length = input_buffers[1].cbBuffer;
                 AWS_LOGF_TRACE(
-                    AWS_LS_IO_TLS,
-                    "id=%p: Extra (incomplete) message received with length %llu.",
-                    (void *)handler,
-                    (unsigned long long)sc_handler->read_extra);
+                    AWS_LS_IO_TLS, "id=%p: Decrypted message with length %zu.", (void *)handler, decrypted_length);
+
+                struct aws_byte_cursor to_append =
+                    aws_byte_cursor_from_array(input_buffers[1].pvBuffer, decrypted_length);
+                int append_failed = aws_byte_buf_append(&sc_handler->buffered_read_out_data_buf, &to_append);
+                assert(!append_failed);
+                (void)append_failed;
+
+                /* if we have extra we have to move the pointer and do another Decrypt operation. */
+                if (input_buffers[3].BufferType == SECBUFFER_EXTRA) {
+                    sc_handler->read_extra = input_buffers[3].cbBuffer;
+                    AWS_LOGF_TRACE(
+                        AWS_LS_IO_TLS,
+                        "id=%p: Extra (incomplete) message received with length %zu.",
+                        (void *)handler,
+                        sc_handler->read_extra);
+                } else {
+                    error = AWS_OP_SUCCESS;
+                    /* this means we processed everything in the buffer. */
+                    sc_handler->buffered_read_in_data_buf.len = 0;
+                    AWS_LOGF_TRACE(
+                        AWS_LS_IO_TLS,
+                        "id=%p: Decrypt ended exactly on the end of the record, resetting buffer.",
+                        (void *)handler);
+                }
             }
-
-            assert(
-                decrypted_length <=
-                sc_handler->buffered_read_out_data_buf.capacity - sc_handler->buffered_read_out_data_buf.len);
-            memcpy(
-                sc_handler->buffered_read_out_data_buf.buffer + sc_handler->buffered_read_out_data_buf.len,
-                (uint8_t *)input_buffers[1].pvBuffer,
-                decrypted_length);
-            sc_handler->buffered_read_out_data_buf.len += decrypted_length;
         }
-
-        return AWS_OP_SUCCESS;
         /* SEC_E_INCOMPLETE_MESSAGE means the message we tried to decrypt isn't a full record and we need to
            append our next read to it and try again. */
-    } else if (status == SEC_E_INCOMPLETE_MESSAGE) {
-        sc_handler->estimated_incomplete_size = input_buffers[1].cbBuffer;
-        AWS_LOGF_TRACE(
-            AWS_LS_IO_TLS,
-            "id=%p: (incomplete) message received. Expecting remaining portion of size %llu.",
-            (void *)handler,
-            (unsigned long long)sc_handler->estimated_incomplete_size);
-        return aws_raise_error(AWS_IO_READ_WOULD_BLOCK);
-    } else {
-        AWS_LOGF_ERROR(
-            AWS_LS_IO_TLS, "id=%p: Error decypting message. SECURITY_STATUS is %d.", (void *)handler, (int)status);
-        int aws_error = s_determine_sspi_error(status);
-        aws_raise_error(aws_error);
-        return AWS_OP_ERR;
-    }
+        else if (status == SEC_E_INCOMPLETE_MESSAGE) {
+            sc_handler->estimated_incomplete_size = input_buffers[1].cbBuffer;
+            AWS_LOGF_TRACE(
+                AWS_LS_IO_TLS,
+                "id=%p: (incomplete) message received. Expecting remaining portion of size %zu.",
+                (void *)handler,
+                sc_handler->estimated_incomplete_size);
+            memmove(
+                sc_handler->buffered_read_in_data_buf.buffer,
+                sc_handler->buffered_read_in_data_buf.buffer + offset,
+                read_len);
+            sc_handler->buffered_read_in_data_buf.len = read_len;
+            aws_raise_error(AWS_IO_READ_WOULD_BLOCK);
+        } else {
+            AWS_LOGF_ERROR(
+                AWS_LS_IO_TLS, "id=%p: Error decypting message. SECURITY_STATUS is %d.", (void *)handler, (int)status);
+            int aws_error = s_determine_sspi_error(status);
+            aws_raise_error(aws_error);
+        }
+    } while (sc_handler->read_extra);
+
+    return error;
 }
 
 static int s_process_pending_output_messages(struct aws_channel_handler *handler) {
@@ -1011,18 +1029,14 @@ static int s_process_pending_output_messages(struct aws_channel_handler *handler
 
     AWS_LOGF_TRACE(
         AWS_LS_IO_TLS,
-        "id=%p: Processing incomming messages. Downstream window is %llu",
+        "id=%p: Processing incomming messages. Downstream window is %zu",
         (void *)handler,
-        (unsigned long long)downstream_window);
+        downstream_window);
     while (sc_handler->buffered_read_out_data_buf.len && downstream_window) {
         size_t requested_message_size = sc_handler->buffered_read_out_data_buf.len > downstream_window
                                             ? downstream_window
                                             : sc_handler->buffered_read_out_data_buf.len;
-        AWS_LOGF_TRACE(
-            AWS_LS_IO_TLS,
-            "id=%p: Requested message size is %llu",
-            (void *)handler,
-            (unsigned long long)requested_message_size);
+        AWS_LOGF_TRACE(AWS_LS_IO_TLS, "id=%p: Requested message size is %zu", (void *)handler, requested_message_size);
 
         if (sc_handler->slot->adj_right) {
             struct aws_io_message *read_out_msg = aws_channel_acquire_message_from_pool(
@@ -1054,11 +1068,7 @@ static int s_process_pending_output_messages(struct aws_channel_handler *handler
             }
 
             downstream_window = aws_channel_slot_downstream_read_window(sc_handler->slot);
-            AWS_LOGF_TRACE(
-                AWS_LS_IO_TLS,
-                "id=%p: Downstream window is %llu",
-                (void *)handler,
-                (unsigned long long)downstream_window);
+            AWS_LOGF_TRACE(AWS_LS_IO_TLS, "id=%p: Downstream window is %zu", (void *)handler, downstream_window);
         } else {
             if (sc_handler->on_data_read) {
                 sc_handler->on_data_read(
@@ -1075,6 +1085,7 @@ static void s_process_pending_output_task(struct aws_channel_task *task, void *a
     (void)task;
     struct aws_channel_handler *handler = arg;
 
+    aws_channel_task_init(task, NULL, NULL);
     if (status == AWS_TASK_STATUS_RUN_READY) {
         if (s_process_pending_output_messages(handler)) {
             struct secure_channel_handler *sc_handler = arg;
@@ -1094,9 +1105,9 @@ static int s_process_read_message(
         /* note, most of these functions log internally, so the log messages in this function are sparse. */
         AWS_LOGF_TRACE(
             AWS_LS_IO_TLS,
-            "id=%p: processing incoming message of size %llu",
+            "id=%p: processing incoming message of size %zu",
             (void *)handler,
-            (unsigned long long)message->message_data.len);
+            message->message_data.len);
 
         struct aws_byte_cursor message_cursor = aws_byte_cursor_from_buf(&message->message_data);
 
@@ -1128,6 +1139,12 @@ static int s_process_read_message(
                     /* throw this one as a protocol error. */
                     aws_raise_error(AWS_IO_TLS_ERROR_WRITE_FAILURE);
                 } else {
+                    if (sc_handler->buffered_read_out_data_buf.len) {
+                        err = s_process_pending_output_messages(handler);
+                        if (err) {
+                            break;
+                        }
+                    }
                     /* prevent a deadlock due to downstream handlers wanting more data, but we have an incomplete
                        record, and the amount they're requesting is less than the size of a tls record. */
                     size_t window_size = slot->window_size;
@@ -1153,6 +1170,7 @@ static int s_process_read_message(
                     sc_handler->buffered_read_in_data_buf.buffer + move_pos,
                     sc_handler->read_extra);
                 sc_handler->buffered_read_in_data_buf.len = sc_handler->read_extra;
+                sc_handler->read_extra = 0;
             } else {
                 sc_handler->buffered_read_in_data_buf.len = 0;
             }
@@ -1196,19 +1214,13 @@ static int s_process_write_message(
 
     if (message) {
         AWS_LOGF_TRACE(
-            AWS_LS_IO_TLS,
-            "id=%p: processing ougoing message of size %llu",
-            (void *)handler,
-            (unsigned long long)message->message_data.len);
+            AWS_LS_IO_TLS, "id=%p: processing ougoing message of size %zu", (void *)handler, message->message_data.len);
 
         struct aws_byte_cursor message_cursor = aws_byte_cursor_from_buf(&message->message_data);
 
         while (message_cursor.len) {
             AWS_LOGF_TRACE(
-                AWS_LS_IO_TLS,
-                "id=%p: processing message fragment of size %llu",
-                (void *)handler,
-                (unsigned long long)message_cursor.len);
+                AWS_LS_IO_TLS, "id=%p: processing message fragment of size %zu", (void *)handler, message_cursor.len);
             /* message size will be the lesser of either payload + record overhead or the max TLS record size.*/
             size_t upstream_overhead = aws_channel_slot_upstream_message_overhead(sc_handler->slot);
             size_t requested_length =
@@ -1279,9 +1291,9 @@ static int s_process_write_message(
                                                      sc_handler->stream_sizes.cbTrailer;
                 AWS_LOGF_TRACE(
                     AWS_LS_IO_TLS,
-                    "id=%p:message fragment encrypted successfully: size is %llu",
+                    "id=%p:message fragment encrypted successfully: size is %zu",
                     (void *)handler,
-                    (unsigned long long)outgoing_message->message_data.len);
+                    outgoing_message->message_data.len);
 
                 if (aws_channel_slot_send_message(slot, outgoing_message, AWS_CHANNEL_DIR_WRITE)) {
                     aws_mem_release(outgoing_message->allocator, outgoing_message);
@@ -1308,8 +1320,7 @@ static int s_process_write_message(
 static int s_increment_read_window(struct aws_channel_handler *handler, struct aws_channel_slot *slot, size_t size) {
     (void)size;
     struct secure_channel_handler *sc_handler = handler->impl;
-    AWS_LOGF_TRACE(
-        AWS_LS_IO_TLS, "id=%p: Increment read window message received %llu", (void *)handler, (unsigned long long)size);
+    AWS_LOGF_TRACE(AWS_LS_IO_TLS, "id=%p: Increment read window message received %zu", (void *)handler, size);
 
     if (AWS_UNLIKELY(!sc_handler->stream_sizes.cbMaximumMessage)) {
         SECURITY_STATUS status =
@@ -1333,14 +1344,11 @@ static int s_increment_read_window(struct aws_channel_handler *handler, struct a
     if (total_desired_size > current_window_size) {
         size_t window_update_size = total_desired_size - current_window_size;
         AWS_LOGF_TRACE(
-            AWS_LS_IO_TLS,
-            "id=%p: Propagating read window increment of size %llu",
-            (void *)handler,
-            (unsigned long long)window_update_size);
+            AWS_LS_IO_TLS, "id=%p: Propagating read window increment of size %zu", (void *)handler, window_update_size);
         aws_channel_slot_increment_read_window(slot, window_update_size);
     }
 
-    if (sc_handler->negotiation_finished) {
+    if (sc_handler->negotiation_finished && !sc_handler->sequential_task_storage.task_fn) {
         aws_channel_task_init(&sc_handler->sequential_task_storage, s_process_pending_output_task, handler);
         aws_channel_schedule_task_now(slot->channel, &sc_handler->sequential_task_storage);
     }
@@ -1575,11 +1583,16 @@ static struct aws_channel_handler *s_tls_handler_new(
 
     if (!options->alpn_list && sc_ctx->alpn_list) {
         sc_handler->alpn_list = aws_string_new_from_string(alloc, sc_ctx->alpn_list);
+        if (!sc_handler->alpn_list) {
+            aws_mem_release(alloc, sc_handler);
+            return NULL;
+        }
     } else if (options->alpn_list) {
         sc_handler->alpn_list = aws_string_new_from_string(alloc, options->alpn_list);
-    }
-
-    if (sc_handler->alpn_list) {
+        if (!sc_handler->alpn_list) {
+            aws_mem_release(alloc, sc_handler);
+            return NULL;
+        }
     }
 
     if (options->server_name) {