{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE PackageImports #-}
{-# LANGUAGE RankNTypes #-}
{- |
Module      :  Control.Monad.Fibre
Copyright   :  (c) Anupam Jain 2011
License     :  GNU GPL Version 3 (see the file LICENSE)

Maintainer  :  ajnsit@gmail.com
Stability   :  experimental
Portability :  non-portable (uses ghc extensions)

Represents computations with Choice and Parallelism.
-}

module Control.Monad.Fibre (
  module Control.Monad.Bi,
  Fibre(..),
  ffmap,
  runFibre,
) where

import Control.Monad (liftM, liftM2)
import "mtl" Control.Monad.Trans (MonadTrans(..), MonadIO(..))
import Control.Concurrent (forkIO)
import Control.Concurrent.STM (STM, newEmptyTMVar, takeTMVar, putTMVar, atomically)

import Control.Monad.Bi

---------------------------------------------
-- FIBRE DATA STRUCTURE (MONAD) DEFINITION --
---------------------------------------------

data Fibre m o where
  Ret   :: (Monad m) => o -> Fibre m o
  Lift  :: (Monad m) => m o -> Fibre m o
  (:&&:) :: (Monad m) => Fibre m o1 -> Fibre m o2 -> Fibre m (o1,o2)
  (:||:) :: (Monad m) => Fibre m o -> Fibre m o -> Fibre m o
  (:>>:) :: (Monad m) => Fibre m i -> (i -> Fibre m o) -> Fibre m o

instance (Monad m) => Monad (Fibre m) where
    return = Ret
    (>>=) = (:>>:)

instance MonadTrans Fibre where
    lift = Lift

instance (Monad m, Functor m) => Functor (Fibre m) where
    fmap f a = a >>= return.f

-- Kind of like fmap, but works only with generic functions of type (forall i. m i -> m i)
ffmap :: (Monad m) => (forall i. m i -> m i) -> Fibre m o -> Fibre m o
ffmap f (Lift m) = Lift $ f m
ffmap f (t1 :&&: t2) = (ffmap f t1) :&&: (ffmap f t2)
ffmap f (t1 :||: t2) = (ffmap f t1) :||: (ffmap f t2)
ffmap f (t :>>: ft) = (ffmap f t) :>>: ft
ffmap f t = t

instance (MonadIO m) => MonadIO (Fibre m) where
  liftIO = lift . liftIO


---------------------
-- RUNNING A FIBRE --
---------------------

-- Running a fibre requires a base monad to be a (MonadBi m IO)
runFibre :: MonadBi m IO => Fibre m o -> m o
runFibre (Ret o) = return o
runFibre (Lift m) = m
runFibre (t1 :&&: t2) = do
    t1' <- lower (runFibre t1)
    t2' <- lower (runFibre t2)
    x1 <- raise $ atomically newEmptyTMVar
    x2 <- raise $ atomically newEmptyTMVar
    raise $ do
        forkIO $ t1' >>= (atomically . putTMVar x1)
        forkIO $ t2' >>= (atomically . putTMVar x2)
    xv1 <- raise $ atomically $ takeTMVar x1
    xv2 <- raise $ atomically $ takeTMVar x2
    return (xv1,xv2)
runFibre (t1 :||: t2) = do
    t1' <- lower (runFibre t1)
    t2' <- lower (runFibre t2)
    x <- raise $ atomically newEmptyTMVar
    raise $ do
        forkIO $ t1' >>= (atomically . putTMVar x)
        forkIO $ t2' >>= (atomically . putTMVar x)
    raise $ atomically $ takeTMVar x
runFibre (t :>>: f) = do
    i <- runFibre t
    runFibre $ f i


-------------------------------
-- UTILITY MONADIC FUNCTIONS --
-------------------------------

-- Nothing here yet