{- This file is part of time-out.
 -
 - Written in 2016 by fr33domlover <fr33domlover@riseup.net>.
 -
 - ♡ Copying is an act of love. Please copy, reuse and share.
 -
 - The author(s) have dedicated all copyright and related and neighboring
 - rights to this software to the public domain worldwide. This software is
 - distributed without any warranty.
 -
 - You should have received a copy of the CC0 Public Domain Dedication along
 - with this software. If not, see
 - <http://creativecommons.org/publicdomain/zero/1.0/>.
 -}

module Control.TimeOut
    ( timeout
    , delay
    )
where

import Control.Concurrent
import Control.Monad (when)
import Control.Monad.Catch
import Control.Monad.IO.Class
import Data.List (genericReplicate)
import Data.Maybe (isJust)
import Data.Time.Units

data Timeout = Timeout deriving Show

instance Exception Timeout

-- | If the action succeeds, return 'Just' the result. If a 'Timeout' exception
-- is thrown during the action, catch it and return 'Nothing'. Other exceptions
-- aren't caught.
catchTimeout :: (MonadIO m, MonadCatch m) => m a -> m (Maybe a)
catchTimeout action = catch (Just <$> action) $ \ Timeout -> return Nothing

-- | Run a monadic action with a time limit. If it finishes before that time
-- passes and returns value @x@, then @Just x@ is returned. If the timeout
-- passes, the action is aborted and @Nothing@ is returned. If the action
-- throws an exception, it is aborted and the exception is rethrown.
--
-- >>> timeout (3 :: Second) $ delay (1 :: Second) >> return "hello"
-- Just "hello"
--
-- >>> timeout (3 :: Second) $ delay (5 :: Second) >> return "hello"
-- Nothing
--
-- >>> timeout (1 :: Second) $ delay "hello"
-- *** Exception: hello
timeout :: (TimeUnit t, MonadIO m, MonadCatch m) => t -> m a -> m (Maybe a)
timeout time action = do
    tidMain <- liftIO myThreadId
    tidTemp <- liftIO $ forkIO $ delay time >> throwTo tidMain Timeout
    result <- catchTimeout action `onException` liftIO (killThread tidTemp)
    when (isJust result) $ liftIO $ killThread tidTemp
    return result

delayInt :: MonadIO m => Int -> m ()
delayInt usec = liftIO $ threadDelay usec

delayInteger :: MonadIO m => Integer -> m ()
delayInteger usec =
    if usec > 0
        then do
            let maxInt = maxBound :: Int
                (times, rest) = usec `divMod` toInteger maxInt
            sequence_ $ genericReplicate times $ delayInt maxInt
            delayInt $ fromInteger rest
        else return ()

-- | Suspend the current thread for the given amount of time.
--
-- Example:
--
-- > delay (5 :: Second)
delay :: (TimeUnit t, MonadIO m) => t -> m ()
delay = delayInteger . toMicroseconds