-- |
-- Module:     Control.Wire.Trans.Fork
-- Copyright:  (c) 2011 Ertugrul Soeylemez
-- License:    BSD3
-- Maintainer: Ertugrul Soeylemez <es@ertes.de>
--
-- Wire concurrency.
--
-- /Warning/: This module is highly experimental and currently causes
-- space leaks.  Please use wire concurrency only for short-lived
-- threads.

module Control.Wire.Trans.Fork
    ( -- * Embedding concurrent wires
      WFork(..),

      -- * Wire thread manager
      WireMgr,
      startWireMgr,
      stopWireMgr,
      withWireMgr,

      -- * Wire threads
      -- ** Channels
      WireChan,
      feedWireChan,
      readWireChan,
      -- ** Threads
      WireThread,
      killWireThread
    )
    where

import qualified Data.Map as M
import Control.Applicative
import Control.Arrow
import Control.Concurrent.Lifted
import Control.Concurrent.STM
import Control.Exception.Lifted
import Control.Monad
import Control.Monad.Fix
import Control.Monad.Trans.Control
import Control.Monad.Trans
import Control.Wire.Types
import Data.Map (Map)
import Data.Monoid


{-# WARNING WFork "Wire concurrency is not stable at the moment!" #-}


-- | Forking wire transformer.  Creates a concurrent wire thread and
-- opens a communication channel to it.

class Arrow (>~) => WFork (>~) where
    -- | Feed a wire thread with additional input.
    --
    -- * Depends: Current instant.
    feedWire :: Wire e (>~) (WireChan a b, a) ()

    -- | Fork the input wire using the input wire manager.
    --
    -- Note: This wire forks at every instant.  In many cases you will
    -- want to use the 'swallow' wire transformer with this.
    --
    -- * Depends: Current instant.
    forkWire :: Wire e (>~) (Wire e (>~) a b, WireMgr)
                        (WireChan a b, WireThread)

    -- | Asks the given wire for its next output.
    --
    -- * Depends: Current instant.
    --
    -- * Inhibits: When there is no data.
    queryWire :: Monoid e => Wire e (>~) (WireChan a b) b

instance (MonadBaseControl IO m, MonadIO m) => WFork (Kleisli m) where
    -- feedWire
    feedWire =
        mkFixM $ \(wc, x') -> do
            let ichan = wcInputChan wc
            liftIO . atomically $ writeTChan ichan x'
            return (Right ())

    -- forkWire
    forkWire =
        mkFixM $ \(thrW, mgr) -> do
            ichan <- liftIO newTChanIO
            ochan <- liftIO newTChanIO
            doneVar <- liftIO (newTVarIO False)
            quitVar <- liftIO (newTVarIO False)

            let wc = WireChan { wcInputChan = ichan,
                                wcOutputChan = ochan }

            mgrOp mgr $ do
                tid <- fork (thread ichan ochan quitVar doneVar thrW)

                let wt = WireThread { wtDoneVar  = doneVar,
                                      wtThreadId = tid,
                                      wtQuitVar  = quitVar }

                let thrsVar = wmThrsVar mgr
                liftIO . atomically $ do
                    thrs <- readTVar thrsVar
                    writeTVar thrsVar (M.insert tid wt thrs)

                return (Right (wc, wt))

        where
        thread ichan ochan quitVar doneVar =
            fix $ \loop w' -> do
                mx' <- liftIO . atomically $
                          Just <$> readTChan ichan <|>
                          Nothing <$ (readTVar quitVar >>= check)
                case mx' of
                  Just x' -> do
                      (mx, w) <- toGenM w' x'
                      either (const $ return ()) (liftIO . atomically . writeTChan ochan) mx
                      loop w
                  Nothing -> do
                      liftIO (atomically $ writeTVar doneVar True)

    -- queryWire
    queryWire =
        mkFixM $ \wc -> do
            let ochan = wcOutputChan wc
            liftIO . atomically $
                Right <$> readTChan ochan <|>
                return (Left mempty)


-- | A wire channel allows you to send input to and receive output from
-- a concurrently running wire.

data WireChan a b =
    WireChan {
      wcInputChan  :: !(TChan a),  -- ^ Input channel.
      wcOutputChan :: !(TChan b)   -- ^ Output channel.
    }


-- | A wire thread manager keeps track of created wire threads.

data WireMgr =
    WireMgr {
      wmFreeVar :: !(TVar Bool),
      wmThrsVar :: !(TVar (Map ThreadId WireThread))
    }


-- | A wire thread is a concurrently running wire.

data WireThread
    = WireThread {
        wtDoneVar  :: !(TVar Bool),     -- ^ True, when wire has quitted.
        wtThreadId :: !ThreadId,        -- ^ Thread id.
        wtQuitVar  :: !(TVar Bool)      -- ^ Set to true to terminate the wire.
      }


-- | Feed the given wire thread with input.

feedWireChan :: WireChan a b -> a -> IO ()
feedWireChan (wcInputChan -> ichan) = atomically . writeTChan ichan


-- | Kill the given wire thread.

killWireThread :: WireMgr -> WireThread -> IO ()
killWireThread mgr thr = do
    let WireThread { wtDoneVar  = doneVar,
                     wtThreadId = tid,
                     wtQuitVar  = quitVar } = thr
        thrsVar = wmThrsVar mgr
    mgrOp mgr $ do
        thrs <- readTVarIO thrsVar
        atomically (writeTVar quitVar True)
        atomically $ do
            readTVar doneVar >>= check
            writeTVar thrsVar (M.delete tid thrs)


-- | Perform a manager operation safely.

mgrOp :: (MonadBaseControl IO m, MonadIO m) => WireMgr -> m a -> m a
mgrOp mgr c = do
    let freeVar = wmFreeVar mgr
    liftIO . atomically $ do
        readTVar freeVar >>= check
        writeTVar freeVar False

    c `finally` liftIO (atomically $ writeTVar freeVar True)


-- | Read the given wire's next output.

readWireChan :: WireChan a b -> IO b
readWireChan (wcOutputChan -> ochan) = atomically (readTChan ochan)


-- | Start a wire manager.

startWireMgr :: IO WireMgr
startWireMgr = do
    freeVar <- newTVarIO True
    thrsVar <- newTVarIO M.empty
    return WireMgr { wmFreeVar = freeVar,
                     wmThrsVar = thrsVar }


-- | Stop a wire manager terminating all threads it keeps track of.

stopWireMgr :: WireMgr -> IO ()
stopWireMgr mgr =
    mgrOp mgr $ do
        let thrsVar = wmThrsVar mgr
        thrs <- fmap M.assocs (readTVarIO thrsVar)
        forM_ thrs $ \(_, wtQuitVar -> quitVar) ->
            atomically (writeTVar quitVar True)
        forM_ thrs $ \(tid, wtDoneVar -> doneVar) -> do
            atomically (readTVar doneVar >>= check)
            killThread tid
        atomically (writeTVar thrsVar M.empty)


-- | Convenient wrapper around 'startWireMgr' and 'stopWireMgr'.

withWireMgr :: (MonadBaseControl IO m, MonadIO m) => (WireMgr -> m a) -> m a
withWireMgr k = do
    mgr <- liftIO startWireMgr
    k mgr `finally` liftIO (stopWireMgr mgr)