{-# LANGUAGE Trustworthy, Rank2Types, MagicHash, UnboxedTuples, BangPatterns #-}

module Control.Monad.IOT (IOT, run, module Control.Monad.Trans, module Control.Monad.Identity, module Control.Monad.Morph) where

import GHC.IO (IO(IO))
import GHC.Prim
import Control.Monad.Trans (MonadIO(..))
import Control.Monad.Identity
import Control.Monad.Morph
import Control.Monad
import Control.Applicative
import Control.Concurrent.MVar
import Data.Typeable
import Unsafe.Coerce

data State = State (State# RealWorld) !(MVar ())

-- | An IO monad transformer.
--
-- 'IOT' cannot be unwrapped in the usual way -- the monad inside it
-- has to be unwrapped. This is done using 'run', and 'hoist' from mmorph.
--
-- Most of the safety of the IO monad is ensured statically.
-- However, to ensure that the same RealWorld token is not
-- used multiple times, a runtime check is necessary. Among
-- the alternatives that perform I/O, the first alternative
-- forced by a concatenation of 'hoist's will contain a result,
-- and subsequent alternatives will be errors.
--
-- Therefore, a concatenation of 'hoists' out of a monad defines
-- at most one path of RealWorld token use. Here is an example using
-- the binary tree monad:
--
-- >>> let io :: IOT Tree () = lift (Node (Leaf 1) (Leaf 2)) >>= liftIO . print
--
-- >>> run $ hoist (\(Node (Leaf x) _) -> Identity x) io
-- 1
--
-- >>> run $ hoist (\(Node _ (Leaf x)) -> Identity x) io
-- 2
--
-- >>> run $ hoist (\(Node (Leaf _) (Leaf x)) -> Identity x) io
-- 1
-- *** Exception: IOT: double RealWorld use
--
newtype IOT m t = IOT (State -> m (State, t))

instance (Monad m) => Monad (IOT m) where
	{-# INLINE return #-}
	return x = IOT $ \s -> return (s, x)
	{-# INLINE (>>=) #-}
	IOT f >>= g = IOT $ \s -> f s >>= \(s, x) -> let IOT h = g x in h s

instance (Monad m) => Applicative (IOT m) where
	{-# INLINE pure #-}
	pure = return
	{-# INLINE (<*>) #-}
	(<*>) = ap

instance (Monad m) => Functor (IOT m) where
	{-# INLINE fmap #-}
	fmap f m = m >>= return . f

err = error "IOT: double RealWorld use"

instance (Monad m) => MonadIO (IOT m) where
	{-# INLINE liftIO #-}
	liftIO m = IOT $ \(State s mv) -> let
			IO f = do
				tryTakeMVar mv >>= maybe err return
				liftM2 (,) m (newMVar ());
			(# s', (x, mv') #) = f s in
		return (State s' mv', x)

instance MonadTrans IOT where
	{-# INLINE lift #-}
	lift m = IOT $ \s -> liftM (\x -> (s, x)) m

{-# INLINE _hoist #-}
_hoist :: (forall t. m t -> n t) -> IOT m t -> IOT n t
_hoist f (IOT g) = IOT (f . g)

-- Squashes together two layers of IOTs.
{-# INLINE _squash #-}
_squash :: (Monad m) => IOT (IOT m) t -> IOT m t
_squash (IOT f) = do
	mv <- liftIO $ newMVar ()
	(State _ m, x) <- IOT (\st@(State s _) -> let IOT g = f st in g (State s mv))
	liftIO (tryTakeMVar m) >>= maybe err return
	return x

instance MFunctor IOT where
	{-# INLINE hoist #-}
	hoist = _hoist

instance MMonad IOT where
	{-# INLINE embed #-}
	embed f = _squash . _hoist f

-- | Run an IOT yielding an IO computation. The 'Identity' monad is a trivial wrapper around IO.
{-# INLINE run #-}
run :: IOT Identity t -> IO t
run (IOT f) = do
	mv <- newMVar ()
	(m, x) <- IO (\s -> case f (State s mv) of
		Identity (State s' m, x) -> (# s', (m, x) #))
	tryTakeMVar m >>= maybe err return
	return x

{-# RULES
"void/newMVar" forall x. void (newMVar x) = return ()
"newMVar/tryTakeMVar" forall x. newMVar x >>= tryTakeMVar = return (Just x)
  #-}