{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP          #-}
{-# LANGUAGE RankNTypes   #-}

module Snap.Internal.Http.Server.Thread
  ( SnapThread
  , fork
  , forkOn
  , cancel
  , wait
  , cancelAndWait
  , isFinished
  ) where

#if !MIN_VERSION_base(4,8,0)
import           Control.Applicative         ((<$>))
#endif
import           Control.Concurrent          (MVar, ThreadId, killThread, newEmptyMVar, putMVar, readMVar)
#if MIN_VERSION_base(4,7,0)
import           Control.Concurrent          (tryReadMVar)
#else
import           Control.Concurrent          (tryTakeMVar)
import           Control.Monad               (when)
import           Data.Maybe                  (fromJust, isJust)
#endif
import           Control.Concurrent.Extended (forkIOLabeledWithUnmaskBs, forkOnLabeledWithUnmaskBs)
import qualified Control.Exception           as E
import           Control.Monad               (void)
import qualified Data.ByteString.Char8       as B
import           GHC.Exts                    (inline)

#if !MIN_VERSION_base(4,7,0)
tryReadMVar :: MVar a -> IO (Maybe a)
tryReadMVar mv = do
    m <- tryTakeMVar mv
    when (isJust m) $ putMVar mv (fromJust m)
    return m
#endif

------------------------------------------------------------------------------
data SnapThread = SnapThread {
      SnapThread -> ThreadId
_snapThreadId :: {-# UNPACK #-} !ThreadId
    , SnapThread -> MVar ()
_snapThreadFinished :: {-# UNPACK #-} !(MVar ())
    }

instance Show SnapThread where
  show :: SnapThread -> String
show = forall a. Show a => a -> String
show forall b c a. (b -> c) -> (a -> b) -> a -> c
. SnapThread -> ThreadId
_snapThreadId


------------------------------------------------------------------------------
forkOn :: B.ByteString                          -- ^ thread label
       -> Int                                   -- ^ capability
       -> ((forall a . IO a -> IO a) -> IO ())  -- ^ user thread action, taking
                                                --   a restore function
       -> IO SnapThread
forkOn :: ByteString
-> Int -> ((forall a. IO a -> IO a) -> IO ()) -> IO SnapThread
forkOn ByteString
label Int
cap (forall a. IO a -> IO a) -> IO ()
action = do
    MVar ()
mv <- forall a. IO (MVar a)
newEmptyMVar
    forall a. IO a -> IO a
E.uninterruptibleMask_ forall a b. (a -> b) -> a -> b
$ do
        ThreadId
tid <- ByteString
-> Int -> ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forkOnLabeledWithUnmaskBs ByteString
label Int
cap (MVar ()
-> ((forall a. IO a -> IO a) -> IO ())
-> (forall a. IO a -> IO a)
-> IO ()
wrapAction MVar ()
mv (forall a. IO a -> IO a) -> IO ()
action)
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! ThreadId -> MVar () -> SnapThread
SnapThread ThreadId
tid MVar ()
mv


------------------------------------------------------------------------------
fork :: B.ByteString                          -- ^ thread label
     -> ((forall a . IO a -> IO a) -> IO ())  -- ^ user thread action, taking
                                              --   a restore function
     -> IO SnapThread
fork :: ByteString -> ((forall a. IO a -> IO a) -> IO ()) -> IO SnapThread
fork ByteString
label (forall a. IO a -> IO a) -> IO ()
action = do
    MVar ()
mv <- forall a. IO (MVar a)
newEmptyMVar
    forall a. IO a -> IO a
E.uninterruptibleMask_ forall a b. (a -> b) -> a -> b
$ do
        ThreadId
tid <- ByteString -> ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forkIOLabeledWithUnmaskBs ByteString
label (MVar ()
-> ((forall a. IO a -> IO a) -> IO ())
-> (forall a. IO a -> IO a)
-> IO ()
wrapAction MVar ()
mv (forall a. IO a -> IO a) -> IO ()
action)
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! ThreadId -> MVar () -> SnapThread
SnapThread ThreadId
tid MVar ()
mv


------------------------------------------------------------------------------
cancel :: SnapThread -> IO ()
cancel :: SnapThread -> IO ()
cancel = ThreadId -> IO ()
killThread forall b c a. (b -> c) -> (a -> b) -> a -> c
. SnapThread -> ThreadId
_snapThreadId


------------------------------------------------------------------------------
wait :: SnapThread -> IO ()
wait :: SnapThread -> IO ()
wait = forall (f :: * -> *) a. Functor f => f a -> f ()
void forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. MVar a -> IO a
readMVar forall b c a. (b -> c) -> (a -> b) -> a -> c
. SnapThread -> MVar ()
_snapThreadFinished


------------------------------------------------------------------------------
cancelAndWait :: SnapThread -> IO ()
cancelAndWait :: SnapThread -> IO ()
cancelAndWait SnapThread
t = SnapThread -> IO ()
cancel SnapThread
t forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> SnapThread -> IO ()
wait SnapThread
t


------------------------------------------------------------------------------
isFinished :: SnapThread -> IO Bool
isFinished :: SnapThread -> IO Bool
isFinished SnapThread
t =
    forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (forall a b. a -> b -> a
const Bool
True) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. MVar a -> IO (Maybe a)
tryReadMVar (SnapThread -> MVar ()
_snapThreadFinished SnapThread
t)


------------------------------------------------------------------------------
-- Internal functions follow
------------------------------------------------------------------------------
wrapAction :: MVar ()
           -> ((forall a . IO a -> IO a) -> IO ())
           -> ((forall a . IO a -> IO a) -> IO ())
wrapAction :: MVar ()
-> ((forall a. IO a -> IO a) -> IO ())
-> (forall a. IO a -> IO a)
-> IO ()
wrapAction MVar ()
mv (forall a. IO a -> IO a) -> IO ()
action forall a. IO a -> IO a
restore = ((forall a. IO a -> IO a) -> IO ()
action forall a. IO a -> IO a
restore forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall a. a -> a
inline IO ()
exit) forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` SomeException -> IO ()
onEx
  where
    onEx :: E.SomeException -> IO ()
    onEx :: SomeException -> IO ()
onEx !SomeException
_ = forall a. a -> a
inline IO ()
exit

    exit :: IO ()
exit = forall a. IO a -> IO a
E.uninterruptibleMask_ (forall a. MVar a -> a -> IO ()
putMVar MVar ()
mv forall a b. (a -> b) -> a -> b
$! ())