From 8f0b1adb19d164bc8220c8d3ed036156caed3d5d Mon Sep 17 00:00:00 2001
From: John Safranek <john@wolfssl.com>
Date: Mon, 3 Jun 2024 14:43:57 -0700
Subject: [PATCH] wolfSSHd Connection Closure

1. Initialize all the fds to -1.
2. Add flags for peerConnected and stdoutEmpty.
3. Remove the idle counter.
4. When the socket would block on write, set a flag to check the socket for
   writing later to call the worker which will send pending data.
5. When reading the pipes, a 0 returns means the pipe is closed. Deal
   with that.
6. If the ssh write fails, interrupt the subordinate process.
7. When waiting for the peer to close its channel and shutdown, sleep
   for 100ms, rather than 1us. It takes a little while to tear down.
8. Shutdown the peer socket. Spin on receiving the peer socket until it
   closes.
---
 apps/wolfsshd/wolfsshd.c | 88 +++++++++++++++++++++++++++++-----------
 1 file changed, 64 insertions(+), 24 deletions(-)

diff --git a/apps/wolfsshd/wolfsshd.c b/apps/wolfsshd/wolfsshd.c
index 3fbcbf736..ef62b56a7 100644
--- a/apps/wolfsshd/wolfsshd.c
+++ b/apps/wolfsshd/wolfsshd.c
@@ -1158,7 +1158,14 @@ static int SHELL_Subsystem(WOLFSSHD_CONNECTION* conn, WOLFSSH* ssh,
     byte channelBuffer[EXAMPLE_BUFFER_SZ];
     char* forcedCmd;
     int   windowFull = 0;
-    int   idle = 0;
+    int   peerConnected = 1;
+    int   stdoutEmpty = 0;
+
+    childFd = -1;
+    stdoutPipe[0] = -1;
+    stdoutPipe[1] = -1;
+    stderrPipe[0] = -1;
+    stderrPipe[1] = -1;
 
     forcedCmd = wolfSSHD_ConfigGetForcedCmd(usrConf);
 
@@ -1216,6 +1223,8 @@ static int SHELL_Subsystem(WOLFSSHD_CONNECTION* conn, WOLFSSH* ssh,
         if (forcedCmd) {
             close(stdoutPipe[0]);
             close(stderrPipe[0]);
+            stdoutPipe[0] = -1;
+            stderrPipe[0] = -1;
             if (dup2(stdoutPipe[1], STDOUT_FILENO) == -1) {
                 wolfSSH_Log(WS_LOG_ERROR,
                     "[SSHD] Error redirecting stdout pipe");
@@ -1288,9 +1297,7 @@ static int SHELL_Subsystem(WOLFSSHD_CONNECTION* conn, WOLFSSH* ssh,
         setenv("SHELL", pPasswd->pw_shell, 1);
 
         if (pPasswd->pw_shell) {
-            word32 shellSz = (word32)WSTRLEN(pPasswd->pw_shell);
-
-            if (shellSz < sizeof(shell)) {
+            if (WSTRLEN(pPasswd->pw_shell) < sizeof(shell)) {
                 char* cursor;
                 char* start;
 
@@ -1313,12 +1320,11 @@ static int SHELL_Subsystem(WOLFSSHD_CONNECTION* conn, WOLFSSH* ssh,
         }
 
         /* default to /bin/sh if user shell is not set */
-        WMEMSET(cmd, 0, sizeof(cmd));
-        if (XSTRLEN(pPasswd->pw_shell) == 0) {
-            XSNPRINTF(cmd, sizeof(cmd), "%s", "/bin/sh");
+        if (pPasswd->pw_shell && XSTRLEN(pPasswd->pw_shell)) {
+            XSNPRINTF(cmd, sizeof(cmd), "%s", pPasswd->pw_shell);
         }
         else {
-            XSNPRINTF(cmd, sizeof(cmd),"%s", pPasswd->pw_shell);
+            XSNPRINTF(cmd, sizeof(cmd), "%s", "/bin/sh");
         }
 
         errno = 0;
@@ -1387,20 +1393,24 @@ static int SHELL_Subsystem(WOLFSSHD_CONNECTION* conn, WOLFSSH* ssh,
         close(stderrPipe[1]);
     }
 
-    while (idle < MAX_IDLE_COUNT) {
+    while (ChildRunning || windowFull || !stdoutEmpty || peerConnected) {
         byte tmp[2];
         fd_set readFds;
+        fd_set writeFds;
         WS_SOCKET_T maxFd;
         int cnt_r;
         int cnt_w;
         int pending = 0;
 
-        idle++; /* increment idle count, gets reset if not idle */
-
         FD_ZERO(&readFds);
         FD_SET(sshFd, &readFds);
         maxFd = sshFd;
 
+        FD_ZERO(&writeFds);
+        if (windowFull) {
+            FD_SET(sshFd, &writeFds);
+        }
+
         /* select on stdout/stderr pipes with forced commands */
         if (forcedCmd) {
             FD_SET(stdoutPipe[0], &readFds);
@@ -1418,18 +1428,18 @@ static int SHELL_Subsystem(WOLFSSHD_CONNECTION* conn, WOLFSSH* ssh,
         }
 
         if (wolfSSH_stream_peek(ssh, tmp, 1) <= 0) {
-            rc = select((int)maxFd + 1, &readFds, NULL, NULL, NULL);
+            rc = select((int)maxFd + 1, &readFds, &writeFds, NULL, NULL);
             if (rc == -1)
                 break;
         }
         else {
             pending = 1; /* found some pending SSH data */
-            idle    = 0;
         }
 
         if (windowFull || pending || FD_ISSET(sshFd, &readFds)) {
             word32 lastChannel = 0;
 
+            windowFull = 0;
             /* The following tries to read from the first channel inside
                the stream. If the pending data in the socket is for
                another channel, this will return an error with id
@@ -1440,7 +1450,6 @@ static int SHELL_Subsystem(WOLFSSHD_CONNECTION* conn, WOLFSSH* ssh,
             if (cnt_r < 0) {
                 rc = wolfSSH_get_error(ssh);
                 if (rc == WS_CHAN_RXD) {
-                    idle = 0;
                     if (lastChannel == shellChannelId) {
                         cnt_r = wolfSSH_ChannelIdRead(ssh, shellChannelId,
                                 channelBuffer,
@@ -1454,6 +1463,11 @@ static int SHELL_Subsystem(WOLFSSHD_CONNECTION* conn, WOLFSSH* ssh,
                     }
                 }
                 else if (rc == WS_CHANNEL_CLOSED) {
+                    peerConnected = 0;
+                    continue;
+                }
+                else if (rc == WS_WANT_WRITE) {
+                    windowFull = 1;
                     continue;
                 }
                 else if (rc != WS_WANT_READ) {
@@ -1468,7 +1482,10 @@ static int SHELL_Subsystem(WOLFSSHD_CONNECTION* conn, WOLFSSH* ssh,
                     shellBuffer, cnt_r);
             if (cnt_w == WS_WINDOW_FULL) {
                 windowFull = 1;
-                idle = 0;
+                continue;
+            }
+            else if (cnt_w == WS_WANT_WRITE) {
+                windowFull = 1;
                 continue;
             }
             else {
@@ -1489,13 +1506,16 @@ static int SHELL_Subsystem(WOLFSSHD_CONNECTION* conn, WOLFSSH* ssh,
                 }
                 else {
                     if (cnt_r > 0) {
-                        idle = 0;
                         cnt_w = wolfSSH_extended_data_send(ssh, shellBuffer,
                             cnt_r);
                         if (cnt_w == WS_WINDOW_FULL) {
                             windowFull = 1;
                             continue;
                         }
+                        else if (cnt_w == WS_WANT_WRITE) {
+                            windowFull = 1;
+                            continue;
+                        }
                         else if (cnt_w < 0)
                             break;
                     }
@@ -1507,23 +1527,31 @@ static int SHELL_Subsystem(WOLFSSHD_CONNECTION* conn, WOLFSSH* ssh,
                 cnt_r = (int)read(stdoutPipe[0], shellBuffer,
                     sizeof shellBuffer);
                 /* This read will return 0 on EOF */
-                if (cnt_r <= 0) {
+                if (cnt_r < 0) {
                     int err = errno;
                     if (err != EAGAIN && err != 0) {
                         break;
                     }
                 }
+                else if (cnt_r == 0) {
+                    stdoutEmpty = 1;
+                }
                 else {
                     if (cnt_r > 0) {
-                        idle = 0;
                         cnt_w = wolfSSH_ChannelIdSend(ssh, shellChannelId,
                                 shellBuffer, cnt_r);
                         if (cnt_w == WS_WINDOW_FULL) {
                             windowFull = 1;
                             continue;
                         }
-                        else if (cnt_w < 0)
+                        else if (cnt_w == WS_WANT_WRITE) {
+                            windowFull = 1;
+                            continue;
+                        }
+                        else if (cnt_w < 0) {
+                            kill(childPid, SIGINT);
                             break;
+                        }
                     }
                 }
             }
@@ -1540,22 +1568,27 @@ static int SHELL_Subsystem(WOLFSSHD_CONNECTION* conn, WOLFSSH* ssh,
                 }
                 else {
                     if (cnt_r > 0) {
-                        idle = 0;
                         cnt_w = wolfSSH_ChannelIdSend(ssh, shellChannelId,
                                 shellBuffer, cnt_r);
                         if (cnt_w == WS_WINDOW_FULL) {
                             windowFull = 1;
                             continue;
                         }
-                        else if (cnt_w < 0)
+                        else if (cnt_w == WS_WANT_WRITE) {
+                            windowFull = 1;
+                            continue;
+                        }
+                        else if (cnt_w < 0) {
+                            kill(childPid, SIGINT);
                             break;
+                        }
                     }
                 }
             }
         }
 
-        if (ChildRunning && idle) {
-            idle = 0; /* waiting on child process */
+        if (!ChildRunning && peerConnected && stdoutEmpty && !windowFull) {
+            peerConnected = 0;
         }
     }
 
@@ -1868,7 +1901,7 @@ static void* HandleConnection(void* arg)
             #ifdef _WIN32
                 Sleep(1);
             #else
-                usleep(1);
+                usleep(100000);
             #endif
             }
 
@@ -1882,6 +1915,13 @@ static void* HandleConnection(void* arg)
     /* check if there is a response to the shutdown */
     wolfSSH_free(ssh);
     if (conn != NULL) {
+        byte sc[1024];
+        shutdown(conn->fd, 1);
+        /* Spin until socket closes. */
+        do {
+            ret = (int)recv(conn->fd, sc, 1024, 0);
+        } while (ret != 0);
+
         WCLOSESOCKET(conn->fd);
     }
     wolfSSH_Log(WS_LOG_INFO, "[SSHD] Return from closing connection = %d", ret);