Skip to content

Commit

Permalink
Merge pull request #239 from cachix/ping-pong
Browse files Browse the repository at this point in the history
Add ping-pong implementation that handles stale connections.
  • Loading branch information
domenkozar authored Dec 27, 2023
2 parents f238132 + 24d7f56 commit 680a5e0
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 59 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# CHANGELOG

- 0.13.0.0 (xxx)
* Introduce `Network.WebSockets.Connection.PingPong` to
handle ping pong for any Connection, be it Client or Server.
* Remove `serverRequirePong` option in favor of the new implementation.

- 0.12.7.3 (2021-10-26)
* Bump `attoparsec` dependency upper bound to 0.15

Expand Down
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ server and client in Haskell.
## Features

- Provides Server/Client implementations of the websocket protocol
- Ping/Pong building blocks for stale connection checking
- `withPingPong` helper for stale connection checking
- TLS support via [wuss](https://hackage.haskell.org/package/wuss) package

## Caveats

- [Doesn't implement client ping/pong](https://github.com/jaspervdj/websockets/issues/159)
- [Send doesn't support streaming](https://github.com/jaspervdj/websockets/issues/119)
- [`send` doesn't support streaming](https://github.com/jaspervdj/websockets/issues/119)
- [Requires careful handling of exceptions](https://github.com/jaspervdj/websockets/issues/48)
- [DeflateCompression isn't thread-safe](https://github.com/jaspervdj/websockets/issues/208)

Expand Down
4 changes: 4 additions & 0 deletions src/Network/WebSockets.hs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ module Network.WebSockets
, newClientConnection

-- * Utilities
, PingPongOptions(..)
, defaultPingPongOptions
, withPingPong
, withPingThread
, forkPingThread
) where
Expand All @@ -91,6 +94,7 @@ module Network.WebSockets
--------------------------------------------------------------------------------
import Network.WebSockets.Client
import Network.WebSockets.Connection
import Network.WebSockets.Connection.PingPong
import Network.WebSockets.Http
import Network.WebSockets.Server
import Network.WebSockets.Types
3 changes: 3 additions & 0 deletions src/Network/WebSockets/Client.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ module Network.WebSockets.Client
--------------------------------------------------------------------------------
import qualified Data.ByteString.Builder as Builder
import Control.Exception (bracket, finally, throwIO)
import Control.Concurrent.MVar (newEmptyMVar)
import Control.Monad (void)
import Data.IORef (newIORef)
import qualified Data.Text as T
Expand Down Expand Up @@ -157,12 +158,14 @@ streamToClientConnection stream opts = do
(connectionMessageDataSizeLimit opts) stream
write <- encodeMessages protocol ClientConnection stream
sentRef <- newIORef False
heartbeat <- newEmptyMVar
return $ Connection
{ connectionOptions = opts
, connectionType = ClientConnection
, connectionProtocol = protocol
, connectionParse = parse
, connectionWrite = write
, connectionHeartbeat = heartbeat
, connectionSentClose = sentRef
}
where
Expand Down
15 changes: 12 additions & 3 deletions src/Network/WebSockets/Connection.hs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import Control.Applicative ((<$>))
import Control.Concurrent (forkIO,
threadDelay)
import qualified Control.Concurrent.Async as Async
import Control.Concurrent.MVar (MVar, newEmptyMVar, tryPutMVar)
import Control.Exception (AsyncException,
fromException,
handle,
Expand Down Expand Up @@ -179,13 +180,15 @@ acceptRequestWith pc ar = case find (flip compatible request) protocols of
write <- foldM (\x ext -> extWrite ext x) writeRaw exts
parse <- foldM (\x ext -> extParse ext x) parseRaw exts

sentRef <- newIORef False
sentRef <- newIORef False
heartbeat <- newEmptyMVar
let connection = Connection
{ connectionOptions = options
, connectionType = ServerConnection
, connectionProtocol = protocol
, connectionParse = parse
, connectionWrite = write
, connectionHeartbeat = heartbeat
, connectionSentClose = sentRef
}

Expand Down Expand Up @@ -252,6 +255,9 @@ data Connection = Connection
{ connectionOptions :: !ConnectionOptions
, connectionType :: !ConnectionType
, connectionProtocol :: !Protocol
, connectionHeartbeat :: !(MVar ())
-- ^ This MVar is filled whenever a pong is received. This is used by
-- 'withPingPong' to timeout the connection if a pong is not received.
, connectionParse :: !(IO (Maybe Message))
, connectionWrite :: !([Message] -> IO ())
, connectionSentClose :: !(IORef Bool)
Expand Down Expand Up @@ -294,6 +300,7 @@ receiveDataMessage conn = do
unless hasSentClose $ send conn msg
throwIO $ CloseRequest i closeMsg
Pong _ -> do
_ <- tryPutMVar (connectionHeartbeat conn) ()
connectionOnPong (connectionOptions conn)
receiveDataMessage conn
Ping pl -> do
Expand Down Expand Up @@ -401,6 +408,9 @@ sendPong conn = send conn . ControlMessage . Pong . toLazyByteString
-- This is useful to keep idle connections open through proxies and whatnot.
-- Many (but not all) proxies have a 60 second default timeout, so based on that
-- sending a ping every 30 seconds is a good idea.
--
-- Note that usually you want to use 'Network.WebSockets.Connection.PingPong.withPingPong'
-- to timeout the connection if a pong is not received.
withPingThread
:: Connection
-> Int -- ^ Second interval in which pings should be sent.
Expand All @@ -410,7 +420,6 @@ withPingThread
withPingThread conn n action app =
Async.withAsync (pingThread conn n action) (\_ -> app)


--------------------------------------------------------------------------------
-- | DEPRECATED: Use 'withPingThread' instead.
--
Expand Down Expand Up @@ -445,4 +454,4 @@ pingThread conn n action

ignore e = case fromException e of
Just async -> throwIO (async :: AsyncException)
Nothing -> return ()
Nothing -> return ()
62 changes: 62 additions & 0 deletions src/Network/WebSockets/Connection/PingPong.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
module Network.WebSockets.Connection.PingPong
( withPingPong
, PingPongOptions(..)
, PongTimeout(..)
, defaultPingPongOptions
) where

import Control.Concurrent.Async as Async
import Control.Exception
import Control.Monad (void)
import Network.WebSockets.Connection (Connection, connectionHeartbeat, pingThread)
import Control.Concurrent.MVar (takeMVar)
import System.Timeout (timeout)


-- | Exception type used to kill connections if there
-- is a pong timeout.
data PongTimeout = PongTimeout deriving Show

instance Exception PongTimeout


-- | Options for ping-pong
--
-- Make sure that the ping interval is less than the pong timeout,
-- for example N/2.
data PingPongOptions = PingPongOptions {
pingInterval :: Int, -- ^ Interval in seconds
pongTimeout :: Int, -- ^ Timeout in seconds
pingAction :: IO () -- ^ Action to perform after sending a ping
}

-- | Default options for ping-pong
--
-- Ping every 15 seconds, timeout after 30 seconds
defaultPingPongOptions :: PingPongOptions
defaultPingPongOptions = PingPongOptions {
pingInterval = 15,
pongTimeout = 30,
pingAction = return ()
}

-- | Run an application with ping-pong enabled. Raises PongTimeout if a pong is not received.
--
-- Can used with Client and Server connections.
withPingPong :: PingPongOptions -> Connection -> (Connection -> IO ()) -> IO ()
withPingPong options connection app = void $
withAsync (app connection) $ \appAsync -> do
withAsync (pingThread connection (pingInterval options) (pingAction options)) $ \pingAsync -> do
withAsync (heartbeat >> throwIO PongTimeout) $ \heartbeatAsync -> do
waitAnyCancel [appAsync, pingAsync, heartbeatAsync]
where
heartbeat = whileJust $ timeout (pongTimeout options * 1000 * 1000)
$ takeMVar (connectionHeartbeat connection)

-- Loop until action returns Nothing
whileJust :: IO (Maybe a) -> IO ()
whileJust action = do
result <- action
case result of
Nothing -> return ()
Just _ -> whileJust action
62 changes: 9 additions & 53 deletions src/Network/WebSockets/Server.hs
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,17 @@ module Network.WebSockets.Server


--------------------------------------------------------------------------------
import Control.Concurrent (takeMVar, tryPutMVar,
newEmptyMVar)
import qualified Control.Concurrent.Async as Async
import Control.Exception (Exception, bracket,
import Control.Exception (bracket,
bracketOnError, finally, mask_,
throwIO)
import Network.Socket (Socket)
import qualified Network.Socket as S
import System.Timeout (timeout)


--------------------------------------------------------------------------------
import Network.WebSockets.Connection
import Network.WebSockets.Connection.PingPong (PongTimeout(..))
import Network.WebSockets.Http
import qualified Network.WebSockets.Stream as Stream
import Network.WebSockets.Types
Expand Down Expand Up @@ -83,10 +81,6 @@ data ServerOptions = ServerOptions
{ serverHost :: String
, serverPort :: Int
, serverConnectionOptions :: ConnectionOptions
-- | Require a pong from the client every N seconds; otherwise kill the
-- connection. If you use this, you should also use 'withPingThread' to
-- send a ping at a smaller interval; for example N/2.
, serverRequirePong :: Maybe Int
}


Expand All @@ -96,7 +90,6 @@ defaultServerOptions = ServerOptions
{ serverHost = "127.0.0.1"
, serverPort = 8080
, serverConnectionOptions = defaultConnectionOptions
, serverRequirePong = Nothing
}


Expand All @@ -109,43 +102,16 @@ runServerWithOptions :: ServerOptions -> ServerApp -> IO a
runServerWithOptions opts app = S.withSocketsDo $
bracket
(makeListenSocket (serverHost opts) (serverPort opts))
S.close $ \sock -> do
let connOpts = serverConnectionOptions opts

connThread conn = case serverRequirePong opts of
Nothing -> runApp conn connOpts app
Just grace -> do
heartbeat <- newEmptyMVar

let -- Update the connection options to perform a heartbeat
-- whenever a pong is received.
connOpts' = connOpts
{ connectionOnPong = do
_ <- tryPutMVar heartbeat ()
connectionOnPong connOpts
}

whileJust io = do
result <- io
case result of
Nothing -> return ()
Just _ -> whileJust io

-- Runs until a pong was not received within the grace
-- period.
heart = whileJust $ timeout (grace * 1000000) (takeMVar heartbeat)

Async.race_
(runApp conn connOpts' app)
(heart >> throwIO PongTimeout)

S.close
(\sock ->
let
mainThread = do
(conn, _) <- S.accept sock
Async.withAsyncWithUnmask
(\unmask -> unmask (connThread conn) `finally` S.close conn)
(\unmask -> unmask (runApp conn (serverConnectionOptions opts) app) `finally` S.close conn)
(\_ -> mainThread)

mask_ mainThread
in mask_ mainThread
)


--------------------------------------------------------------------------------
Expand Down Expand Up @@ -205,14 +171,4 @@ makePendingConnectionFromStream stream opts = do
, pendingRequest = request
, pendingOnAccept = \_ -> return ()
, pendingStream = stream
}


--------------------------------------------------------------------------------
-- | Internally used exception type used to kill connections if there
-- is a pong timeout.
data PongTimeout = PongTimeout deriving Show


--------------------------------------------------------------------------------
instance Exception PongTimeout
}
2 changes: 2 additions & 0 deletions websockets.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ Library
Network.WebSockets
Network.WebSockets.Client
Network.WebSockets.Connection
Network.WebSockets.Connection.PingPong
Network.WebSockets.Extensions
Network.WebSockets.Stream
-- Network.WebSockets.Util.PubSub TODO
Expand Down Expand Up @@ -108,6 +109,7 @@ Test-suite websockets-tests
Network.WebSockets.Client
Network.WebSockets.Connection
Network.WebSockets.Connection.Options
Network.WebSockets.Connection.PingPong
Network.WebSockets.Extensions
Network.WebSockets.Extensions.Description
Network.WebSockets.Extensions.PermessageDeflate
Expand Down

0 comments on commit 680a5e0

Please sign in to comment.