{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module Control.Concurrent.GoChan
(
Chan
,Result(..)
,Case(..)
,chanMake
,chanClose
,chanRecv
,chanTryRecv
,chanSend
,chanTrySend
,chanSelect)
where
import Control.Concurrent.MVar
import Control.Monad
import Control.Monad.Primitive (RealWorld)
import Data.Array.IO
import Data.IORef
import Data.List (intercalate)
import Data.Maybe (fromJust, isJust, isNothing)
import qualified Data.Vector as V
import qualified Data.Vector.Algorithms.Heap as VAH
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Mutable as VGM
import qualified Data.Vector.Mutable as VM
import Data.Word
import GHC.Prim (Any)
import System.IO.Unsafe
import System.Random
import Unsafe.Coerce
data Chan a = Chan
{ _qcount :: {-# UNPACK #-} !(IORef Int)
, _qsize :: {-# UNPACK #-} !Int
, _buf :: {-# UNPACK #-} !(IOArray Int a)
, _sendx :: {-# UNPACK #-} !(IORef Int)
, _recvx :: {-# UNPACK #-} !(IORef Int)
, _sendq :: {-# UNPACK #-} !SuspQ
, _recvq :: {-# UNPACK #-} !SuspQ
, _lock :: {-# UNPACK #-} !(MVar ())
, _closed :: {-# UNPACK #-} !(IORef Bool)
, _id :: {-# UNPACK #-} !Word64
}
data SuspQ = SuspQ
{ _first :: {-# UNPACK #-} !(IORef (Maybe SomeSuspend))
, _last :: {-# UNPACK #-} !(IORef (Maybe SomeSuspend))
}
data SomeSuspend =
forall a. SomeSuspend (Suspend a)
data Suspend a = forall b. Suspend
{ _selectDone :: !(Maybe (IORef Bool))
, _case :: !(Maybe (Case b))
, _next :: !(IORef (Maybe SomeSuspend))
, _prev :: !(IORef (Maybe SomeSuspend))
, _elem :: !(Maybe (IORef a))
, _chan :: !(Chan a)
, _park :: !(MVar (Maybe (Suspend a)))
, _sid :: !Word64
}
data Result a
= Msg a
| Closed
data Case a
= forall b. Recv (Chan b)
(Result b -> IO a)
| forall b. Send (Chan b)
b
(IO a)
{-# INLINE caseChanId #-}
caseChanId :: Case a -> Word64
caseChanId (Recv chan _) = _id chan
caseChanId (Send chan _ _) = _id chan
{-# NOINLINE currIdRef #-}
currIdRef :: IORef Word64
currIdRef = unsafePerformIO (newIORef 0)
{-# NOINLINE currSIdRef #-}
currSIdRef :: IORef Word64
currSIdRef = unsafePerformIO (newIORef 0)
shuffleVector
:: (VGM.MVector v e)
=> v RealWorld e -> IO ()
shuffleVector !xs = do
let !size = VGM.length xs
forM_ [1 .. size - 1] $
\i -> do
j <- randomRIO (0, i)
vi <- VGM.read xs i
vj <- VGM.read xs j
VGM.write xs j vi
VGM.write xs i vj
chanSelect
:: [Case a]
-> Maybe (IO a)
-> IO a
chanSelect cases mdefault = do
!pollOrder <-
do vec <- V.thaw (V.fromList cases)
shuffleVector vec
V.unsafeFreeze vec
let !ncases = VG.length pollOrder
!lockOrder <-
do vec <- V.thaw (V.fromList cases)
VAH.sortBy
(\cas1 cas2 ->
caseChanId cas1 `compare` caseChanId cas2)
vec
V.unsafeFreeze vec
selLock lockOrder
let pass1 !n = do
if n /= ncases
then case pollOrder VG.! n of
Recv chan act -> do
ms <- dequeue (_sendq chan)
case ms of
Just (SomeSuspend s) -> do
elemRef <- newIORef undefined
recv chan (unsafeCoerceSuspend s) (Just elemRef) (selUnlock lockOrder)
val <- readIORef elemRef
act (Msg val)
_ -> do
!qcount <- readIORef (_qcount chan)
if qcount > 0
then do
!recvx <- readIORef (_recvx chan)
val <- readArray (_buf chan) recvx
let !recvx' =
let !x = recvx + 1
in if x == _qsize chan
then 0
else x
writeIORef (_recvx chan) $! recvx'
writeIORef (_qcount chan) (qcount - 1)
selUnlock lockOrder
act (Msg val)
else do
!isClosed <- readIORef (_closed chan)
if isClosed
then do
selUnlock lockOrder
act Closed
else do
pass1 (n + 1)
Send chan val act -> do
!isClosed <- readIORef (_closed chan)
if isClosed
then do
selUnlock lockOrder
fail "send on closed channel"
else do
ms <- dequeue (_recvq chan)
case ms of
Just (SomeSuspend s) -> do
send chan (unsafeCoerceSuspend s) val (selUnlock lockOrder)
act
_ -> do
!qcount <- readIORef (_qcount chan)
if qcount < _qsize chan
then do
!sendx <- readIORef (_sendx chan)
writeArray (_buf chan) sendx val
let !sendx' =
let !x = sendx + 1
in if x == _qsize chan
then 0
else x
writeIORef (_sendx chan) sendx'
writeIORef (_qcount chan) (qcount + 1)
selUnlock lockOrder
act
else do
pass1 (n + 1)
else case mdefault of
Just def -> do
selUnlock lockOrder
def
_ -> do
park <- newEmptyMVar
selectDone <- newIORef False
ss <-
V.generateM
ncases
(\n -> do
next <- newIORef Nothing
prev <- newIORef Nothing
id <-
atomicModifyIORef'
currSIdRef
(\currId ->
(currId + 1, currId))
case lockOrder V.! n of
cas@(Send chan val _) -> do
elemRef <- newIORef (unsafeCoerce val)
let !s =
SomeSuspend
(Suspend
(Just selectDone)
(Just (unsafeCoerceCase cas))
next
prev
(Just elemRef)
(unsafeCoerceChan chan)
park
id)
enqueue (_sendq chan) s
return s
cas@(Recv chan _) -> do
elemRef <- newIORef undefined
let !s =
SomeSuspend
(Suspend
(Just selectDone)
(Just (unsafeCoerceCase cas))
next
prev
(Just elemRef)
(unsafeCoerceChan chan)
park
id)
enqueue (_recvq chan) s
return s)
selUnlock lockOrder
ms <- takeMVar park
selLock lockOrder
let pass3 !n = do
case ss VG.! n of
someS@(SomeSuspend s) ->
case s of
(Suspend _ cas _ _ _ _ _ _) ->
case cas of
Just (Send _ _ _) -> do
dequeueSuspend
(_sendq (_chan s))
someS
Just (Recv _ _) -> do
dequeueSuspend (_recvq (_chan s)) someS
when ((n + 1) /= ncases) (pass3 (n + 1))
pass3 0
case ms of
Just s -> do
case s of
(Suspend _ cas _ _ _ _ _ _) ->
case cas of
Just (Send chan _ act) -> do
selUnlock lockOrder
unsafeCoerceSendAction act
Just (Recv chan act) -> do
!val <- readIORef (fromJust (_elem s))
selUnlock lockOrder
unsafeCoerceRecvAction act (Msg val)
_ -> do
pass1
0
pass1 0
{-# INLINE unsafeCoerceSendAction #-}
unsafeCoerceSendAction :: IO a -> IO b
unsafeCoerceSendAction = unsafeCoerce
{-# INLINE unsafeCoerceRecvAction #-}
unsafeCoerceRecvAction :: (Result b -> IO a) -> (Result d -> IO c)
unsafeCoerceRecvAction = unsafeCoerce
{-# INLINE unsafeCoerceSuspend #-}
unsafeCoerceSuspend :: Suspend a -> Suspend b
unsafeCoerceSuspend = unsafeCoerce
{-# INLINE unsafeCoerceChan #-}
unsafeCoerceChan :: Chan a -> Chan b
unsafeCoerceChan = unsafeCoerce
{-# INLINE unsafeCoerceCase #-}
unsafeCoerceCase :: Case a -> Case b
unsafeCoerceCase = unsafeCoerce
{-# INLINE lockCase #-}
lockCase :: Case a -> IO ()
lockCase (Recv chan _) = takeMVar (_lock chan)
lockCase (Send chan _ _) = takeMVar (_lock chan)
{-# INLINE unlockCase #-}
unlockCase :: Case a -> IO ()
unlockCase (Recv chan _) = putMVar (_lock chan) ()
unlockCase (Send chan _ _) = putMVar (_lock chan) ()
selLock
:: (VG.Vector v e, e ~ Case a)
=> v (Case a) -> IO ()
selLock !vec = do
go 0 maxBound
where
len = VG.length vec
go n prevId = do
let !cas = vec VG.! n
if n == len - 1
then do
lockCase cas
else do
when (caseChanId cas /= prevId) $ do lockCase cas
go (n + 1) (caseChanId cas)
selUnlock
:: (VG.Vector v e, e ~ Case a)
=> v (Case a) -> IO ()
selUnlock !vec = do
go (len - 1) maxBound
where
len = VG.length vec
go n prevId = do
let !cas = vec VG.! n
if n == 0
then do
unlockCase cas
else do
when (caseChanId cas /= prevId) $ do (unlockCase cas)
go (n - 1) (caseChanId cas)
chanMake
:: Int -> IO (Chan a)
chanMake !size = do
ary <- newArray_ (0, size - 1)
qcount <- newIORef 0
sendx <- newIORef 0
recvx <- newIORef 0
sendq_first <- newIORef Nothing
sendq_last <- newIORef Nothing
recvq_first <- newIORef Nothing
recvq_last <- newIORef Nothing
lock <- newMVar ()
closed <- newIORef False
id <-
atomicModifyIORef'
currIdRef
(\currId ->
(currId + 1, currId))
return
Chan
{ _qcount = qcount
, _qsize = size
, _buf = ary
, _sendx = sendx
, _recvx = recvx
, _sendq = SuspQ sendq_first sendq_last
, _recvq = SuspQ recvq_first recvq_last
, _lock = lock
, _closed = closed
, _id = id
}
chanSend
:: Chan a -> a -> IO ()
chanSend !chan !val = void $ chanSendInternal chan val True
chanTrySend
:: Chan a -> a -> IO Bool
chanTrySend !chan !val = chanSendInternal chan val False
chanSendInternal :: Chan a -> a -> Bool -> IO Bool
chanSendInternal !chan !val !block = do
!isClosed <- readIORef (_closed chan)
!recvq_first <- readIORef (_first (_recvq chan))
!qcount <- readIORef (_qcount chan)
if not block && not isClosed && ((_qsize chan == 0 && isJust recvq_first) || (_qsize chan > 0 && qcount == _qsize chan))
then return False
else do
takeMVar (_lock chan)
!isClosed <- readIORef (_closed chan)
if isClosed
then do
putMVar (_lock chan) ()
fail "send on closed channel"
else do
ms <- dequeue (_recvq chan)
case ms of
Just (SomeSuspend s) -> do
send chan (unsafeCoerceSuspend s) val (putMVar (_lock chan) ())
return True
Nothing -> do
!qcount <- readIORef (_qcount chan)
if qcount < _qsize chan
then do
!sendx <- readIORef (_sendx chan)
writeArray (_buf chan) sendx val
writeIORef (_sendx chan) $! (sendx + 1)
let !sendx' = sendx + 1
if sendx' == _qsize chan
then writeIORef (_sendx chan) 0
else writeIORef (_sendx chan) $! sendx'
writeIORef (_qcount chan) (qcount + 1)
putMVar (_lock chan) ()
return True
else if not block
then do
putMVar (_lock chan) ()
return False
else do
next <- newIORef Nothing
prev <- newIORef Nothing
elem <- newIORef val
park <- newEmptyMVar
id <-
atomicModifyIORef'
currSIdRef
(\currId ->
(currId + 1, currId))
let !s = (SomeSuspend (Suspend Nothing Nothing next prev (Just elem) chan park id))
enqueue (_sendq chan) s
putMVar (_lock chan) ()
ms' <- takeMVar park
case ms' of
Nothing -> do
!isClosed <- readIORef (_closed chan)
unless isClosed (fail "chansend: spurious wakeup")
fail "send on closed channel"
_ -> return True
send :: Chan a -> Suspend a -> a -> IO () -> IO ()
send !chan !s !val !unlock = do
case _elem s of
Just elemRef -> do
writeIORef elemRef val
_ -> do
return ()
unlock
putMVar (_park s) (Just s)
chanClose
:: Chan a -> IO ()
chanClose !chan = do
takeMVar (_lock chan)
!isClosed <- readIORef (_closed chan)
when isClosed $
do putMVar (_lock chan) ()
fail "close of closed channel"
writeIORef (_closed chan) True
ss <- releaseReaders [] chan
ss <- releaseWriters ss chan
putMVar (_lock chan) ()
wakeSuspends ss
where
releaseReaders ss chan = do
ms <- dequeue (_recvq chan)
case ms of
Nothing -> return ss
Just s -> releaseReaders (s : ss) chan
releaseWriters ss chan = do
ms <- dequeue (_sendq chan)
case ms of
Nothing -> return ss
Just s -> releaseWriters (s : ss) chan
wakeSuspends ss =
forM_
ss
(\(SomeSuspend s) ->
putMVar (_park s) Nothing)
data RecvResult
= RecvWouldBlock
| RecvGotMessage
| RecvClosed
chanTryRecv
:: Chan a -> IO (Maybe (Result a))
chanTryRecv !chan = do
ref <- newIORef undefined
chanRecvInternal chan (Just ref) False >>=
\case
RecvWouldBlock -> return Nothing
RecvClosed -> return (Just Closed)
RecvGotMessage-> Just <$> Msg <$> readIORef ref
chanRecv
:: Chan a -> IO (Result a)
chanRecv !chan = do
ref <- newIORef undefined
chanRecvInternal chan (Just ref) True >>=
\case
RecvWouldBlock -> fail "the impossible happened"
RecvClosed -> return Closed
RecvGotMessage -> Msg <$> readIORef ref
chanRecvInternal :: Chan a -> Maybe (IORef a) -> Bool -> IO RecvResult
chanRecvInternal !chan !melemRef !block = do
!sendq_first <- readIORef (_first (_sendq chan))
!qcount <- atomicReadIORef (_qcount chan)
!isClosed <- atomicReadIORef (_closed chan)
if not block && ((_qsize chan == 0 && isNothing sendq_first) || (_qsize chan > 0 && qcount == _qsize chan)) && not isClosed
then return RecvWouldBlock
else do
takeMVar (_lock chan)
!isClosed <- readIORef (_closed chan)
!qcount <- readIORef (_qcount chan)
if isClosed && qcount == 0
then do
putMVar (_lock chan) ()
return RecvClosed
else do
ms <- dequeue (_sendq chan)
case ms of
Just (SomeSuspend s) -> do
recv chan (unsafeCoerceSuspend s) melemRef (putMVar (_lock chan) ())
return RecvGotMessage
_ ->
if qcount > 0
then do
!recvx <- readIORef (_recvx chan)
val <- readArray (_buf chan) recvx
case melemRef of
Just elemRef -> writeIORef elemRef val
_ -> return ()
let !recvx' =
let !x = recvx + 1
in if x == _qsize chan
then 0
else x
writeIORef (_recvx chan) $! recvx'
modifyIORef' (_qcount chan) (subtract 1)
putMVar (_lock chan) ()
return RecvGotMessage
else if not block
then do
putMVar (_lock chan) ()
return RecvWouldBlock
else do
next <- newIORef Nothing
prev <- newIORef Nothing
park <- newEmptyMVar
id <-
atomicModifyIORef'
currSIdRef
(\currId ->
(currId + 1, currId))
let !s = SomeSuspend (Suspend Nothing Nothing next prev melemRef chan park id)
enqueue (_recvq chan) s
putMVar (_lock chan) ()
ms' <- takeMVar park
if isJust ms'
then return RecvGotMessage
else return RecvClosed
recv :: Chan a -> Suspend a -> Maybe (IORef a) -> IO () -> IO ()
recv !chan !s !melemRef !unlock = do
if _qsize chan == 0
then case melemRef of
Just elemRef -> do
!val <- readIORef (fromJust (_elem s))
writeIORef elemRef val
_ -> return ()
else do
!recvx <- readIORef (_recvx chan)
val <- readArray (_buf chan) recvx
case melemRef of
Just elemRef -> writeIORef elemRef val
_ -> return ()
!val' <- readIORef (fromJust (_elem s))
writeArray (_buf chan) recvx val'
let !recvx' =
let !x = recvx + 1
in if x == _qsize chan
then 0
else x
writeIORef (_recvx chan) $! recvx'
writeIORef (_sendx chan) $! recvx'
unlock
putMVar (_park s) (Just s)
enqueue
:: SuspQ -> SomeSuspend -> IO ()
enqueue !q someS@(SomeSuspend s) = do
writeIORef (_next s) Nothing
mx <- readIORef . _last $ q
case mx of
Nothing -> do
writeIORef (_prev s) Nothing
writeIORef (_first q) (Just someS)
writeIORef (_last q) (Just someS)
Just someX@(SomeSuspend x) -> do
writeIORef (_prev s) (Just someX)
writeIORef (_next x) (Just someS)
writeIORef (_last q) (Just someS)
dequeue
:: SuspQ -> IO (Maybe SomeSuspend)
dequeue !q = do
!ms <- readIORef (_first q)
case ms of
Nothing -> return Nothing
Just someS@(SomeSuspend s) -> do
!my <- readIORef (_next s)
case my of
Nothing -> do
writeIORef (_first q) Nothing
writeIORef (_last q) Nothing
Just someY@(SomeSuspend y) -> do
writeIORef (_prev y) Nothing
writeIORef (_first q) (Just someY)
writeIORef (_next s) Nothing
case _selectDone s of
Nothing -> return (Just someS)
Just doneRef -> do
done <- readIORef doneRef
if not done
then do
oldDone <-
atomicModifyIORef'
doneRef
(\oldDone ->
(True, oldDone))
if oldDone
then do
dequeue q
else do
return
(Just someS)
else do
dequeue q
{-# INLINE atomicReadIORef #-}
atomicReadIORef :: IORef a -> IO a
atomicReadIORef !ref =
atomicModifyIORef'
ref
(\oldVal ->
(oldVal, oldVal))
eqSuspend
:: Suspend a -> Suspend b -> Bool
eqSuspend !s1 !s2 = _sid s1 == _sid s2
dequeueSuspend :: SuspQ -> SomeSuspend -> IO ()
dequeueSuspend !q someS@(SomeSuspend s) = do
!mx <- readIORef (_prev s)
!my <- readIORef (_next s)
case mx of
Just someX@(SomeSuspend x) ->
case my of
Just someY@(SomeSuspend y) -> do
writeIORef (_next x) (Just someY)
writeIORef (_prev y) (Just someX)
writeIORef (_next s) Nothing
writeIORef (_prev s) Nothing
_ -> do
writeIORef (_next x) Nothing
writeIORef (_last q) (Just someX)
writeIORef (_prev s) Nothing
_ ->
case my of
Just someY@(SomeSuspend y) -> do
writeIORef (_prev y) Nothing
writeIORef (_first q) (Just someY)
writeIORef (_next s) Nothing
_ -> do
!mfirst <- readIORef (_first q)
case mfirst of
Just someFirst@(SomeSuspend first) ->
when (first `eqSuspend` s) $
do writeIORef (_first q) Nothing
writeIORef (_last q) Nothing
_ -> return ()
waitqToList
:: SuspQ -> IO [SomeSuspend]
waitqToList q = do
!ms <- readIORef (_first q)
case ms of
Just s -> sleeperChain s
_ -> return []
sleeperChain :: SomeSuspend -> IO [SomeSuspend]
sleeperChain someS@(SomeSuspend s) = do
!mnext <- readIORef (_next s)
case mnext of
Just next -> do
ss <- sleeperChain next
return (someS : ss)
_ -> return [someS]
printSuspQ :: SuspQ -> IO ()
printSuspQ q = do
ss <- waitqToList q
let !chain =
intercalate
"->"
(map
(\(SomeSuspend s) ->
"(SID: " ++ show (_sid s) ++ ", CID: " ++ show (_id (_chan s)) ++ ")")
ss)
putStrLn $ "WAITQ: " ++ chain