{-# LANGUAGE UnboxedTuples, MagicHash, RankNTypes, FlexibleInstances, MultiParamTypeClasses, UndecidableInstances #-}

-- | A monad transformer that cleanly generalizes out implementation details of array manipulation in an array transformer.  In general, this is likely to be the most efficient array transformer implementation made available in this library, but if improperly used, elements of this implementation may lead to segfaults.
module Control.Monad.Array.ArrayT (ArrayT, ArrayM, runArrayT, runArrayT_, runArrayM, runArrayM_) where

import GHC.Exts
import GHC.ST(ST(..))

import Prelude hiding (getContents)
import Control.Monad.ST.Trans
import Control.Monad.ST
import Control.Monad.Fix
import Control.Monad.Array.Class
import Control.Monad
import Control.Monad.Trans
import Control.Monad.RWS.Class
import Control.Monad.State

data MArr s e = MArr {-# UNPACK #-} !Int e (MutableArray# s e)

-- | Monad transformer that safely grants the underlying monad access to a mutable array.
newtype ArrayT e m a = ArrayT {runArrT :: forall s . StateT (MArr s e) (STT s m) a}
newtype ArrayM e a = ArrayM {runArrM :: forall s . StateT (MArr s e) (ST s) a}

-- | Safely performs a monadic computation that statefully modifies a one-dimensional array with the specified default element.
runArrayT :: Monad m => Int		-- ^ Initial array size.
			-> e		-- ^ Default array element.
			-> ArrayT e m a -- ^ Array transformer.
			-> m a		-- ^ Monadically bound output.
runArrayT n d m = runSTT $ liftST (newMArr n d) >>= evalStateT (runArrT m)

runArrayT_ :: Monad m => Int -> ArrayT e m a -> m a
runArrayT_ n = runArrayT n emptyElement

runArrayM :: Int -> e -> ArrayM e a -> a
runArrayM n d m = runST $ newMArr n d >>= evalStateT (runArrM m)

runArrayM_ :: Int -> ArrayM e a -> a
runArrayM_ n = runArrayM n emptyElement

emptyElement = error "Undefined array element"

instance Monad m => Monad (ArrayT e m) where
	return x = ArrayT (return x)
	m >>= k = ArrayT (runArrT m >>= runArrT . k)
	fail s = ArrayT (lift (fail s))

instance Monad (ArrayM e) where
	return x = ArrayM (lift (return x))
	m >>= k = ArrayM (runArrM m >>= runArrM . k)
	fail s = ArrayM (fail s)
	m >> k = ArrayM (runArrM m >> runArrM k)

instance MonadFix (ArrayM e) where
	mfix f = ArrayM (mfix (runArrM . f))

instance MonadTrans (ArrayT e) where
	lift m = ArrayT (lift (lift m))

instance Monad m => MonadArray e (ArrayT e m) where
	{-# INLINE unsafeReadAt #-}
	{-# INLINE unsafeWriteAt #-}
	{-# INLINE getSize #-}
	{-# INLINE resize #-}
	unsafeReadAt i = ArrayT $ 	do	arr <- get
						liftST $ readMArr arr i
	unsafeWriteAt i x = ArrayT $ 	do	arr <- get
						liftST $ writeMArr arr i x
	getSize = ArrayT $ 	do	MArr n _ _ <- get
					return n
	resize n' = ArrayT $ 	do	a@(MArr n d _) <- get
					a' <- liftST $ newMArr n' d
					liftST $ mapM_ (\ i -> readMArr a i >>= writeMArr a' i) [0..n-1]
					put a'

instance MonadArray e (ArrayM e) where
	{-# INLINE unsafeReadAt #-}
	{-# INLINE unsafeWriteAt #-}
	{-# INLINE getSize #-}
	{-# INLINE resize #-}
	unsafeReadAt i = ArrayM $ 	do	arr <- get
						lift $ readMArr arr i
	unsafeWriteAt i x = ArrayM $ 	do	arr <- get
						lift $ writeMArr arr i x
	getSize = ArrayM $ 	do	MArr n _ _ <- get
					return n
	resize n' = ArrayM $ 	do	a@(MArr n d _) <- get
					a' <- lift $ newMArr n' d
					lift $ mapM_ (\ i -> readMArr a i >>= writeMArr a' i) [0..n-1]
					put a'

instance MonadState s m => MonadState s (ArrayT e m) where
	get = lift get
	put = lift . put

instance MonadReader r m => MonadReader r (ArrayT e m) where
	ask = lift ask
	local f = (lift . local f . return =<<)

instance MonadWriter w m => MonadWriter w (ArrayT e m) where
	tell = lift . tell
	listen = (lift . listen . return =<<)
	pass = (lift . pass . return =<<)

instance MonadPlus m => MonadPlus (ArrayT e m) where
	mzero = lift mzero
	ArrayT m1 `mplus` ArrayT m2 = ArrayT (m1 `mplus` m2)

instance MonadFix m => MonadFix (ArrayT e m) where
	mfix f = ArrayT (mfix (runArrT . f))

instance MonadIO m => MonadIO (ArrayT e m) where
	liftIO = lift . liftIO

instance MonadST s m => MonadST s (ArrayT e m) where
	liftST = lift . liftST

newMArr :: Int -> e -> ST s (MArr s e)
newMArr (I# n) d = ST $ \ s -> case newArray# n d s of (# s', arr' #) -> (# s', MArr (I# n) d arr' #)

readMArr :: MArr s e -> Int -> ST s e
readMArr (MArr n _ arr) i@(I# i#) = ST $ readArray# arr i#

writeMArr :: MArr s e -> Int -> e -> ST s ()
writeMArr (MArr n _ arr) i@(I# i#) x = ST $ \ s -> (# writeArray# arr i# x s, () #)