{-# LANGUAGE DeriveDataTypeable #-}
-- |Intended for internal use: Simple timeout mechanism
module System.SimpleTimeout
    ( TimeoutHandle
    , timeoutHandle
    , timeout
    ) where

import Control.Exception (Exception, handle)
import Control.Concurrent (forkIO, threadDelay, throwTo, ThreadId, myThreadId)
import Control.Concurrent.MVar (MVar, newMVar, newEmptyMVar, takeMVar, putMVar, swapMVar, modifyMVar)

import Data.Time.Clock (UTCTime, getCurrentTime, diffUTCTime)
import Data.Typeable (Typeable)

------------------------

-- |timeout exception
--
-- The @Double@ parameter documented at 'timeout'.
data TimeOutException  
    = TimeOutException Double
        deriving (TimeOutException -> TimeOutException -> Bool
(TimeOutException -> TimeOutException -> Bool)
-> (TimeOutException -> TimeOutException -> Bool)
-> Eq TimeOutException
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TimeOutException -> TimeOutException -> Bool
$c/= :: TimeOutException -> TimeOutException -> Bool
== :: TimeOutException -> TimeOutException -> Bool
$c== :: TimeOutException -> TimeOutException -> Bool
Eq, Typeable)

instance Show TimeOutException where
    show :: TimeOutException -> String
show (TimeOutException Double
d) = String
"<<timeout at " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show (Double -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
round (Double -> Integer) -> Double -> Integer
forall a b. (a -> b) -> a -> b
$ Double
100Double -> Double -> Double
forall a. Num a => a -> a -> a
*Double
d :: Integer) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"%>>"

instance Exception TimeOutException

---------------

-- |Abstract data structure used by 'TimeoutHandle' and 'timeout'.
newtype TimeoutHandle 
    = TimeutHandle (MVar 
        (Maybe [(ThreadId, UTCTime)]))
            -- ^ 
            -- @Nothing@: the timeout happened already
            -- @Just xs@: there is time left
            --   @xs@ contains the list of threads for which a 'TimeoutException' 
            --         will be thrown when the time is over.
            --   'UTCTime' is needed to compute the @Double@ parameter of the exception.

-- |Creates a 'TimeoutHandle'.
--
-- The @Double@ parameter is the time limit in seconds.
-- All operations behind 'timeout' will be stopped 
-- at the current time plus the time limit.
timeoutHandle :: Double -> IO TimeoutHandle
timeoutHandle :: Double -> IO TimeoutHandle
timeoutHandle Double
limit = do
    MVar (Maybe [(ThreadId, UTCTime)])
th <- Maybe [(ThreadId, UTCTime)]
-> IO (MVar (Maybe [(ThreadId, UTCTime)]))
forall a. a -> IO (MVar a)
newMVar (Maybe [(ThreadId, UTCTime)]
 -> IO (MVar (Maybe [(ThreadId, UTCTime)])))
-> Maybe [(ThreadId, UTCTime)]
-> IO (MVar (Maybe [(ThreadId, UTCTime)]))
forall a b. (a -> b) -> a -> b
$ [(ThreadId, UTCTime)] -> Maybe [(ThreadId, UTCTime)]
forall a. a -> Maybe a
Just []
    ThreadId
_ <- IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ MVar (Maybe [(ThreadId, UTCTime)]) -> IO ()
forall (t :: * -> *).
Foldable t =>
MVar (Maybe (t (ThreadId, UTCTime))) -> IO ()
killLater MVar (Maybe [(ThreadId, UTCTime)])
th
    TimeoutHandle -> IO TimeoutHandle
forall (m :: * -> *) a. Monad m => a -> m a
return (TimeoutHandle -> IO TimeoutHandle)
-> TimeoutHandle -> IO TimeoutHandle
forall a b. (a -> b) -> a -> b
$ MVar (Maybe [(ThreadId, UTCTime)]) -> TimeoutHandle
TimeutHandle MVar (Maybe [(ThreadId, UTCTime)])
th
  where

    killLater :: MVar (Maybe (t (ThreadId, UTCTime))) -> IO ()
killLater MVar (Maybe (t (ThreadId, UTCTime)))
th = do
        UTCTime
start <- IO UTCTime
getCurrentTime
        Int -> IO ()
threadDelay (Int -> IO ()) -> Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Double -> Int
forall a b. (RealFrac a, Integral b) => a -> b
round (Double -> Int) -> Double -> Int
forall a b. (a -> b) -> a -> b
$ Double
1000000 Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
limit
        Just t (ThreadId, UTCTime)
threads <- MVar (Maybe (t (ThreadId, UTCTime)))
-> Maybe (t (ThreadId, UTCTime))
-> IO (Maybe (t (ThreadId, UTCTime)))
forall a. MVar a -> a -> IO a
swapMVar MVar (Maybe (t (ThreadId, UTCTime)))
th Maybe (t (ThreadId, UTCTime))
forall a. Maybe a
Nothing

        UTCTime
end <- IO UTCTime
getCurrentTime
        let whole :: NominalDiffTime
whole = UTCTime
end UTCTime -> UTCTime -> NominalDiffTime
`diffUTCTime` UTCTime
start

        let kill :: (ThreadId, UTCTime) -> IO ()
kill (ThreadId
x, UTCTime
time) 
                = ThreadId
x ThreadId -> TimeOutException -> IO ()
forall e. Exception e => ThreadId -> e -> IO ()
`throwTo` 
                    Double -> TimeOutException
TimeOutException (NominalDiffTime -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac (NominalDiffTime -> Double) -> NominalDiffTime -> Double
forall a b. (a -> b) -> a -> b
$ (UTCTime
time UTCTime -> UTCTime -> NominalDiffTime
`diffUTCTime` UTCTime
start) NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Fractional a => a -> a -> a
/ NominalDiffTime
whole)

        ((ThreadId, UTCTime) -> IO ()) -> t (ThreadId, UTCTime) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (ThreadId, UTCTime) -> IO ()
kill t (ThreadId, UTCTime)
threads


-- | Stop an operation at a time given by 'timeoutHandle'.
--
-- The @Double@ parameter is a percent between 0 and 1.
-- 
--  * 0: 'timeout' was called right after the 'TimeoutHandle' was created.
--
--  * 1: 'timeout' was called after the time of the timeout.
--
--  * near to 1: 'timeout' was called right before the time of the timeout.
--
--  * Other values: proportional to the time spend by the operation.
timeout 
    :: TimeoutHandle    -- ^ knows the time of the timeout and the creation time of itself
    -> (Double -> IO a) -- ^ timeout handling action for which will the percent will be supplied
    -> IO a             -- ^ the operation to timeout
    -> IO a
timeout :: TimeoutHandle -> (Double -> IO a) -> IO a -> IO a
timeout (TimeutHandle MVar (Maybe [(ThreadId, UTCTime)])
th) Double -> IO a
handleTimeout IO a
operation = do
    MVar a
result <- IO (MVar a)
forall a. IO (MVar a)
newEmptyMVar

    let handleTimeoutException :: TimeOutException -> IO ()
handleTimeoutException (TimeOutException Double
d) 
            = Double -> IO a
handleTimeout Double
d IO a -> (a -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MVar a -> a -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar a
result

    ThreadId
_ <- IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ (TimeOutException -> IO ()) -> IO () -> IO ()
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
handle TimeOutException -> IO ()
handleTimeoutException (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        Bool
b <- MVar (Maybe [(ThreadId, UTCTime)])
-> (Maybe [(ThreadId, UTCTime)]
    -> IO (Maybe [(ThreadId, UTCTime)], Bool))
-> IO Bool
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar MVar (Maybe [(ThreadId, UTCTime)])
th ((Maybe [(ThreadId, UTCTime)]
  -> IO (Maybe [(ThreadId, UTCTime)], Bool))
 -> IO Bool)
-> (Maybe [(ThreadId, UTCTime)]
    -> IO (Maybe [(ThreadId, UTCTime)], Bool))
-> IO Bool
forall a b. (a -> b) -> a -> b
$ \Maybe [(ThreadId, UTCTime)]
b -> case Maybe [(ThreadId, UTCTime)]
b of
            Maybe [(ThreadId, UTCTime)]
Nothing -> (Maybe [(ThreadId, UTCTime)], Bool)
-> IO (Maybe [(ThreadId, UTCTime)], Bool)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe [(ThreadId, UTCTime)]
forall a. Maybe a
Nothing, Bool
False)
            Just [(ThreadId, UTCTime)]
xs -> do
                ThreadId
pid <- IO ThreadId
myThreadId
                UTCTime
time <- IO UTCTime
getCurrentTime
                (Maybe [(ThreadId, UTCTime)], Bool)
-> IO (Maybe [(ThreadId, UTCTime)], Bool)
forall (m :: * -> *) a. Monad m => a -> m a
return ([(ThreadId, UTCTime)] -> Maybe [(ThreadId, UTCTime)]
forall a. a -> Maybe a
Just ([(ThreadId, UTCTime)] -> Maybe [(ThreadId, UTCTime)])
-> [(ThreadId, UTCTime)] -> Maybe [(ThreadId, UTCTime)]
forall a b. (a -> b) -> a -> b
$ (ThreadId
pid,UTCTime
time)(ThreadId, UTCTime)
-> [(ThreadId, UTCTime)] -> [(ThreadId, UTCTime)]
forall a. a -> [a] -> [a]
:[(ThreadId, UTCTime)]
xs, Bool
True)
        a
x <- if Bool
b 
            then IO a
operation
            else Double -> IO a
handleTimeout Double
1
        MVar a -> a -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar a
result a
x

    MVar a -> IO a
forall a. MVar a -> IO a
takeMVar MVar a
result