{- This file is part of time-out.
 -
 - Written in 2016 by fr33domlover <fr33domlover@riseup.net>.
 -
 - ♡ Copying is an act of love. Please copy, reuse and share.
 -
 - The author(s) have dedicated all copyright and related and neighboring
 - rights to this software to the public domain worldwide. This software is
 - distributed without any warranty.
 -
 - You should have received a copy of the CC0 Public Domain Dedication along
 - with this software. If not, see
 - <http://creativecommons.org/publicdomain/zero/1.0/>.
 -}

{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses      #-}

-- | Monad transformer for running actions with a time limit. Provides a
-- scalable @MonadTimeout@ instance (at least for the case of a constant number
-- of long-running threads). If you need to use timeouts often in a
-- computation, this is probably better than "Control.Timeout".
module Control.Monad.Trans.Timeout
    ( TimeoutT ()
    , runTimeoutT
    , withTimeoutThrow
    , withTimeoutThrow'
    , withTimeoutCatch
    , withTimeoutCatch'
    )
where

import Control.Monad.Catch
import Control.Monad.Fix (MonadFix)
import Control.Monad.IO.Class
import Control.Monad.Timeout.Class
import Control.Monad.Trans.Alarm
import Control.Monad.Trans.Class
import Data.Time.Units (TimeUnit)

-- | Monad transformer which gives your monad stack an ability to run actions
-- with a timeout, and abort them if they don't finish within the time limit.
--
-- By default, e.g. if you 'lift' or 'liftIO' an action, it runs in the regular
-- way without a timeout. Use one of the timeout functions, such as
-- 'withTimeoutThrow', to use the timeout.
newtype TimeoutT m a = TimeoutT
    { unTT :: AlarmT m a
    }
    deriving
        ( -- Basics
          Functor
        , Applicative
        , Monad
          -- Extra monads from base
        , MonadFix
          -- Thread operations are IO
        , MonadIO
          -- This is a transformer after all
        , MonadTrans
          -- Exceptions
        , MonadCatch
        , MonadThrow
        , MonadMask
        )

instance (MonadIO m, MonadCatch m) => MonadTimeout (TimeoutT m) m where
    timeoutThrow = withTimeoutThrow'
    timeoutCatch = withTimeoutCatch'

runTimeoutT :: (TimeUnit t, MonadIO m, MonadMask m) => TimeoutT m a -> t -> m a
runTimeoutT act t = runAlarmT (unTT act) t

withTimeoutThrow :: (MonadIO m, MonadCatch m) => m a -> TimeoutT m a
withTimeoutThrow act = do
    mresult <- withTimeoutCatch act
    case mresult of
        Nothing     -> throwM Timeout
        Just result -> return result

withTimeoutThrow'
    :: (TimeUnit t, MonadIO m, MonadCatch m)
    => t
    -> m a
    -> TimeoutT m a
withTimeoutThrow' t act = do
    mresult <- withTimeoutCatch' t act
    case mresult of
        Nothing     -> throwM Timeout
        Just result -> return result

withTimeoutCatch :: (MonadIO m, MonadCatch m) => m a -> TimeoutT m (Maybe a)
withTimeoutCatch act = TimeoutT $ alarm act

withTimeoutCatch'
    :: (TimeUnit t, MonadCatch m, MonadIO m)
    => t
    -> m a
    -> TimeoutT m (Maybe a)
withTimeoutCatch' t act = TimeoutT $ alarm' t act