{-# LANGUAGE TypeFamilies, GeneralizedNewtypeDeriving, UnboxedTuples, MagicHash, Rank2Types, FlexibleInstances, MultiParamTypeClasses, UndecidableInstances #-}

-- | A monad that cleanly generalizes out implementation details of array manipulation in a monad.  In general, this is likely to be the most efficient array transformer implementation made available in this library, but if improperly used, elements of this implementation may lead to segfaults.
module Control.Monad.Array.ArrayT (ArrayM, ArrayT, runArrayM, runArrayM_, runArrayT, runArrayT_) where

import Control.Monad.Array.Class
import Control.Monad.ST.Class

import GHC.Exts
import GHC.ST(ST(..))

import Control.Monad.State.Strict
import Control.Monad.Reader.Class
import Control.Monad.Writer.Class
import Control.Monad.Fix
import Control.Monad.Trans

import Control.Monad
import Control.Monad.ST

data MArr s e = MArr {-# UNPACK #-} !Int e (MutableArray# s e)

-- | Monad controlling safe access to an underlying array.
type ArrayM s e = ArrayT e (ST s)
-- | Monad transformer that safely grants the underlying monad access to a mutable array.
newtype ArrayT e m a = ArrayT {runArrT :: StateT (MArr (StateThread m) e) m a} deriving (Monad, MonadFix, MonadPlus, MonadReader r, MonadWriter w)

instance MonadTrans (ArrayT e) where
	lift = ArrayT . lift

instance MonadState s m => MonadState s (ArrayT e m) where
	get = lift get
	put = lift . put

runArrayM :: Int -> e -> (forall s . ArrayM s e a) -> a
runArrayM n d m = runST $ runArrayT n d m

runArrayM_ :: Int -> (forall s . ArrayM s e a) -> a
runArrayM_ n = runArrayM n emptyElement

runArrayT :: (MonadST m, Monad m) => Int -> e -> ArrayT e m a -> m a
runArrayT n d m = liftST (newMArr n d) >>= evalStateT (runArrT m)

runArrayT_ :: (MonadST m, Monad m) => Int -> ArrayT e m a -> m a
runArrayT_ n = runArrayT n emptyElement

emptyElement :: e
emptyElement = error "Undefined array element"

instance (MonadST m, Monad m) => MonadArray (ArrayT e m) where
	{-# SPECIALIZE instance MonadArray (ArrayM s e) #-}
	{-# INLINE unsafeReadAt #-}
	{-# INLINE unsafeWriteAt #-}
	{-# INLINE askSize #-}
	{-# INLINE resize #-}

	type ArrayElem (ArrayT e m) = e

	unsafeReadAt i = ArrayT $ 	do	arr <- get
						liftST $ readMArr arr i
	unsafeWriteAt i x = ArrayT $ 	do	arr <- get
						liftST $ writeMArr arr i x
	askSize = ArrayT $ 	do	MArr n _ _ <- get
					return n
	resize n' = ArrayT $ 	do	a@(MArr n d _) <- get
					a' <- liftST $ newMArr n' d
					liftST $ mapM_ (\ i -> readMArr a i >>= writeMArr a' i) [0..n-1]
					put a'

newMArr :: Int -> e -> ST s (MArr s e)
newMArr (I# n) d = ST $ \ s -> case newArray# n d s of (# s', arr' #) -> (# s', MArr (I# n) d arr' #)

readMArr :: MArr s e -> Int -> ST s e
readMArr (MArr _ _ arr) (I# i#) = ST $ readArray# arr i#

writeMArr :: MArr s e -> Int -> e -> ST s ()
writeMArr (MArr _ _ arr) (I# i#) x = ST $ \ s -> (# writeArray# arr i# x s, () #)