{-# LANGUAGE UndecidableInstances, MultiParamTypeClasses, TypeFamilies, FlexibleInstances#-}
module Control.Monad.Array.Class where

import qualified Control.Monad.State.Lazy as LazyS
import qualified Control.Monad.State.Strict as StrictS
import Control.Monad.Reader
import Control.Monad.List
import qualified Control.Monad.Writer.Lazy as LazyW
import qualified Control.Monad.Writer.Strict as StrictW
import Control.Monad.Maybe
import Data.Monoid
import Control.Monad

-- | Type class abstraction for a monad with access to an underlying mutable array indexed by 'Int's.  Minimal implementation: 'readAt' or 'unsafeReadAt', 'writeAt' or 'unsafeWriteAt', 'askSize', 'resize' or 'ensureSize'.
class Monad m => MonadArray m where
	type ArrayElem m
	{-# INLINE readAt #-}
	{-# INLINE unsafeReadAt #-}
	{-# INLINE writeAt #-}
	{-# INLINE unsafeWriteAt #-}
	{-# INLINE replaceAt #-}
	{-# INLINE askElems #-}
	{-# INLINE askSize #-}
	{-# INLINE resize #-}
	{-# INLINE ensureSize #-}
	readAt :: Int -> m (ArrayElem m)
	unsafeReadAt :: Int -> m (ArrayElem m)
	writeAt :: Int -> ArrayElem m -> m ()
	unsafeWriteAt :: Int -> ArrayElem m -> m ()
	replaceAt :: Int -> ArrayElem m -> m (ArrayElem m)
	askElems :: m [ArrayElem m]
	askAssocs :: m [(Int, ArrayElem m)]
	askSize :: m Int
	resize :: Int -> m ()
	ensureSize :: Int -> m ()
	readAt i = 	do	n <- askSize
				if i >= 0 && i < n then unsafeReadAt i else fail "Index out of bounds"
	unsafeReadAt = 	readAt
	writeAt i x = 	do	n <- askSize
				if i >= 0 && i < n then unsafeWriteAt i x else fail "Index out of bounds"
	unsafeWriteAt = writeAt
	askAssocs = do	n <- askSize
			mapM (\ i -> liftM ((,) i) (unsafeReadAt i)) [0..n-1]
	askElems = liftM (map snd) askAssocs
	ensureSize n =	do	m <- askSize
				when (m < n) (resize n)
	resize = ensureSize
	replaceAt i x = do	y <- readAt i
				writeAt i x
				return y

instance MonadArray m => MonadArray (LazyS.StateT s m) where
	type ArrayElem (LazyS.StateT s m) = ArrayElem m
	readAt = lift . readAt
	unsafeReadAt = lift . unsafeReadAt
	writeAt i x = lift (writeAt i x)
	unsafeWriteAt i x = lift (unsafeWriteAt i x)
	replaceAt i x = lift (replaceAt i x)
	askElems = lift askElems
	askSize = lift askSize
	resize = lift . resize
	ensureSize = lift . ensureSize

instance MonadArray m => MonadArray (StrictS.StateT s m) where
	type ArrayElem (StrictS.StateT s m) = ArrayElem m
	readAt = lift . readAt
	unsafeReadAt = lift . unsafeReadAt
	writeAt i x = lift (writeAt i x)
	unsafeWriteAt i x = lift (unsafeWriteAt i x)
	replaceAt i x = lift (replaceAt i x)
	askElems = lift askElems
	askSize = lift askSize
	resize = lift . resize
	ensureSize = lift . ensureSize

instance MonadArray m => MonadArray (ReaderT r m) where
	type ArrayElem (ReaderT r m) = ArrayElem m
	readAt = lift . readAt
	unsafeReadAt = lift . unsafeReadAt
	writeAt i x = lift (writeAt i x)
	unsafeWriteAt i x = lift (unsafeWriteAt i x)
	replaceAt i x = lift (replaceAt i x)
	askElems = lift askElems
	askSize = lift askSize
	resize = lift . resize
	ensureSize = lift . ensureSize

instance (Monoid w, MonadArray m) => MonadArray (StrictW.WriterT w m) where
	type ArrayElem (StrictW.WriterT w m) = ArrayElem m
	readAt = lift . readAt
	unsafeReadAt = lift . unsafeReadAt
	writeAt i x = lift (writeAt i x)
	unsafeWriteAt i x = lift (unsafeWriteAt i x)
	replaceAt i x = lift (replaceAt i x)
	askElems = lift askElems
	askSize = lift askSize
	resize = lift . resize
	ensureSize = lift . ensureSize

instance (Monoid w, MonadArray m) => MonadArray (LazyW.WriterT w m) where
	type ArrayElem (LazyW.WriterT w m) = ArrayElem m
	readAt = lift . readAt
	unsafeReadAt = lift . unsafeReadAt
	writeAt i x = lift (writeAt i x)
	unsafeWriteAt i x = lift (unsafeWriteAt i x)
	replaceAt i x = lift (replaceAt i x)
	askElems = lift askElems
	askSize = lift askSize
	resize = lift . resize
	ensureSize = lift . ensureSize

instance MonadArray m => MonadArray (MaybeT m) where
	type ArrayElem (MaybeT m) = ArrayElem m
	readAt = lift . readAt
	unsafeReadAt = lift . unsafeReadAt
	writeAt i x = lift (writeAt i x)
	unsafeWriteAt i x = lift (unsafeWriteAt i x)
	replaceAt i x = lift (replaceAt i x)
	askElems = lift askElems
	askSize = lift askSize
	resize = lift . resize
	ensureSize = lift . ensureSize

instance MonadArray m => MonadArray (ListT m) where
	type ArrayElem (ListT m) = ArrayElem m
	readAt = lift . readAt
	unsafeReadAt = lift . unsafeReadAt
	writeAt i x = lift (writeAt i x)
	unsafeWriteAt i x = lift (unsafeWriteAt i x)
	replaceAt i x = lift (replaceAt i x)
	askElems = lift askElems
	askSize = lift askSize
	resize = lift . resize
	ensureSize = lift . ensureSize