module Network.Pusher.WebSockets.Internal where
import Control.Concurrent (ThreadId)
import Control.Exception (IOException, Exception, SomeException, catch, toException)
import qualified Control.Exception as E
import Data.String (IsString(..))
import Data.Word (Word16)
import Control.Concurrent.STM (STM, TVar, atomically, newTVar, modifyTVar')
import Control.Concurrent.STM.TQueue
import qualified Control.Concurrent.STM as STM
import Control.DeepSeq (NFData(..), force)
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Trans.Reader (ReaderT, runReaderT)
import qualified Control.Monad.Trans.Reader as R
import Data.Aeson (Value(..))
import Data.Hashable (Hashable(..))
import qualified Data.HashMap.Strict as H
import qualified Data.Set as S
import Data.Text (Text, unpack)
import Data.Time.Clock (UTCTime, getCurrentTime)
import Network.Socket (HostName, PortNumber)
import Network.WebSockets (ConnectionException, HandshakeException)
newtype PusherClient a = PusherClient (ReaderT Pusher IO a)
deriving (Functor, Applicative, Monad, MonadIO)
runPusherClient :: Pusher -> PusherClient a -> IO a
runPusherClient pusher (PusherClient action) = runReaderT action pusher
data Pusher = Pusher
{ commandQueue :: TQueue PusherCommand
, connState :: TVar ConnectionState
, options :: Options
, idleTimer :: TVar (Maybe Int)
, lastReceived :: TVar UTCTime
, socketId :: TVar (Maybe Text)
, threadStore :: TVar (S.Set ThreadId)
, eventHandlers :: TVar (H.HashMap Binding Handler)
, nextBinding :: TVar Binding
, allChannels :: TVar (S.Set Channel)
, presenceChannels :: TVar (H.HashMap Channel (Value, H.HashMap Text Value))
}
data PusherCommand
= SendMessage Value
| SendLocalMessage Value
| Subscribe Channel Value
| Terminate
deriving (Eq, Show)
data TerminatePusher = TerminatePusher (Maybe Word16)
deriving (Eq, Ord, Read, Show)
instance Exception TerminatePusher
data PusherClosed = PusherClosed (Maybe Word16)
deriving (Eq, Ord, Read, Show)
instance Exception PusherClosed
data ConnectionState
= Initialized
| Connecting
| Connected
| Unavailable
| Disconnected (Maybe Word16)
deriving (Eq, Ord, Read, Show)
defaultPusher :: Options -> IO Pusher
defaultPusher opts = do
now <- getCurrentTime
atomically $ do
defCommQueue <- newTQueue
defConnState <- newTVar Initialized
defIdleTimer <- newTVar Nothing
defLastReceived <- newTVar now
defSocketId <- newTVar Nothing
defThreadStore <- newTVar S.empty
defEHandlers <- newTVar H.empty
defBinding <- newTVar (Binding 0)
defAChannels <- newTVar S.empty
defPChannels <- newTVar H.empty
pure Pusher
{ commandQueue = defCommQueue
, connState = defConnState
, options = opts
, idleTimer = defIdleTimer
, lastReceived = defLastReceived
, socketId = defSocketId
, threadStore = defThreadStore
, eventHandlers = defEHandlers
, nextBinding = defBinding
, allChannels = defAChannels
, presenceChannels = defPChannels
}
sendCommand :: Pusher -> PusherCommand -> IO ()
sendCommand pusher cmd = do
cstate <- readTVarIO (connState pusher)
case cstate of
Disconnected ccode -> E.throwIO (PusherClosed ccode)
_ -> atomically (writeTQueue (commandQueue pusher) cmd)
data Options = Options
{ appKey :: AppKey
, encrypted :: Bool
, authorisationURL :: Maybe String
, cluster :: Cluster
, pusherURL :: Maybe (HostName, PortNumber, String)
} deriving (Eq, Ord, Show)
instance NFData Options where
rnf o = rnf ( appKey o
, encrypted o
, authorisationURL o
, cluster o
, mangle (pusherURL o)
)
where
mangle Nothing = Nothing
mangle (Just (h, p, s)) = p `seq` Just (h, s)
data Cluster
= MT1
| EU
| AP1
deriving (Eq, Ord, Bounded, Enum, Read, Show)
instance NFData Cluster where
rnf c = c `seq` ()
newtype AppKey = AppKey String
deriving (Eq, Ord, Show, Read)
instance IsString AppKey where
fromString = AppKey
instance NFData AppKey where
rnf (AppKey k) = rnf k
defaultOptions :: AppKey -> Options
defaultOptions key = Options
{ appKey = key
, encrypted = True
, authorisationURL = Nothing
, cluster = MT1
, pusherURL = Nothing
}
data Handler = Handler (Maybe Text) (Maybe Channel) (Value -> PusherClient ())
instance NFData Handler where
rnf (Handler e c _) = rnf (e, c)
newtype Channel = Channel { unChannel :: Text }
deriving (Eq, Ord)
instance NFData Channel where
rnf (Channel c) = rnf c
instance Show Channel where
show (Channel c) = "<<channel " ++ unpack c ++ ">>"
instance Hashable Channel where
hashWithSalt salt (Channel c) = hashWithSalt salt c
newtype Binding = Binding { unBinding :: Int }
deriving (Eq, Ord)
instance NFData Binding where
rnf (Binding b) = rnf b
instance Show Binding where
show (Binding b) = "<<binding " ++ show b ++ ">>"
instance Hashable Binding where
hashWithSalt salt (Binding b) = hashWithSalt salt b
ask :: PusherClient Pusher
ask = PusherClient R.ask
strictModifyTVar :: NFData a => TVar a -> (a -> a) -> STM ()
strictModifyTVar tvar = modifyTVar' tvar . force
strictModifyTVarIO :: (MonadIO m, NFData a) => TVar a -> (a -> a) -> m ()
strictModifyTVarIO tvar = liftIO . atomically . strictModifyTVar tvar
readTVarIO :: MonadIO m => TVar a -> m a
readTVarIO = liftIO . STM.readTVarIO
ignoreAll :: a -> IO a -> IO a
ignoreAll fallback act = catchAll act (const (pure fallback))
reconnecting :: IO a -> IO () -> IO a
reconnecting act prere = loop where
loop = catchNetException act (const (prere >> loop))
catchNetException :: forall a. IO a -> (SomeException -> IO a) -> IO a
catchNetException act handler = E.catches act handlers where
handlers = [ E.Handler (handler . toException :: IOException -> IO a)
, E.Handler (handler . toException :: HandshakeException -> IO a)
, E.Handler (handler . toException :: ConnectionException -> IO a)
]
catchAll :: IO a -> (SomeException -> IO a) -> IO a
catchAll = catch