From 7c82e350afd9c6e3bfe1612f38564d6a3c3ce6ff Mon Sep 17 00:00:00 2001 From: Diogo Biazus Date: Wed, 21 Apr 2021 15:59:21 -0400 Subject: [PATCH 1/8] Remove src from Multiplexer type since with the supervisor we might need to replace the producer thread. Also that thread id was not in use. --- src/PostgresWebsockets/Broadcast.hs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/PostgresWebsockets/Broadcast.hs b/src/PostgresWebsockets/Broadcast.hs index e0db012..299593c 100644 --- a/src/PostgresWebsockets/Broadcast.hs +++ b/src/PostgresWebsockets/Broadcast.hs @@ -9,7 +9,7 @@ The multiplexer contains a map of channels and a producer thread. This module avoids any database implementation details, it is used by HasqlBroadcast where the database logic is combined. -} -module PostgresWebsockets.Broadcast ( Multiplexer (src) +module PostgresWebsockets.Broadcast ( Multiplexer , Message (..) , newMultiplexer , onMessage @@ -34,7 +34,6 @@ data Message = Message { channel :: ByteString } deriving (Eq, Show) data Multiplexer = Multiplexer { channels :: M.Map ByteString Channel - , src :: ThreadId , messages :: TQueue Message } @@ -64,8 +63,8 @@ newMultiplexer :: (TQueue Message -> IO a) -> IO Multiplexer newMultiplexer openProducer closeProducer = do msgs <- newTQueueIO - m <- liftA2 Multiplexer M.newIO (forkFinally (openProducer msgs) closeProducer) - return $ m msgs + void $ forkFinally (openProducer msgs) closeProducer + Multiplexer <$> M.newIO <*> pure msgs openChannel :: Multiplexer -> ByteString -> STM Channel openChannel multi chan = do From bda8e2b07860d932a8ae2523f6f327860b523509 Mon Sep 17 00:00:00 2001 From: Diogo Biazus Date: Tue, 11 May 2021 22:04:12 -0400 Subject: [PATCH 2/8] Add check interval to config and pass it to newHasqlBroadcasterForChannel so we can use that configuration to configure a supervisor in the multiplexer producer thread --- src/PostgresWebsockets/Config.hs | 2 ++ src/PostgresWebsockets/Context.hs | 2 +- src/PostgresWebsockets/HasqlBroadcast.hs | 12 ++++++------ 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/PostgresWebsockets/Config.hs b/src/PostgresWebsockets/Config.hs index 36704e8..7731b91 100644 --- a/src/PostgresWebsockets/Config.hs +++ b/src/PostgresWebsockets/Config.hs @@ -37,6 +37,7 @@ data AppConfig = AppConfig { , configJwtSecretIsBase64 :: Bool , configPool :: Int , configRetries :: Int + , configReconnectInterval :: Int } -- | User friendly version number @@ -75,6 +76,7 @@ readOptions = <*> var auto "PGWS_JWT_SECRET_BASE64" (def False <> helpDef show <> help "Indicate whether the JWT secret should be decoded from a base64 encoded string") <*> var auto "PGWS_POOL_SIZE" (def 10 <> helpDef show <> help "How many connection to the database should be used by the connection pool") <*> var auto "PGWS_RETRIES" (def 5 <> helpDef show <> help "How many times it should try to connect to the database on startup before exiting with an error") + <*> var auto "PGWS_CHECK_LISTENER_INTERVAL" (def 0 <> helpDef show <> help "Interval for supervisor thread to check if listener connection is alive. 0 to disable it.") loadDatabaseURIFile :: AppConfig -> IO AppConfig loadDatabaseURIFile conf@AppConfig{..} = diff --git a/src/PostgresWebsockets/Context.hs b/src/PostgresWebsockets/Context.hs index bda824d..48a83ef 100644 --- a/src/PostgresWebsockets/Context.hs +++ b/src/PostgresWebsockets/Context.hs @@ -32,7 +32,7 @@ mkContext :: AppConfig -> IO () -> IO Context mkContext conf@AppConfig{..} shutdown = do Context conf <$> P.acquire (configPool, 10, pgSettings) - <*> newHasqlBroadcaster shutdown (toS configListenChannel) configRetries pgSettings + <*> newHasqlBroadcaster shutdown (toS configListenChannel) configRetries configReconnectInterval pgSettings <*> mkGetTime where mkGetTime :: IO (IO UTCTime) diff --git a/src/PostgresWebsockets/HasqlBroadcast.hs b/src/PostgresWebsockets/HasqlBroadcast.hs index 116e5a8..f08b19e 100644 --- a/src/PostgresWebsockets/HasqlBroadcast.hs +++ b/src/PostgresWebsockets/HasqlBroadcast.hs @@ -32,10 +32,10 @@ import PostgresWebsockets.Broadcast {- | Returns a multiplexer from a connection URI, keeps trying to connect in case there is any error. This function also spawns a thread that keeps relaying the messages from the database to the multiplexer's listeners -} -newHasqlBroadcaster :: IO () -> Text -> Int -> ByteString -> IO Multiplexer -newHasqlBroadcaster onConnectionFailure ch maxRetries = newHasqlBroadcasterForConnection . tryUntilConnected maxRetries +newHasqlBroadcaster :: IO () -> Text -> Int -> Int -> ByteString -> IO Multiplexer +newHasqlBroadcaster onConnectionFailure ch maxRetries checkInterval = newHasqlBroadcasterForConnection . tryUntilConnected maxRetries where - newHasqlBroadcasterForConnection = newHasqlBroadcasterForChannel onConnectionFailure ch + newHasqlBroadcasterForConnection = newHasqlBroadcasterForChannel onConnectionFailure ch checkInterval {- | Returns a multiplexer from a connection URI or an error message on the left case This function also spawns a thread that keeps relaying the messages from the database to the multiplexer's listeners @@ -44,7 +44,7 @@ newHasqlBroadcasterOrError :: IO () -> Text -> ByteString -> IO (Either ByteStri newHasqlBroadcasterOrError onConnectionFailure ch = acquire >=> (sequence . mapBoth (toSL . show) (newHasqlBroadcasterForConnection . return)) where - newHasqlBroadcasterForConnection = newHasqlBroadcasterForChannel onConnectionFailure ch + newHasqlBroadcasterForConnection = newHasqlBroadcasterForChannel onConnectionFailure ch 0 tryUntilConnected :: Int -> ByteString -> IO Connection tryUntilConnected maxRetries = @@ -83,8 +83,8 @@ tryUntilConnected maxRetries = forever $ fmap print (atomically $ readTChan ch) @ -} -newHasqlBroadcasterForChannel :: IO () -> Text -> IO Connection -> IO Multiplexer -newHasqlBroadcasterForChannel onConnectionFailure ch getCon = do +newHasqlBroadcasterForChannel :: IO () -> Text -> Int -> IO Connection -> IO Multiplexer +newHasqlBroadcasterForChannel onConnectionFailure ch checkInterval getCon = do multi <- newMultiplexer openProducer $ const onConnectionFailure void $ relayMessagesForever multi return multi From ec937645c361ecce35249a21c9a29c2aa7c7628a Mon Sep 17 00:00:00 2001 From: Diogo Biazus Date: Sat, 15 May 2021 10:00:30 -0400 Subject: [PATCH 3/8] Use Text to represent both channel and payload in message since both will be eventually represented as JSON id does not make much sense to keep them as ByteString around. --- src/PostgresWebsockets/Broadcast.hs | 40 ++++++++++++++++-------- src/PostgresWebsockets/Claims.hs | 12 +++---- src/PostgresWebsockets/HasqlBroadcast.hs | 10 +++--- src/PostgresWebsockets/Middleware.hs | 13 ++++---- test/ClaimsSpec.hs | 10 +++--- test/ServerSpec.hs | 1 + 6 files changed, 51 insertions(+), 35 deletions(-) diff --git a/src/PostgresWebsockets/Broadcast.hs b/src/PostgresWebsockets/Broadcast.hs index 299593c..bab2af9 100644 --- a/src/PostgresWebsockets/Broadcast.hs +++ b/src/PostgresWebsockets/Broadcast.hs @@ -15,6 +15,7 @@ module PostgresWebsockets.Broadcast ( Multiplexer , onMessage , relayMessages , relayMessagesForever + , superviseMultiplexer -- * Re-exports , readTQueue , writeTQueue @@ -27,19 +28,15 @@ import qualified StmContainers.Map as M import Control.Concurrent.STM.TChan import Control.Concurrent.STM.TQueue -import GHC.Show - -data Message = Message { channel :: ByteString - , payload :: ByteString +data Message = Message { channel :: Text + , payload :: Text } deriving (Eq, Show) -data Multiplexer = Multiplexer { channels :: M.Map ByteString Channel +data Multiplexer = Multiplexer { channels :: M.Map Text Channel , messages :: TQueue Message + , producerThreadId :: MVar ThreadId + , reopenProducer :: IO ThreadId } - -instance Show Multiplexer where - show Multiplexer{} = "Multiplexer" - data Channel = Channel { broadcast :: TChan Message , listeners :: Integer } @@ -63,10 +60,27 @@ newMultiplexer :: (TQueue Message -> IO a) -> IO Multiplexer newMultiplexer openProducer closeProducer = do msgs <- newTQueueIO - void $ forkFinally (openProducer msgs) closeProducer - Multiplexer <$> M.newIO <*> pure msgs + let forkNewProducer = forkFinally (openProducer msgs) closeProducer + tid <- forkNewProducer + multiplexerMap <- M.newIO + producerThreadId <- newMVar tid + pure $ Multiplexer multiplexerMap msgs producerThreadId forkNewProducer + +{- | Given a multiplexer, a number of milliseconds and an IO computation that returns a boolean + Runs the IO computation at every interval of milliseconds interval and reopens the multiplexer producer + if the resulting boolean is true + Call this in case you want to ensure the producer thread is killed and restarted under a certain condition +-} +superviseMultiplexer :: Multiplexer -> Int -> IO Bool -> IO () +superviseMultiplexer multi msInterval shouldRestart = do + void $ forkIO $ forever $ do + threadDelay msInterval + sr <- shouldRestart + when sr $ do + void $ killThread <$> readMVar (producerThreadId multi) + void $ swapMVar (producerThreadId multi) <$> reopenProducer multi -openChannel :: Multiplexer -> ByteString -> STM Channel +openChannel :: Multiplexer -> Text -> STM Channel openChannel multi chan = do c <- newBroadcastTChan let newChannel = Channel{ broadcast = c @@ -81,7 +95,7 @@ openChannel multi chan = do The first listener will open the channel, when a listener dies it will check if there acquire any others and close the channel when that's the case. -} -onMessage :: Multiplexer -> ByteString -> (Message -> IO()) -> IO () +onMessage :: Multiplexer -> Text -> (Message -> IO()) -> IO () onMessage multi chan action = do listener <- atomically $ openChannelWhenNotFound >>= addListener void $ forkFinally (forever (atomically (readTChan listener) >>= action)) disposeListener diff --git a/src/PostgresWebsockets/Claims.hs b/src/PostgresWebsockets/Claims.hs index 9b6b4cf..4314e19 100644 --- a/src/PostgresWebsockets/Claims.hs +++ b/src/PostgresWebsockets/Claims.hs @@ -23,13 +23,13 @@ import qualified Data.Aeson as JSON type Claims = M.HashMap Text JSON.Value -type ConnectionInfo = ([ByteString], ByteString, Claims) +type ConnectionInfo = ([Text], Text, Claims) {-| Given a secret, a token and a timestamp it validates the claims and returns either an error message or a triple containing channel, mode and claims hashmap. -} validateClaims - :: Maybe ByteString + :: Maybe Text -> ByteString -> LByteString -> UTCTime @@ -58,16 +58,16 @@ validateClaims requestChannel secret jwtToken time = runExceptT $ do pure (validChannels, mode, cl') where - claimAsJSON :: Text -> Claims -> Maybe ByteString + claimAsJSON :: Text -> Claims -> Maybe Text claimAsJSON name cl = case M.lookup name cl of - Just (JSON.String s) -> Just $ encodeUtf8 s + Just (JSON.String s) -> Just s _ -> Nothing - claimAsJSONList :: Text -> Claims -> Maybe [ByteString] + claimAsJSONList :: Text -> Claims -> Maybe [Text] claimAsJSONList name cl = case M.lookup name cl of Just channelsJson -> case JSON.fromJSON channelsJson :: JSON.Result [Text] of - JSON.Success channelsList -> Just $ encodeUtf8 <$> channelsList + JSON.Success channelsList -> Just channelsList _ -> Nothing Nothing -> Nothing diff --git a/src/PostgresWebsockets/HasqlBroadcast.hs b/src/PostgresWebsockets/HasqlBroadcast.hs index f08b19e..b6e476e 100644 --- a/src/PostgresWebsockets/HasqlBroadcast.hs +++ b/src/PostgresWebsockets/HasqlBroadcast.hs @@ -89,23 +89,23 @@ newHasqlBroadcasterForChannel onConnectionFailure ch checkInterval getCon = do void $ relayMessagesForever multi return multi where - toMsg :: ByteString -> ByteString -> Message + toMsg :: Text -> Text -> Message toMsg c m = case decode (toS m) of Just v -> Message (channelDef c v) m Nothing -> Message c m - lookupStringDef :: Text -> ByteString -> Value -> ByteString + lookupStringDef :: Text -> Text -> Value -> Text lookupStringDef key d (Object obj) = case lookupDefault (String $ toS d) key obj of String s -> toS s - _ -> d - lookupStringDef _ d _ = d + _ -> toS d + lookupStringDef _ d _ = toS d channelDef = lookupStringDef "channel" openProducer msgQ = do con <- getCon listen con $ toPgIdentifier ch waitForNotifications - (\c m-> atomically $ writeTQueue msgQ $ toMsg c m) + (\c m-> atomically $ writeTQueue msgQ $ toMsg (toS c) (toS m)) con putErrLn :: Text -> IO () diff --git a/src/PostgresWebsockets/Middleware.hs b/src/PostgresWebsockets/Middleware.hs index 5da67f1..01e988e 100644 --- a/src/PostgresWebsockets/Middleware.hs +++ b/src/PostgresWebsockets/Middleware.hs @@ -23,6 +23,7 @@ import qualified Network.WebSockets as WS import qualified Data.Aeson as A import qualified Data.ByteString.Char8 as BS +import qualified Data.Text as T import qualified Data.ByteString.Lazy as BL import qualified Data.HashMap.Strict as M @@ -63,8 +64,8 @@ wsApp :: Context -> WS.ServerApp wsApp Context{..} pendingConn = ctxGetTime >>= validateClaims requestChannel (configJwtSecret ctxConfig) (toS jwtToken) >>= either rejectRequest forkSessions where - hasRead m = m == ("r" :: ByteString) || m == ("rw" :: ByteString) - hasWrite m = m == ("w" :: ByteString) || m == ("rw" :: ByteString) + hasRead m = m == ("r" :: Text) || m == ("rw" :: Text) + hasWrite m = m == ("w" :: Text) || m == ("rw" :: Text) rejectRequest :: Text -> IO () rejectRequest msg = do @@ -72,8 +73,8 @@ wsApp Context{..} pendingConn = WS.rejectRequest pendingConn (toS msg) -- the URI has one of the two formats - /:jwt or /:channel/:jwt - pathElements = BS.split '/' $ BS.drop 1 $ WS.requestPath $ WS.pendingRequest pendingConn - jwtToken = + pathElements = T.split (== '/') $ T.drop 1 $ (toSL . WS.requestPath) $ WS.pendingRequest pendingConn + jwtToken = case length pathElements `compare` 1 of GT -> headDef "" $ tailSafe pathElements _ -> headDef "" pathElements @@ -102,7 +103,7 @@ wsApp Context{..} pendingConn = case configMetaChannel ctxConfig of Nothing -> pure () - Just ch -> sendMessageWithTimestamp $ connectionOpenMessage (toS $ BS.intercalate "," chs) ch + Just ch -> sendMessageWithTimestamp $ connectionOpenMessage (toS $ T.intercalate "," chs) ch when (hasRead mode) $ forM_ chs $ flip (onMessage ctxMulti) $ WS.sendTextData conn . B.payload @@ -116,7 +117,7 @@ wsApp Context{..} pendingConn = -- Having both channel and claims as parameters seem redundant -- But it allows the function to ignore the claims structure and the source -- of the channel, so all claims decoding can be coded in the caller -notifySession :: WS.Connection -> (Text -> Text -> IO ()) -> [ByteString] -> IO () +notifySession :: WS.Connection -> (Text -> Text -> IO ()) -> [Text] -> IO () notifySession wsCon sendToChannel chs = withAsync (forever relayData) wait where diff --git a/test/ClaimsSpec.hs b/test/ClaimsSpec.hs index a112fcd..be5de5f 100644 --- a/test/ClaimsSpec.hs +++ b/test/ClaimsSpec.hs @@ -21,12 +21,12 @@ spec = `shouldReturn` Left "Token expired" it "request any channel from a token that does not have channels or channel claims should succeed" $ do time <- getCurrentTime - validateClaims (Just (encodeUtf8 "test")) secret + validateClaims (Just "test") secret "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciJ9.jL5SsRFegNUlbBm8_okhHSujqLcKKZdDglfdqNl1_rY" time `shouldReturn` Right (["test"], "r", M.fromList[("mode",String "r")]) it "requesting a channel that is set by and old style channel claim should work" $ do time <- getCurrentTime - validateClaims (Just (encodeUtf8 "test")) secret + validateClaims (Just "test") secret "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWwiOiJ0ZXN0In0.1d4s-at2kWj8OSabHZHTbNh1dENF7NWy_r0ED3Rwf58" time `shouldReturn` Right (["test"], "r", M.fromList[("mode",String "r"),("channel",String "test")]) it "no requesting channel should return all channels in the token" $ do @@ -37,18 +37,18 @@ spec = it "requesting a channel from the channels claim shoud return only the requested channel" $ do time <- getCurrentTime - validateClaims (Just (encodeUtf8 "test")) secret + validateClaims (Just "test") secret "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWxzIjpbInRlc3QiLCJ0ZXN0MiJdfQ.MumdJ5FpFX4Z6SJD3qsygVF0r9vqxfqhj5J30O32N0k" time `shouldReturn` Right (["test"], "r", M.fromList[("mode",String "r"),("channels", toJSON ["test"::Text, "test2"] )]) it "requesting a channel not from the channels claim shoud error" $ do time <- getCurrentTime - validateClaims (Just (encodeUtf8 "notAllowed")) secret + validateClaims (Just "notAllowed") secret "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWxzIjpbInRlc3QiLCJ0ZXN0MiJdfQ.MumdJ5FpFX4Z6SJD3qsygVF0r9vqxfqhj5J30O32N0k" time `shouldReturn` Left "No allowed channels" it "requesting a channel with no mode fails" $ do time <- getCurrentTime - validateClaims (Just (encodeUtf8 "test")) secret + validateClaims (Just "test") secret "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJjaGFubmVscyI6WyJ0ZXN0IiwidGVzdDIiXX0.akC1PEYk2DEZtLP2XjC6qXOGZJejmPx49qv-VeEtQYQ" time `shouldReturn` Left "Missing mode" diff --git a/test/ServerSpec.hs b/test/ServerSpec.hs index 920dba5..5b2fa04 100644 --- a/test/ServerSpec.hs +++ b/test/ServerSpec.hs @@ -24,6 +24,7 @@ testServerConfig = AppConfig , configJwtSecretIsBase64 = False , configPool = 10 , configRetries = 5 + , configReconnectInterval = 0 } startTestServer :: IO ThreadId From 92ca0acf5e926edb1dccca28654669061cdfdbab Mon Sep 17 00:00:00 2001 From: Diogo Biazus Date: Sat, 15 May 2021 10:33:53 -0400 Subject: [PATCH 4/8] Reformat module using haskell language server formatter --- src/PostgresWebsockets/Broadcast.hs | 146 +++++++++++++++------------- 1 file changed, 78 insertions(+), 68 deletions(-) diff --git a/src/PostgresWebsockets/Broadcast.hs b/src/PostgresWebsockets/Broadcast.hs index bab2af9..d91bba4 100644 --- a/src/PostgresWebsockets/Broadcast.hs +++ b/src/PostgresWebsockets/Broadcast.hs @@ -1,49 +1,56 @@ -{-| -Module : PostgresWebsockets.Broadcast -Description : Distribute messages from one producer to several consumers. +-- | +-- Module : PostgresWebsockets.Broadcast +-- Description : Distribute messages from one producer to several consumers. +-- +-- PostgresWebsockets functions to broadcast messages to several listening clients +-- This module provides a type called Multiplexer. +-- The multiplexer contains a map of channels and a producer thread. +-- +-- This module avoids any database implementation details, it is used by HasqlBroadcast where +-- the database logic is combined. +module PostgresWebsockets.Broadcast + ( Multiplexer, + Message (..), + newMultiplexer, + onMessage, + relayMessages, + relayMessagesForever, + superviseMultiplexer, -PostgresWebsockets functions to broadcast messages to several listening clients -This module provides a type called Multiplexer. -The multiplexer contains a map of channels and a producer thread. - -This module avoids any database implementation details, it is used by HasqlBroadcast where -the database logic is combined. --} -module PostgresWebsockets.Broadcast ( Multiplexer - , Message (..) - , newMultiplexer - , onMessage - , relayMessages - , relayMessagesForever - , superviseMultiplexer - -- * Re-exports - , readTQueue - , writeTQueue - , readTChan - ) where + -- * Re-exports + readTQueue, + writeTQueue, + readTChan, + ) +where +import Control.Concurrent.STM.TChan +import Control.Concurrent.STM.TQueue import Protolude hiding (toS) import Protolude.Conv import qualified StmContainers.Map as M -import Control.Concurrent.STM.TChan -import Control.Concurrent.STM.TQueue -data Message = Message { channel :: Text - , payload :: Text - } deriving (Eq, Show) +data Message = Message + { channel :: Text, + payload :: Text + } + deriving (Eq, Show) + +data Multiplexer = Multiplexer + { channels :: M.Map Text Channel, + messages :: TQueue Message, + producerThreadId :: MVar ThreadId, + reopenProducer :: IO ThreadId + } -data Multiplexer = Multiplexer { channels :: M.Map Text Channel - , messages :: TQueue Message - , producerThreadId :: MVar ThreadId - , reopenProducer :: IO ThreadId - } -data Channel = Channel { broadcast :: TChan Message - , listeners :: Integer - } +data Channel = Channel + { broadcast :: TChan Message, + listeners :: Integer + } -- | Opens a thread that relays messages from the producer thread to the channels forever relayMessagesForever :: Multiplexer -> IO ThreadId -relayMessagesForever = forkIO . forever . relayMessages +relayMessagesForever = forkIO . forever . relayMessages -- | Reads the messages from the producer and relays them to the active listeners in their respective channels. relayMessages :: Multiplexer -> IO () @@ -55,9 +62,10 @@ relayMessages multi = Nothing -> return () Just c -> writeTChan (broadcast c) m -newMultiplexer :: (TQueue Message -> IO a) - -> (Either SomeException a -> IO ()) - -> IO Multiplexer +newMultiplexer :: + (TQueue Message -> IO a) -> + (Either SomeException a -> IO ()) -> + IO Multiplexer newMultiplexer openProducer closeProducer = do msgs <- newTQueueIO let forkNewProducer = forkFinally (openProducer msgs) closeProducer @@ -66,36 +74,38 @@ newMultiplexer openProducer closeProducer = do producerThreadId <- newMVar tid pure $ Multiplexer multiplexerMap msgs producerThreadId forkNewProducer -{- | Given a multiplexer, a number of milliseconds and an IO computation that returns a boolean - Runs the IO computation at every interval of milliseconds interval and reopens the multiplexer producer - if the resulting boolean is true - Call this in case you want to ensure the producer thread is killed and restarted under a certain condition --} +-- | Given a multiplexer, a number of milliseconds and an IO computation that returns a boolean +-- Runs the IO computation at every interval of milliseconds interval and reopens the multiplexer producer +-- if the resulting boolean is true +-- Call this in case you want to ensure the producer thread is killed and restarted under a certain condition superviseMultiplexer :: Multiplexer -> Int -> IO Bool -> IO () superviseMultiplexer multi msInterval shouldRestart = do - void $ forkIO $ forever $ do - threadDelay msInterval - sr <- shouldRestart - when sr $ do - void $ killThread <$> readMVar (producerThreadId multi) - void $ swapMVar (producerThreadId multi) <$> reopenProducer multi + void $ + forkIO $ + forever $ do + threadDelay msInterval + sr <- shouldRestart + when sr $ do + void $ killThread <$> readMVar (producerThreadId multi) + void $ swapMVar (producerThreadId multi) <$> reopenProducer multi -openChannel :: Multiplexer -> Text -> STM Channel +openChannel :: Multiplexer -> Text -> STM Channel openChannel multi chan = do - c <- newBroadcastTChan - let newChannel = Channel{ broadcast = c - , listeners = 0 - } - M.insert newChannel chan (channels multi) - return newChannel + c <- newBroadcastTChan + let newChannel = + Channel + { broadcast = c, + listeners = 0 + } + M.insert newChannel chan (channels multi) + return newChannel -{- | Adds a listener to a certain multiplexer's channel. - The listener must be a function that takes a 'TChan Message' and perform any IO action. - All listeners run in their own thread. - The first listener will open the channel, when a listener dies it will check if there acquire - any others and close the channel when that's the case. --} -onMessage :: Multiplexer -> Text -> (Message -> IO()) -> IO () +-- | Adds a listener to a certain multiplexer's channel. +-- The listener must be a function that takes a 'TChan Message' and perform any IO action. +-- All listeners run in their own thread. +-- The first listener will open the channel, when a listener dies it will check if there acquire +-- any others and close the channel when that's the case. +onMessage :: Multiplexer -> Text -> (Message -> IO ()) -> IO () onMessage multi chan action = do listener <- atomically $ openChannelWhenNotFound >>= addListener void $ forkFinally (forever (atomically (readTChan listener) >>= action)) disposeListener @@ -105,13 +115,13 @@ onMessage multi chan action = do let c = fromMaybe (panic $ "trying to remove listener from non existing channel: " <> toS chan) mC M.delete chan (channels multi) when (listeners c - 1 > 0) $ - M.insert Channel{ broadcast = broadcast c, listeners = listeners c - 1 } chan (channels multi) + M.insert Channel {broadcast = broadcast c, listeners = listeners c - 1} chan (channels multi) openChannelWhenNotFound = M.lookup chan (channels multi) >>= \case - Nothing -> openChannel multi chan - Just ch -> return ch + Nothing -> openChannel multi chan + Just ch -> return ch addListener ch = do M.delete chan (channels multi) - let newChannel = Channel{ broadcast = broadcast ch, listeners = listeners ch + 1} + let newChannel = Channel {broadcast = broadcast ch, listeners = listeners ch + 1} M.insert newChannel chan (channels multi) dupTChan $ broadcast newChannel From 1b7f8ac865c80e82420515f332df66d5d77bf5a6 Mon Sep 17 00:00:00 2001 From: Diogo Biazus Date: Mon, 17 May 2021 10:52:55 -0400 Subject: [PATCH 5/8] WIP creating type to print multiplexer state --- src/PostgresWebsockets/Broadcast.hs | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/PostgresWebsockets/Broadcast.hs b/src/PostgresWebsockets/Broadcast.hs index d91bba4..7328cb4 100644 --- a/src/PostgresWebsockets/Broadcast.hs +++ b/src/PostgresWebsockets/Broadcast.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE DeriveGeneric #-} + -- | -- Module : PostgresWebsockets.Broadcast -- Description : Distribute messages from one producer to several consumers. @@ -26,8 +28,9 @@ where import Control.Concurrent.STM.TChan import Control.Concurrent.STM.TQueue +import qualified Data.Aeson as A import Protolude hiding (toS) -import Protolude.Conv +import Protolude.Conv (toS) import qualified StmContainers.Map as M data Message = Message @@ -43,11 +46,20 @@ data Multiplexer = Multiplexer reopenProducer :: IO ThreadId } +data MultiplexerSnapshot = MultiplexerSnapshot + { channelsSize :: Int, + messageQueueEmpty :: Bool, + producerId :: Text + } + deriving (Generic) + data Channel = Channel { broadcast :: TChan Message, listeners :: Integer } +instance A.ToJSON MultiplexerSnapshot + -- | Opens a thread that relays messages from the producer thread to the channels forever relayMessagesForever :: Multiplexer -> IO ThreadId relayMessagesForever = forkIO . forever . relayMessages From 5e333df2e69baae4f006cf1bbe86e9245115a89b Mon Sep 17 00:00:00 2001 From: Diogo Biazus Date: Thu, 27 May 2021 18:02:37 -0400 Subject: [PATCH 6/8] Respawn listener connection on error or shutdown the whole server depending on the value of PGWS_CHECK_LISTENER_INTERVAL --- sample-env | 4 + src/PostgresWebsockets/Broadcast.hs | 22 +++- src/PostgresWebsockets/Config.hs | 137 +++++++++++----------- src/PostgresWebsockets/Context.hs | 49 ++++---- src/PostgresWebsockets/HasqlBroadcast.hs | 138 +++++++++++++---------- 5 files changed, 199 insertions(+), 151 deletions(-) diff --git a/sample-env b/sample-env index eab6f47..51d28f5 100644 --- a/sample-env +++ b/sample-env @@ -21,3 +21,7 @@ export PGWS_PORT=3000 ## (use "@filename" to load from separate file) export PGWS_JWT_SECRET="auwhfdnskjhewfi34uwehdlaehsfkuaeiskjnfduierhfsiweskjcnzeiluwhskdewishdnpwe" export PGWS_JWT_SECRET_BASE64=False + +## Check database listener every 10 seconds +## comment it out to disable and shutdown the server on listener errors (can be useful when using external process supervisors) +export PGWS_CHECK_LISTENER_INTERVAL=10000 diff --git a/src/PostgresWebsockets/Broadcast.hs b/src/PostgresWebsockets/Broadcast.hs index 7328cb4..8fac2db 100644 --- a/src/PostgresWebsockets/Broadcast.hs +++ b/src/PostgresWebsockets/Broadcast.hs @@ -60,6 +60,15 @@ data Channel = Channel instance A.ToJSON MultiplexerSnapshot +-- | Given a multiplexer derive a type that can be printed for debugging or logging purposes +takeSnapshot :: Multiplexer -> IO MultiplexerSnapshot +takeSnapshot multi = + MultiplexerSnapshot <$> size <*> e <*> thread + where + size = atomically $ M.size $ channels multi + thread = show <$> readMVar (producerThreadId multi) + e = atomically $ isEmptyTQueue $ messages multi + -- | Opens a thread that relays messages from the producer thread to the channels forever relayMessagesForever :: Multiplexer -> IO ThreadId relayMessagesForever = forkIO . forever . relayMessages @@ -89,17 +98,26 @@ newMultiplexer openProducer closeProducer = do -- | Given a multiplexer, a number of milliseconds and an IO computation that returns a boolean -- Runs the IO computation at every interval of milliseconds interval and reopens the multiplexer producer -- if the resulting boolean is true +-- When interval is 0 this is NOOP, so the minimum interval is 1ms -- Call this in case you want to ensure the producer thread is killed and restarted under a certain condition superviseMultiplexer :: Multiplexer -> Int -> IO Bool -> IO () superviseMultiplexer multi msInterval shouldRestart = do void $ forkIO $ forever $ do - threadDelay msInterval + threadDelay $ msInterval * 1000 sr <- shouldRestart when sr $ do + snapBefore <- takeSnapshot multi void $ killThread <$> readMVar (producerThreadId multi) - void $ swapMVar (producerThreadId multi) <$> reopenProducer multi + new <- reopenProducer multi + void $ swapMVar (producerThreadId multi) new + snapAfter <- takeSnapshot multi + putStrLn $ + "Restarting producer. Multiplexer updated: " + <> A.encode snapBefore + <> " -> " + <> A.encode snapAfter openChannel :: Multiplexer -> Text -> STM Channel openChannel multi chan = do diff --git a/src/PostgresWebsockets/Config.hs b/src/PostgresWebsockets/Config.hs index 7731b91..8dc845a 100644 --- a/src/PostgresWebsockets/Config.hs +++ b/src/PostgresWebsockets/Config.hs @@ -1,43 +1,43 @@ -{-| -Module : PostgresWebsockets.Config -Description : Manages PostgresWebsockets configuration options. - -This module provides a helper function to read the command line -arguments using the AppConfig type to store -them. It also can be used to define other middleware configuration that -may be delegated to some sort of external configuration. --} -module PostgresWebsockets.Config - ( prettyVersion - , loadConfig - , warpSettings - , AppConfig (..) - ) where - -import Env -import Data.Text (intercalate, pack, replace, strip, stripPrefix) -import Data.Version (versionBranch) -import Paths_postgres_websockets (version) -import Protolude hiding (intercalate, (<>), optional, replace, toS) -import Protolude.Conv -import Data.String (IsString(..)) -import Network.Wai.Handler.Warp -import qualified Data.ByteString as BS -import qualified Data.ByteString.Base64 as B64 +-- | +-- Module : PostgresWebsockets.Config +-- Description : Manages PostgresWebsockets configuration options. +-- +-- This module provides a helper function to read the command line +-- arguments using the AppConfig type to store +-- them. It also can be used to define other middleware configuration that +-- may be delegated to some sort of external configuration. +module PostgresWebsockets.Config + ( prettyVersion, + loadConfig, + warpSettings, + AppConfig (..), + ) +where + +import qualified Data.ByteString as BS +import qualified Data.ByteString.Base64 as B64 +import Data.String (IsString (..)) +import Data.Text (intercalate, pack, replace, strip, stripPrefix) +import Data.Version (versionBranch) +import Env +import Network.Wai.Handler.Warp +import Paths_postgres_websockets (version) +import Protolude hiding (intercalate, optional, replace, toS, (<>)) +import Protolude.Conv -- | Config file settings for the server -data AppConfig = AppConfig { - configDatabase :: Text - , configPath :: Maybe Text - , configHost :: Text - , configPort :: Int - , configListenChannel :: Text - , configMetaChannel :: Maybe Text - , configJwtSecret :: ByteString - , configJwtSecretIsBase64 :: Bool - , configPool :: Int - , configRetries :: Int - , configReconnectInterval :: Int +data AppConfig = AppConfig + { configDatabase :: Text, + configPath :: Maybe Text, + configHost :: Text, + configPort :: Int, + configListenChannel :: Text, + configMetaChannel :: Maybe Text, + configJwtSecret :: ByteString, + configJwtSecretIsBase64 :: Bool, + configPool :: Int, + configRetries :: Int, + configReconnectInterval :: Maybe Int } -- | User friendly version number @@ -50,38 +50,37 @@ loadConfig = readOptions >>= loadSecretFile >>= loadDatabaseURIFile -- | Given a shutdown handler and an AppConfig builds a Warp Settings to start a stand-alone server warpSettings :: (IO () -> IO ()) -> AppConfig -> Settings -warpSettings waitForShutdown AppConfig{..} = - setHost (fromString $ toS configHost) - . setPort configPort - . setServerName (toS $ "postgres-websockets/" <> prettyVersion) - . setTimeout 3600 - . setInstallShutdownHandler waitForShutdown - . setGracefulShutdownTimeout (Just 5) - $ defaultSettings - +warpSettings waitForShutdown AppConfig {..} = + setHost (fromString $ toS configHost) + . setPort configPort + . setServerName (toS $ "postgres-websockets/" <> prettyVersion) + . setTimeout 3600 + . setInstallShutdownHandler waitForShutdown + . setGracefulShutdownTimeout (Just 5) + $ defaultSettings -- private -- | Function to read and parse options from the environment readOptions :: IO AppConfig readOptions = - Env.parse (header "You need to configure some environment variables to start the service.") $ - AppConfig <$> var (str <=< nonempty) "PGWS_DB_URI" (help "String to connect to PostgreSQL") - <*> optional (var str "PGWS_ROOT_PATH" (help "Root path to serve static files, unset to disable.")) - <*> var str "PGWS_HOST" (def "*4" <> helpDef show <> help "Address the server will listen for websocket connections") - <*> var auto "PGWS_PORT" (def 3000 <> helpDef show <> help "Port the server will listen for websocket connections") - <*> var str "PGWS_LISTEN_CHANNEL" (def "postgres-websockets-listener" <> helpDef show <> help "Master channel used in the database to send or read messages in any notification channel") - <*> optional (var str "PGWS_META_CHANNEL" (help "Websockets channel used to send events about the server state changes.")) - <*> var str "PGWS_JWT_SECRET" (help "Secret used to sign JWT tokens used to open communications channels") - <*> var auto "PGWS_JWT_SECRET_BASE64" (def False <> helpDef show <> help "Indicate whether the JWT secret should be decoded from a base64 encoded string") - <*> var auto "PGWS_POOL_SIZE" (def 10 <> helpDef show <> help "How many connection to the database should be used by the connection pool") - <*> var auto "PGWS_RETRIES" (def 5 <> helpDef show <> help "How many times it should try to connect to the database on startup before exiting with an error") - <*> var auto "PGWS_CHECK_LISTENER_INTERVAL" (def 0 <> helpDef show <> help "Interval for supervisor thread to check if listener connection is alive. 0 to disable it.") + Env.parse (header "You need to configure some environment variables to start the service.") $ + AppConfig <$> var (str <=< nonempty) "PGWS_DB_URI" (help "String to connect to PostgreSQL") + <*> optional (var str "PGWS_ROOT_PATH" (help "Root path to serve static files, unset to disable.")) + <*> var str "PGWS_HOST" (def "*4" <> helpDef show <> help "Address the server will listen for websocket connections") + <*> var auto "PGWS_PORT" (def 3000 <> helpDef show <> help "Port the server will listen for websocket connections") + <*> var str "PGWS_LISTEN_CHANNEL" (def "postgres-websockets-listener" <> helpDef show <> help "Master channel used in the database to send or read messages in any notification channel") + <*> optional (var str "PGWS_META_CHANNEL" (help "Websockets channel used to send events about the server state changes.")) + <*> var str "PGWS_JWT_SECRET" (help "Secret used to sign JWT tokens used to open communications channels") + <*> var auto "PGWS_JWT_SECRET_BASE64" (def False <> helpDef show <> help "Indicate whether the JWT secret should be decoded from a base64 encoded string") + <*> var auto "PGWS_POOL_SIZE" (def 10 <> helpDef show <> help "How many connection to the database should be used by the connection pool") + <*> var auto "PGWS_RETRIES" (def 5 <> helpDef show <> help "How many times it should try to connect to the database on startup before exiting with an error") + <*> optional (var auto "PGWS_CHECK_LISTENER_INTERVAL" (helpDef show <> help "Interval for supervisor thread to check if listener connection is alive. 0 to disable it.")) loadDatabaseURIFile :: AppConfig -> IO AppConfig -loadDatabaseURIFile conf@AppConfig{..} = +loadDatabaseURIFile conf@AppConfig {..} = case stripPrefix "@" configDatabase of - Nothing -> pure conf + Nothing -> pure conf Just filename -> setDatabase . strip <$> readFile (toS filename) where setDatabase uri = conf {configDatabase = uri} @@ -89,15 +88,16 @@ loadDatabaseURIFile conf@AppConfig{..} = loadSecretFile :: AppConfig -> IO AppConfig loadSecretFile conf = extractAndTransform secret where - secret = decodeUtf8 $ configJwtSecret conf - isB64 = configJwtSecretIsBase64 conf + secret = decodeUtf8 $ configJwtSecret conf + isB64 = configJwtSecretIsBase64 conf extractAndTransform :: Text -> IO AppConfig extractAndTransform s = - fmap setSecret $ transformString isB64 =<< - case stripPrefix "@" s of - Nothing -> return . encodeUtf8 $ s - Just filename -> chomp <$> BS.readFile (toS filename) + fmap setSecret $ + transformString isB64 + =<< case stripPrefix "@" s of + Nothing -> return . encodeUtf8 $ s + Just filename -> chomp <$> BS.readFile (toS filename) where chomp bs = fromMaybe bs (BS.stripSuffix "\n" bs) @@ -107,11 +107,10 @@ loadSecretFile conf = extractAndTransform secret transformString True t = case B64.decode $ encodeUtf8 $ strip $ replaceUrlChars $ decodeUtf8 t of Left errMsg -> panic $ pack errMsg - Right bs -> return bs + Right bs -> return bs setSecret bs = conf {configJwtSecret = bs} -- replace: Replace every occurrence of one substring with another replaceUrlChars = replace "_" "/" . replace "-" "+" . replace "." "=" - diff --git a/src/PostgresWebsockets/Context.hs b/src/PostgresWebsockets/Context.hs index 48a83ef..7343bff 100644 --- a/src/PostgresWebsockets/Context.hs +++ b/src/PostgresWebsockets/Context.hs @@ -1,40 +1,45 @@ -{-| -Module : PostgresWebsockets.Context -Description : Produce a context capable of running postgres-websockets sessions --} +-- | +-- Module : PostgresWebsockets.Context +-- Description : Produce a context capable of running postgres-websockets sessions module PostgresWebsockets.Context - ( Context (..) - , mkContext - ) where + ( Context (..), + mkContext, + ) +where -import Protolude hiding (toS) -import Protolude.Conv +import Control.AutoUpdate + ( defaultUpdateSettings, + mkAutoUpdate, + updateAction, + ) import Data.Time.Clock (UTCTime, getCurrentTime) -import Control.AutoUpdate ( defaultUpdateSettings - , mkAutoUpdate - , updateAction - ) import qualified Hasql.Pool as P - -import PostgresWebsockets.Config ( AppConfig(..) ) -import PostgresWebsockets.HasqlBroadcast (newHasqlBroadcaster) import PostgresWebsockets.Broadcast (Multiplexer) +import PostgresWebsockets.Config (AppConfig (..)) +import PostgresWebsockets.HasqlBroadcast (newHasqlBroadcaster) +import Protolude hiding (toS) +import Protolude.Conv -data Context = Context { - ctxConfig :: AppConfig - , ctxPool :: P.Pool - , ctxMulti :: Multiplexer - , ctxGetTime :: IO UTCTime +data Context = Context + { ctxConfig :: AppConfig, + ctxPool :: P.Pool, + ctxMulti :: Multiplexer, + ctxGetTime :: IO UTCTime } -- | Given a configuration and a shutdown action (performed when the Multiplexer's listen connection dies) produces the context necessary to run sessions mkContext :: AppConfig -> IO () -> IO Context -mkContext conf@AppConfig{..} shutdown = do +mkContext conf@AppConfig {..} shutdownServer = do Context conf <$> P.acquire (configPool, 10, pgSettings) <*> newHasqlBroadcaster shutdown (toS configListenChannel) configRetries configReconnectInterval pgSettings <*> mkGetTime where + shutdown = + maybe + shutdownServer + (const $ putText "Producer thread is dead") + configReconnectInterval mkGetTime :: IO (IO UTCTime) mkGetTime = mkAutoUpdate defaultUpdateSettings {updateAction = getCurrentTime} pgSettings = toS configDatabase diff --git a/src/PostgresWebsockets/HasqlBroadcast.hs b/src/PostgresWebsockets/HasqlBroadcast.hs index b6e476e..a88b3a6 100644 --- a/src/PostgresWebsockets/HasqlBroadcast.hs +++ b/src/PostgresWebsockets/HasqlBroadcast.hs @@ -1,50 +1,50 @@ -{-| -Module : PostgresWebsockets.Broadcast -Description : Build a Hasql.Notifications based producer 'Multiplexer'. - -Uses Broadcast module adding database as a source producer. -This module provides a function to produce a 'Multiplexer' from a Hasql 'Connection'. -The producer issues a LISTEN command upon Open commands and UNLISTEN upon Close. --} +-- | +-- Module : PostgresWebsockets.Broadcast +-- Description : Build a Hasql.Notifications based producer 'Multiplexer'. +-- +-- Uses Broadcast module adding database as a source producer. +-- This module provides a function to produce a 'Multiplexer' from a Hasql 'Connection'. +-- The producer issues a LISTEN command upon Open commands and UNLISTEN upon Close. module PostgresWebsockets.HasqlBroadcast - ( newHasqlBroadcaster - , newHasqlBroadcasterOrError - -- re-export - , acquire - , relayMessages - , relayMessagesForever - ) where + ( newHasqlBroadcaster, + newHasqlBroadcasterOrError, + -- re-export + acquire, + relayMessages, + relayMessagesForever, + ) +where -import Protolude hiding (putErrLn, toS, show) +import Control.Retry (RetryStatus (..), capDelay, exponentialBackoff, retrying) +import Data.Aeson (Value (..), decode) +import Data.Either.Combinators (mapBoth) +import Data.Function (id) +import Data.HashMap.Lazy (lookupDefault) import GHC.Show -import Protolude.Conv - import Hasql.Connection +import qualified Hasql.Decoders as HD +import qualified Hasql.Encoders as HE import Hasql.Notifications -import Data.Aeson (decode, Value(..)) -import Data.HashMap.Lazy (lookupDefault) -import Data.Either.Combinators (mapBoth) -import Data.Function (id) -import Control.Retry (RetryStatus(..), retrying, capDelay, exponentialBackoff) - +import qualified Hasql.Session as H +import qualified Hasql.Statement as H import PostgresWebsockets.Broadcast +import Protolude hiding (putErrLn, show, toS) +import Protolude.Conv -{- | Returns a multiplexer from a connection URI, keeps trying to connect in case there is any error. - This function also spawns a thread that keeps relaying the messages from the database to the multiplexer's listeners --} -newHasqlBroadcaster :: IO () -> Text -> Int -> Int -> ByteString -> IO Multiplexer +-- | Returns a multiplexer from a connection URI, keeps trying to connect in case there is any error. +-- This function also spawns a thread that keeps relaying the messages from the database to the multiplexer's listeners +newHasqlBroadcaster :: IO () -> Text -> Int -> Maybe Int -> ByteString -> IO Multiplexer newHasqlBroadcaster onConnectionFailure ch maxRetries checkInterval = newHasqlBroadcasterForConnection . tryUntilConnected maxRetries where newHasqlBroadcasterForConnection = newHasqlBroadcasterForChannel onConnectionFailure ch checkInterval -{- | Returns a multiplexer from a connection URI or an error message on the left case - This function also spawns a thread that keeps relaying the messages from the database to the multiplexer's listeners --} +-- | Returns a multiplexer from a connection URI or an error message on the left case +-- This function also spawns a thread that keeps relaying the messages from the database to the multiplexer's listeners newHasqlBroadcasterOrError :: IO () -> Text -> ByteString -> IO (Either ByteString Multiplexer) newHasqlBroadcasterOrError onConnectionFailure ch = acquire >=> (sequence . mapBoth (toSL . show) (newHasqlBroadcasterForConnection . return)) where - newHasqlBroadcasterForConnection = newHasqlBroadcasterForChannel onConnectionFailure ch 0 + newHasqlBroadcasterForConnection = newHasqlBroadcasterForChannel onConnectionFailure ch Nothing tryUntilConnected :: Int -> ByteString -> IO Connection tryUntilConnected maxRetries = @@ -55,44 +55,46 @@ tryUntilConnected maxRetries = firstDelayInMicroseconds = 1000000 retryPolicy = capDelay maxDelayInMicroseconds $ exponentialBackoff firstDelayInMicroseconds shouldRetry :: RetryStatus -> Either ConnectionError Connection -> IO Bool - shouldRetry RetryStatus{..} con = + shouldRetry RetryStatus {..} con = case con of Left err -> do putErrLn $ "Error connecting notification listener to database: " <> (toS . show) err pure $ rsIterNumber < maxRetries - 1 _ -> return False -{- | Returns a multiplexer from a channel and an IO Connection, listen for different database notifications on the provided channel using the connection produced. - - This function also spawns a thread that keeps relaying the messages from the database to the multiplexer's listeners - - To listen on channels *chat* - - @ - import Protolude - import PostgresWebsockets.HasqlBroadcast - import PostgresWebsockets.Broadcast - import Hasql.Connection - - main = do - conOrError <- H.acquire "postgres://localhost/test_database" - let con = either (panic . show) id conOrError :: Connection - multi <- newHasqlBroadcaster con - - onMessage multi "chat" (\ch -> - forever $ fmap print (atomically $ readTChan ch) - @ --} -newHasqlBroadcasterForChannel :: IO () -> Text -> Int -> IO Connection -> IO Multiplexer +-- | Returns a multiplexer from a channel and an IO Connection, listen for different database notifications on the provided channel using the connection produced. +-- +-- This function also spawns a thread that keeps relaying the messages from the database to the multiplexer's listeners +-- +-- To listen on channels *chat* +-- +-- @ +-- import Protolude +-- import PostgresWebsockets.HasqlBroadcast +-- import PostgresWebsockets.Broadcast +-- import Hasql.Connection +-- +-- main = do +-- conOrError <- H.acquire "postgres://localhost/test_database" +-- let con = either (panic . show) id conOrError :: Connection +-- multi <- newHasqlBroadcaster con +-- +-- onMessage multi "chat" (\ch -> +-- forever $ fmap print (atomically $ readTChan ch) +-- @ +newHasqlBroadcasterForChannel :: IO () -> Text -> Maybe Int -> IO Connection -> IO Multiplexer newHasqlBroadcasterForChannel onConnectionFailure ch checkInterval getCon = do multi <- newMultiplexer openProducer $ const onConnectionFailure + case checkInterval of + Just i -> superviseMultiplexer multi i shouldRestart + _ -> pure () void $ relayMessagesForever multi return multi where toMsg :: Text -> Text -> Message toMsg c m = case decode (toS m) of - Just v -> Message (channelDef c v) m - Nothing -> Message c m + Just v -> Message (channelDef c v) m + Nothing -> Message c m lookupStringDef :: Text -> Text -> Value -> Text lookupStringDef key d (Object obj) = @@ -101,12 +103,32 @@ newHasqlBroadcasterForChannel onConnectionFailure ch checkInterval getCon = do _ -> toS d lookupStringDef _ d _ = toS d channelDef = lookupStringDef "channel" + shouldRestart = do + con <- getCon + not <$> isListening con ch + openProducer msgQ = do con <- getCon listen con $ toPgIdentifier ch waitForNotifications - (\c m-> atomically $ writeTQueue msgQ $ toMsg (toS c) (toS m)) + (\c m -> atomically $ writeTQueue msgQ $ toMsg (toS c) (toS m)) con putErrLn :: Text -> IO () putErrLn = hPutStrLn stderr + +isListening :: Connection -> Text -> IO Bool +isListening con ch = do + resultOrError <- H.run session con + pure $ fromRight False resultOrError + where + session = H.statement chPattern isListeningStatement + chPattern = "listen%" <> ch <> "%" + +isListeningStatement :: H.Statement Text Bool +isListeningStatement = + H.Statement sql encoder decoder True + where + sql = "select exists (select * from pg_stat_activity where datname = current_database() and query ilike $1);" + encoder = HE.param $ HE.nonNullable HE.text + decoder = HD.singleRow (HD.column (HD.nonNullable HD.bool)) \ No newline at end of file From 197591583e01de7a099e89da9587d218e3c95d90 Mon Sep 17 00:00:00 2001 From: Diogo Biazus Date: Thu, 27 May 2021 18:14:18 -0400 Subject: [PATCH 7/8] Add README section discussing how to recover from database connection failure and about new configuration options --- README.md | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ef66fc7..517d51a 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# postgres-websockets +# postgres-websockets ![CI](https://github.com/diogob/postgres-websockets/actions/workflows/ci.yml/badge.svg) [![Hackage Matrix CI](https://matrix.hackage.haskell.org/api/v2/packages/postgres-websockets/badge)](https://matrix.hackage.haskell.org/package/postgres-websockets) @@ -128,4 +128,11 @@ For instamce, if we use the configuration in the [sample-env](./sample-env) we w {"event":"ConnectionOpen","channel":"server-info","payload":"server-info","claims":{"mode":"rw","message_delivered_at":1.602719440727465893e9}} ``` -You can monitor these messages on another websocket connection with a proper read token for the channel `server-info` or also having an additional database listener on the `PGWS_LISTEN_CHANNEL`. \ No newline at end of file +You can monitor these messages on another websocket connection with a proper read token for the channel `server-info` or also having an additional database listener on the `PGWS_LISTEN_CHANNEL`. + +## Recovering from listener database connection failures + +The database conneciton used to wait for notification where the `LISTEN` command is issued can cause problems when it fails. To prevent these problem from completely disrupting our websockets server there are two ways to configure postgres-websockets: + +* Self healing connection - postgres-websockets comes with a connection supervisor baked in. You just need to set the configuration `PGWS_CHECK_LISTENER_INTERVAL` to a number of milliseconds that will be the maximum amount of time losing messages. There is a cost for this since at each interval an additional SELECT query will be issued to ensure the listener connection is still active. If the connecion is not found the connection thread will be killed and respawned. This method has the advantage of keeping all channels and websocket connections alive while the database connection is severed (although messages will be lost). +* Using external supervision - you can also unset `PGWS_CHECK_LISTENER_INTERVAL` and postgres-websockets will try to shutdown the server when the database connection is lost. This does not seem to work in 100% of the cases, since in theory is possible to have the database connection closed and the producer thread lingering. But in most cases it should work and some external process can then restart the server. The downside is that all websocket connections will be lost. \ No newline at end of file From 427a32fc573f792be6d2aaf79c043e80ce3bdf53 Mon Sep 17 00:00:00 2001 From: Diogo Biazus Date: Thu, 27 May 2021 18:17:41 -0400 Subject: [PATCH 8/8] Fix test configuration --- test/ServerSpec.hs | 173 +++++++++++++++++++++++---------------------- 1 file changed, 87 insertions(+), 86 deletions(-) diff --git a/test/ServerSpec.hs b/test/ServerSpec.hs index 5b2fa04..542d504 100644 --- a/test/ServerSpec.hs +++ b/test/ServerSpec.hs @@ -1,52 +1,51 @@ module ServerSpec (spec) where -import Protolude - -import Test.Hspec -import PostgresWebsockets -import PostgresWebsockets.Config - import Control.Lens import Data.Aeson.Lens - +import Network.Socket (withSocketsDo) import qualified Network.WebSockets as WS -import Network.Socket (withSocketsDo) +import PostgresWebsockets +import PostgresWebsockets.Config +import Protolude +import Test.Hspec testServerConfig :: AppConfig -testServerConfig = AppConfig - { configDatabase = "postgres://postgres:roottoor@localhost:5432/postgres_ws_test" - , configPath = Nothing - , configHost = "*" - , configPort = 8080 - , configListenChannel = "postgres-websockets-test-channel" - , configJwtSecret = "reallyreallyreallyreallyverysafe" - , configMetaChannel = Nothing - , configJwtSecretIsBase64 = False - , configPool = 10 - , configRetries = 5 - , configReconnectInterval = 0 - } +testServerConfig = + AppConfig + { configDatabase = "postgres://postgres:roottoor@localhost:5432/postgres_ws_test", + configPath = Nothing, + configHost = "*", + configPort = 8080, + configListenChannel = "postgres-websockets-test-channel", + configJwtSecret = "reallyreallyreallyreallyverysafe", + configMetaChannel = Nothing, + configJwtSecretIsBase64 = False, + configPool = 10, + configRetries = 5, + configReconnectInterval = Nothing + } startTestServer :: IO ThreadId startTestServer = do - threadId <- forkIO $ serve testServerConfig - threadDelay 500000 - pure threadId + threadId <- forkIO $ serve testServerConfig + threadDelay 500000 + pure threadId withServer :: IO () -> IO () withServer action = - bracket startTestServer - (\tid -> killThread tid >> threadDelay 500000) - (const action) + bracket + startTestServer + (\tid -> killThread tid >> threadDelay 500000) + (const action) sendWsData :: Text -> Text -> IO () sendWsData uri msg = - withSocketsDo $ - WS.runClient - "127.0.0.1" - (configPort testServerConfig) - (toS uri) - (`WS.sendTextData` msg) + withSocketsDo $ + WS.runClient + "127.0.0.1" + (configPort testServerConfig) + (toS uri) + (`WS.sendTextData` msg) testChannel :: Text testChannel = "/test/eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoicncifQ.auy9z4-pqoVEAay9oMi1FuG7ux_C_9RQCH8-wZgej18" @@ -59,62 +58,64 @@ testAndSecondaryChannel = "/eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoicnc waitForWsData :: Text -> IO (MVar ByteString) waitForWsData uri = do - msg <- newEmptyMVar - void $ forkIO $ - withSocketsDo $ - WS.runClient - "127.0.0.1" - (configPort testServerConfig) - (toS uri) - (\c -> do - m <- WS.receiveData c - putMVar msg m - ) - threadDelay 10000 - pure msg + msg <- newEmptyMVar + void $ + forkIO $ + withSocketsDo $ + WS.runClient + "127.0.0.1" + (configPort testServerConfig) + (toS uri) + ( \c -> do + m <- WS.receiveData c + putMVar msg m + ) + threadDelay 10000 + pure msg waitForMultipleWsData :: Int -> Text -> IO (MVar [ByteString]) waitForMultipleWsData messageCount uri = do - msg <- newEmptyMVar - void $ forkIO $ - withSocketsDo $ - WS.runClient - "127.0.0.1" - (configPort testServerConfig) - (toS uri) - (\c -> do - m <- replicateM messageCount (WS.receiveData c) - putMVar msg m - ) - threadDelay 1000 - pure msg + msg <- newEmptyMVar + void $ + forkIO $ + withSocketsDo $ + WS.runClient + "127.0.0.1" + (configPort testServerConfig) + (toS uri) + ( \c -> do + m <- replicateM messageCount (WS.receiveData c) + putMVar msg m + ) + threadDelay 1000 + pure msg spec :: Spec spec = around_ withServer $ - describe "serve" $ do - it "should be able to send messages to test server" $ - sendWsData testChannel "test data" - it "should be able to receive messages from test server" $ do - msg <- waitForWsData testChannel - sendWsData testChannel "test data" - msgJson <- takeMVar msg - (msgJson ^? key "payload" . _String) `shouldBe` Just "test data" - it "should be able to send messages to multiple channels in one shot" $ do - msg <- waitForWsData testChannel - secondaryMsg <- waitForWsData secondaryChannel - sendWsData testAndSecondaryChannel "test data" - msgJson <- takeMVar msg - secondaryMsgJson <- takeMVar secondaryMsg - - (msgJson ^? key "payload" . _String) `shouldBe` Just "test data" - (msgJson ^? key "channel" . _String) `shouldBe` Just "test" - (secondaryMsgJson ^? key "payload" . _String) `shouldBe` Just "test data" - (secondaryMsgJson ^? key "channel" . _String) `shouldBe` Just "secondary" - it "should be able to receive from multiple channels in one shot" $ do - msgs <- waitForMultipleWsData 2 testAndSecondaryChannel - sendWsData testAndSecondaryChannel "test data" - msgsJson <- takeMVar msgs - - forM_ - msgsJson - (\msgJson -> (msgJson ^? key "payload" . _String) `shouldBe` Just "test data") + describe "serve" $ do + it "should be able to send messages to test server" $ + sendWsData testChannel "test data" + it "should be able to receive messages from test server" $ do + msg <- waitForWsData testChannel + sendWsData testChannel "test data" + msgJson <- takeMVar msg + (msgJson ^? key "payload" . _String) `shouldBe` Just "test data" + it "should be able to send messages to multiple channels in one shot" $ do + msg <- waitForWsData testChannel + secondaryMsg <- waitForWsData secondaryChannel + sendWsData testAndSecondaryChannel "test data" + msgJson <- takeMVar msg + secondaryMsgJson <- takeMVar secondaryMsg + + (msgJson ^? key "payload" . _String) `shouldBe` Just "test data" + (msgJson ^? key "channel" . _String) `shouldBe` Just "test" + (secondaryMsgJson ^? key "payload" . _String) `shouldBe` Just "test data" + (secondaryMsgJson ^? key "channel" . _String) `shouldBe` Just "secondary" + it "should be able to receive from multiple channels in one shot" $ do + msgs <- waitForMultipleWsData 2 testAndSecondaryChannel + sendWsData testAndSecondaryChannel "test data" + msgsJson <- takeMVar msgs + + forM_ + msgsJson + (\msgJson -> (msgJson ^? key "payload" . _String) `shouldBe` Just "test data")