{-# language BangPatterns #-}
{-# language DataKinds #-}
{-# language DerivingStrategies #-}
{-# language GeneralizedNewtypeDeriving #-}
{-# language LambdaCase #-}
{-# language MagicHash #-}
{-# language MultiWayIf #-}
{-# language NamedFieldPuns #-}
{-# language ScopedTypeVariables #-}
{-# language UnboxedTuples #-}
module Socket.EventManager
(
manager
, register
, reader
, writer
, Token
, unready
, wait
, unreadyAndWait
, persistentUnreadyAndWait
, persistentUnready
, interruptibleWait
, interruptibleWaitCounting
, isInterrupt
) where
import Control.Applicative (liftA2,(<|>))
import Control.Concurrent (getNumCapabilities,forkOn,rtsSupportsBoundThreads)
import Control.Concurrent.STM (TVar)
import Control.Monad (when)
import Control.Monad.STM (atomically)
import Data.Bits (countLeadingZeros,finiteBitSize,unsafeShiftL,(.|.),(.&.))
import Data.Bits (unsafeShiftR)
import Data.Primitive.Unlifted.Array (MutableUnliftedArray(..))
import Data.Primitive (MutableByteArray(..),MutablePrimArray(..))
import Data.Primitive (Prim)
import Data.Word (Word64,Word32)
import Foreign.C.Error (Errno(..),eINTR)
import Foreign.C.Types (CInt)
import GHC.Conc.Sync (TVar(..),yield)
import GHC.Exts (RealWorld,Int(I#),(*#),TVar#,ArrayArray#,MutableArrayArray#)
import GHC.Exts (Any,MutableArray#,unsafeCoerce#,(==#),isTrue#,casArray#)
import GHC.IO (IO(..))
import Numeric (showIntAtBase)
import Socket.Error (die)
import Socket.Debug (debug,whenDebugging,debugging)
import System.IO.Unsafe (unsafePerformIO)
import System.Posix.Types (Fd)
import qualified Control.Monad.STM as STM
import qualified Control.Concurrent.STM as STM
import qualified Linux.Epoll as Epoll
import qualified Control.Monad.Primitive as PM
import qualified Data.Primitive as PM
import qualified Data.Primitive.Unlifted.Array as PM
import qualified GHC.Exts as Exts
register ::
Manager
-> Fd
-> IO ()
register mngr@Manager{epoll} !fd = do
(ixTier2, tier2) <- constructivelyLookupTier1 (fdToInt fd) mngr
let ixRead = ixTier2 * 2
ixWrite = ixRead + 1
readVar <- readTVarArray tier2 ixRead
writeVar <- readTVarArray tier2 ixWrite
atomically $ do
STM.modifyTVar' readVar resetToken
STM.modifyTVar' writeVar resetToken
ev <- PM.newPrimArray 1
debug ("register: registering fd " ++ show fd)
PM.writePrimArray ev 0 $ Epoll.Event
{ Epoll.events = Epoll.input <> Epoll.output <> Epoll.edgeTriggered <> Epoll.readHangup
, Epoll.payload = fd
}
e <- Epoll.uninterruptibleControlMutablePrimArray epoll Epoll.add fd ev
case e of
Left (Errno code) ->
die $ "Socket.EventManager.register: epoll_ctl error " ++ show code
Right () -> pure ()
type MUArray = MutableUnliftedArray RealWorld
data Manager = Manager
{ variables :: !(MUArray (MUArray (TVar Token)))
, novars :: !(MUArray (TVar Token))
, epoll :: !Fd
}
manager :: Manager
{-# noinline manager #-}
manager = unsafePerformIO $ do
when (not rtsSupportsBoundThreads) $ do
fail $ "Socket.Event.manager: threaded runtime required"
!novars <- PM.unsafeNewUnliftedArray 0
!variables <- PM.unsafeNewUnliftedArray 32
let goX !ix = if ix >= 0
then do
writeMutableUnliftedArrayArray variables ix novars
goX (ix - 1)
else pure ()
goX 32
Epoll.uninterruptibleCreate1 Epoll.closeOnExec >>= \case
Left (Errno code) ->
die $ "Socket.EventManager.manager: epoll_create error code " ++ show code
Right !epoll -> do
capNum <- getNumCapabilities
whenDebugging $ do
when (capNum < 1) $ do
die $ "Socket.EventManager.manager: non-positive number of capabilities"
let go !ix = if ix > (-1)
then do
_ <- forkOn ix $ do
let !initSz = if debugging then 1 else 8
!initArr <- newPinnedPrimArray initSz
loopManager initArr initSz epoll variables
go (ix - 1)
else pure ()
go (if debugging then 0 else capNum)
pure (Manager {variables,novars,epoll})
reader :: Manager -> Fd -> IO (TVar Token)
reader Manager{variables} !fd = lookupGeneric 0 (fdToInt fd) variables
writer :: Manager -> Fd -> IO (TVar Token)
writer Manager{variables} !fd = lookupGeneric 1 (fdToInt fd) variables
lookupBoth ::
Int
-> MUArray (MUArray (TVar Token))
-> IO (TVar Token,TVar Token)
lookupBoth !fd !arr = do
let (ixTier1,ixTier2) = decompose fd
tier2 <- readMutableUnliftedArrayArray arr ixTier1
liftA2 (,)
(readTVarArray tier2 (ixTier2 * 2))
(readTVarArray tier2 (ixTier2 * 2 + 1))
lookupGeneric ::
Int
-> Int
-> MutableUnliftedArray RealWorld (MutableUnliftedArray RealWorld (TVar Token))
-> IO (TVar Token)
lookupGeneric !rw !fd !arr = do
let (ixTier1,ixTier2) = decompose fd
tier2 <- readMutableUnliftedArrayArray arr ixTier1
readTVarArray tier2 ((ixTier2 * 2) + rw)
constructivelyLookupTier1 ::
Int
-> Manager
-> IO (Int, MUArray (TVar Token))
constructivelyLookupTier1 !fd Manager{variables,novars} = do
let (ixTier1,ixTier2) = decompose fd
varsTier2 <- readMutableUnliftedArrayArray variables ixTier1
if PM.sameMutableUnliftedArray varsTier2 novars
then do
let !len = exp2succ ixTier1
varsAttempt <- PM.unsafeNewUnliftedArray len
let goVars !ix = if ix > (-1)
then do
writeTVarArray varsAttempt ix
=<< STM.newTVarIO emptyToken
goVars (ix - 1)
else pure ()
goVars (len - 1)
(success,tier2) <- casMutableUnliftedArrayArray variables ixTier1 novars varsAttempt
debug ("constructivelyLookupTier1: Created tier 2 array of length " ++ show len ++ " at index " ++ show ixTier1 ++ " with success " ++ show success)
pure (ixTier2,tier2)
else pure (ixTier2,varsTier2)
loopManager ::
MutablePrimArray RealWorld (Epoll.Event 'Epoll.Response Fd)
-> Int
-> Fd
-> MUArray (MUArray (TVar Token))
-> IO ()
loopManager !evs0 !sz0 !epfd !tier1 = do
yield
(!evs1, !sz1) <- stepManager evs0 sz0 epfd tier1
loopManager evs1 sz1 epfd tier1
stepManager ::
MutablePrimArray RealWorld (Epoll.Event 'Epoll.Response Fd)
-> Int
-> Fd
-> MUArray (MUArray (TVar Token))
-> IO (MutablePrimArray RealWorld (Epoll.Event 'Epoll.Response Fd),Int)
stepManager !evs0 !sz0 !epfd !tier1 = do
Epoll.uninterruptibleWaitMutablePrimArray epfd evs0 (intToCInt sz0) >>= \case
Left (Errno code) -> die $ "Socket.EventManager.stepManager: A " ++ show code
Right len0 -> if len0 > 0
then handleEvents evs0 (cintToInt len0) sz0 tier1
else do
debug "stepManager: first attempt returned no events"
yield
Epoll.uninterruptibleWaitMutablePrimArray epfd evs0 (intToCInt sz0) >>= \case
Left (Errno code) -> die $ "Socket.EventManager.stepManager: B " ++ show code
Right len1 -> if len1 > 0
then do
debug "stepManager: second attempt succeeded"
handleEvents evs0 (cintToInt len1) sz0 tier1
else do
debug "stepManager: second attempt returned no events"
whenDebugging $ do
actualSize <- PM.getSizeofMutablePrimArray evs0
when (actualSize /= sz0) (die "stepManager: bad size")
let go = Epoll.waitMutablePrimArray epfd evs0 (intToCInt sz0) (-1) >>= \case
Left err@(Errno code) -> if err == eINTR
then go
else die $ "Socket.EventManager.stepManager: C " ++ show code
Right len2 -> if len2 > 0
then do
whenDebugging $ do
let !(MutablePrimArray evs0#) = evs0
let untypedEvs0 = MutableByteArray evs0#
debug ("stepManager: third attempt succeeded, len=" ++ show len2 ++ ",sz=" ++ show sz0)
(w0 :: Word32) <- PM.readByteArray untypedEvs0 0
(w1 :: Word32) <- PM.readByteArray untypedEvs0 1
(w2 :: Word32) <- PM.readByteArray untypedEvs0 2
debug $ "stepManager: element 0 raw after third attempt " ++
lpad 32 (showIntAtBase 2 binChar w0 "") ++ " " ++
lpad 32 (showIntAtBase 2 binChar w1 "") ++ " " ++
lpad 32 (showIntAtBase 2 binChar w2 "")
when (sz0 > 1) $ do
(w0a :: Word32) <- PM.readByteArray untypedEvs0 3
(w1a :: Word32) <- PM.readByteArray untypedEvs0 4
(w2a :: Word32) <- PM.readByteArray untypedEvs0 5
debug $ "stepManager: element 1 raw after third attempt " ++
lpad 32 (showIntAtBase 2 binChar w0a "") ++ " " ++
lpad 32 (showIntAtBase 2 binChar w1a "") ++ " " ++
lpad 32 (showIntAtBase 2 binChar w2a "")
handleEvents evs0 (cintToInt len2) sz0 tier1
else die $ "Socket.EventManager.stepManager: D"
go
lpad :: Int -> String -> String
lpad m xs = replicate (m - length ys) '0' ++ ys
where ys = take m xs
intToCInt :: Int -> CInt
intToCInt = fromIntegral
cintToInt :: CInt -> Int
cintToInt = fromIntegral
handleEvents ::
MutablePrimArray RealWorld (Epoll.Event 'Epoll.Response Fd)
-> Int
-> Int
-> MUArray (MUArray (TVar Token))
-> IO (MutablePrimArray RealWorld (Epoll.Event 'Epoll.Response Fd),Int)
handleEvents !evs !len !sz !vars = do
traverseMutablePrimArray_
( \(Epoll.Event{Epoll.events,Epoll.payload}) -> do
let fd = payload
let hasReadInclusive = Epoll.containsAnyEvents events
(Epoll.input <> Epoll.readHangup <> Epoll.error <> Epoll.hangup)
let hasPersistentReadInclusive = Epoll.containsAnyEvents events Epoll.readHangup
let hasWriteInclusive = Epoll.containsAnyEvents events
(Epoll.output <> Epoll.error <> Epoll.hangup)
let hasRead = Epoll.containsAnyEvents events Epoll.input
let hasReadHangup = Epoll.containsAnyEvents events Epoll.readHangup
let hasWrite = Epoll.containsAnyEvents events Epoll.output
let hasHangup = Epoll.containsAnyEvents events Epoll.hangup
let hasError = Epoll.containsAnyEvents events Epoll.error
whenDebugging $ do
let hasPriority = Epoll.containsAnyEvents events Epoll.priority
let Epoll.Events e = events
debug $
"handleEvents: fd " ++ show fd ++
" bitmask " ++ showIntAtBase 2 binChar e "" ++ " read [" ++ show hasRead ++
"] write [" ++ show hasWrite ++ "] hangup [" ++ show hasHangup ++
"] readHangup [" ++ show hasReadHangup ++
"] error [" ++ show hasError ++
"] priority [" ++ show hasPriority ++ "]"
(readVar,writeVar) <- lookupBoth (fdToInt fd) vars
when hasReadInclusive $ atomically $ do
old <- STM.readTVar readVar
let !new = if hasPersistentReadInclusive
then persistentReadyToken old
else readyToken old
STM.writeTVar readVar new
when hasWriteInclusive $ atomically $ STM.modifyTVar' writeVar readyToken
) evs 0 len
if | len < sz -> pure (evs,sz)
| len == sz -> do
let newSz = sz * 2
debug ("handleEvents: doubling size of array to " ++ show newSz)
newBuf <- newPinnedPrimArray newSz
pure (newBuf,newSz)
| otherwise -> die "Socket.EventManager.handleEvents: len > sz"
binChar :: Int -> Char
binChar = \case
0 -> '0'
1 -> '1'
_ -> 'x'
traverseMutablePrimArray_ ::
Prim a
=> (a -> IO ())
-> MutablePrimArray RealWorld a
-> Int
-> Int
-> IO ()
{-# inline traverseMutablePrimArray_ #-}
traverseMutablePrimArray_ f a off end = go off where
go !ix = if ix < end
then do
debug ("traverseMutablePrimArray_: index " ++ show ix)
f =<< PM.readPrimArray a ix
go (ix + 1)
else pure ()
exp2 :: Int -> Int
{-# INLINE exp2 #-}
exp2 n = unsafeShiftL (1 :: Int) n
exp2succ :: Int -> Int
{-# INLINE exp2succ #-}
exp2succ n = unsafeShiftL (1 :: Int) (n + 1)
decompose :: Int -> (Int,Int)
{-# INLINE decompose #-}
decompose n =
let !a = finiteBitSize (undefined :: Int) - countLeadingZeros (n + 1) - 1
!b = (n + 1) - exp2 a
in (a,b)
fdToInt :: Fd -> Int
{-# INLINE fdToInt #-}
fdToInt = fromIntegral
newtype Token = Token Word64
readyBit :: Word64
readyBit = 0x8000000000000000
persistentReadyBits :: Word64
persistentReadyBits = 0xC000000000000000
unreadyBit :: Word64
unreadyBit = 0x7FFFFFFFFFFFFFFF
eqToken :: Token -> Token -> Bool
eqToken (Token a) (Token b) = a == b
emptyToken :: Token
emptyToken = Token readyBit
interruptToken :: Token
interruptToken = Token 0x2000000000000000
isTokenReady :: Token -> Bool
isTokenReady (Token w) = unsafeShiftR w 62 /= 0
isInterrupt :: Token -> Bool
isInterrupt (Token w) = (0x2000000000000000 == w)
readyToken :: Token -> Token
readyToken (Token w) = Token (readyBit .|. (w + 1))
persistentReadyToken :: Token -> Token
persistentReadyToken (Token w) = Token (persistentReadyBits .|. (w + 1))
resetToken :: Token -> Token
resetToken (Token w) = Token ((readyBit .|. (w + 1)) .&. 0x9FFFFFFFFFFFFFFF)
unreadyToken :: Token -> Token
unreadyToken (Token w) = Token (unreadyBit .&. (w + 1))
persistentUnreadyToken :: Token -> Token
persistentUnreadyToken (Token w) = Token (0x1FFFFFFFFFFFFFFF .&. (w + 1))
unready ::
Token
-> TVar Token
-> IO ()
unready !oldToken !tv = atomically $ do
newToken <- STM.readTVar tv
if eqToken oldToken newToken
then STM.writeTVar tv $! unreadyToken oldToken
else pure ()
persistentUnready ::
Token
-> TVar Token
-> IO ()
persistentUnready !oldToken !tv = atomically $ do
newToken <- STM.readTVar tv
if eqToken oldToken newToken
then STM.writeTVar tv $! persistentUnreadyToken oldToken
else pure ()
unreadyAndWait ::
Token
-> TVar Token
-> IO Token
unreadyAndWait !oldToken !tv = do
unready oldToken tv
wait tv
persistentUnreadyAndWait ::
Token
-> TVar Token
-> IO Token
persistentUnreadyAndWait !oldToken !tv = do
persistentUnready oldToken tv
wait tv
wait :: TVar Token -> IO Token
wait !tv = do
!token0@(Token val) <- STM.readTVarIO tv
debug $ "wait: initial token value " ++ (lpad 64 (showIntAtBase 2 binChar val ""))
if isTokenReady token0
then pure token0
else atomically $ do
token1 <- STM.readTVar tv
STM.check (isTokenReady token1)
pure token1
interruptibleWait ::
TVar Bool
-> TVar Token
-> IO Token
interruptibleWait !interrupt !tv = do
STM.readTVarIO interrupt >>= \case
True -> pure interruptToken
False -> do
token0 <- STM.readTVarIO tv
if isTokenReady token0
then pure token0
else do
atomically $
( do STM.check =<< STM.readTVar interrupt
pure interruptToken
) <|>
( do token1 <- STM.readTVar tv
STM.check (isTokenReady token1)
pure token1
)
interruptibleWaitCounting :: TVar Int -> TVar Bool -> TVar Token -> IO Token
interruptibleWaitCounting !counter !interrupt !tv = atomically $
( do STM.check =<< STM.readTVar interrupt
pure interruptToken
) <|>
( do token1 <- STM.readTVar tv
STM.check (isTokenReady token1)
STM.modifyTVar' counter (+1)
pure token1
)
newPinnedPrimArray :: forall a. Prim a
=> Int -> IO (MutablePrimArray RealWorld a)
{-# INLINE newPinnedPrimArray #-}
newPinnedPrimArray (I# n#)
= PM.primitive (\s# -> case Exts.newPinnedByteArray# (n# *# PM.sizeOf# (undefined :: a)) s# of
(# s'#, arr# #) -> (# s'#, MutablePrimArray arr# #))
readTVarArray :: forall a.
MutableUnliftedArray RealWorld (TVar a)
-> Int
-> IO (TVar a)
readTVarArray (MutableUnliftedArray maa#) (I# i#)
= PM.primitive $ \s -> case Exts.readArrayArrayArray# maa# i# s of
(# s', aa# #) -> (# s', TVar ((unsafeCoerce# :: ArrayArray# -> TVar# RealWorld a) aa#) #)
readMutableUnliftedArrayArray
:: MutableUnliftedArray RealWorld (MutableUnliftedArray RealWorld a)
-> Int
-> IO (MutableUnliftedArray RealWorld a)
readMutableUnliftedArrayArray (MutableUnliftedArray maa#) (I# i#)
= PM.primitive $ \s -> case Exts.readArrayArrayArray# maa# i# s of
(# s', aa# #) -> (# s', MutableUnliftedArray ((unsafeCoerce# :: ArrayArray# -> MutableArrayArray# RealWorld) aa#) #)
writeTVarArray :: forall a.
MutableUnliftedArray RealWorld (TVar a)
-> Int
-> TVar a
-> IO ()
writeTVarArray (PM.MutableUnliftedArray maa#) (I# i#) (TVar a)
= PM.primitive_ (Exts.writeArrayArrayArray# maa# i# ((unsafeCoerce# :: TVar# RealWorld a -> ArrayArray#) a))
writeMutableUnliftedArrayArray :: forall a.
MutableUnliftedArray RealWorld (MutableUnliftedArray RealWorld a)
-> Int
-> MutableUnliftedArray RealWorld a
-> IO ()
writeMutableUnliftedArrayArray (PM.MutableUnliftedArray maa#) (I# i#) (MutableUnliftedArray a)
= PM.primitive_ (Exts.writeArrayArrayArray# maa# i# ((unsafeCoerce# :: MutableArrayArray# RealWorld -> ArrayArray#) a))
casMutableUnliftedArrayArray ::
MutableUnliftedArray RealWorld (MutableUnliftedArray RealWorld a)
-> Int
-> (MutableUnliftedArray RealWorld a)
-> (MutableUnliftedArray RealWorld a)
-> IO (Bool,MutableUnliftedArray RealWorld a)
{-# INLINE casMutableUnliftedArrayArray #-}
casMutableUnliftedArrayArray (MutableUnliftedArray arr#) (I# i#) (MutableUnliftedArray old) (MutableUnliftedArray new) =
IO $ \s0 ->
let !uold = (unsafeCoerce# :: MutableArrayArray# RealWorld -> Any) old
!unew = (unsafeCoerce# :: MutableArrayArray# RealWorld -> Any) new
in case casArray# ((unsafeCoerce# :: MutableArrayArray# RealWorld -> MutableArray# RealWorld Any) arr#) i# uold unew s0 of
(# s1, n, ur #) -> (# s1, (isTrue# (n ==# 0# ),MutableUnliftedArray ((unsafeCoerce# :: Any -> MutableArrayArray# RealWorld) ur)) #)