{-# LANGUAGE CPP, MagicHash, UnboxedTuples, Rank2Types, FlexibleInstances,
    MultiParamTypeClasses, UndecidableInstances, RecursiveDo #-}
{- |
   Module      :  Control.Monad.ST.Trans
   Copyright   :  Josef Svenningsson 2008-2010
                  (c) The University of Glasgow, 1994-2000
   License     :  BSD
 
   Maintainer  :  josef.svenningsson@gmail.com
   Stability   :  experimental
   Portability :  non-portable (GHC Extensions)

   This library provides a monad transformer version of the ST monad.

   Warning! This monad transformer should not be used with monads that
   can contain multiple answers, like the list monad. The reason is that 
   the will be duplicated across the different answers and this cause
   Bad Things to happen (such as loss of referential transparency). Safe 
   monads include the monads State, Reader, Writer, Maybe and 
   combinations of their corresponding monad transformers.

-}
module Control.Monad.ST.Trans(
      -- * The ST Monad Transformer
      STT,
      runST,
      -- * Mutable references
      STRef,
      newSTRef,
      readSTRef,
      writeSTRef,
      -- * Mutable arrays
      STArray,
      newSTArray,
      readSTArray,
      writeSTArray,
      boundsSTArray,
      numElementsSTArray,
      freezeSTArray,
      thawSTArray,
      runSTArray,
      -- * Unsafe Operations
      unsafeIOToSTT,
      unsafeSTToIO,
      unsafeSTRefToIORef,
      unsafeIORefToSTRef
      )where

import GHC.Base
import GHC.Arr (Ix(..), safeRangeSize, safeIndex, 
                Array(..), arrEleBottom)

import Control.Monad.ST.Trans.Internal

import Control.Monad.Fix
import Control.Monad.Trans
import Control.Monad.Error.Class
import Control.Monad.Reader.Class
import Control.Monad.State.Class
import Control.Monad.Writer.Class

import Control.Applicative

import Data.IORef

import Unsafe.Coerce
import System.IO.Unsafe

#if __GLASGOW_HASKELL__ < 708
isTrue# :: Bool -> Bool
isTrue# x = x
#endif

instance Monad m => Monad (STT s m) where
  return a = STT $ \st -> return (STTRet st a)
  STT m >>= k = STT $ \st -> 
    do ret <- m st
       case ret of
         STTRet new_st a -> 
             unSTT (k a) new_st
  fail msg = lift (fail msg)

instance MonadTrans (STT s) where
  lift m = STT $ \st ->
   do a <- m
      return (STTRet st a)
      
liftSTT :: STT s m a -> State# s -> m (STTRet s a)
liftSTT (STT m) s = m s

instance (MonadFix m) => MonadFix (STT s m) where
  mfix k = STT $ \ s -> mdo
    ans@(STTRet _ r) <- liftSTT (k r) s
    return ans

instance Functor (STTRet s) where
  fmap f (STTRet s a) = STTRet s (f a)

instance Functor m => Functor (STT s m) where
  fmap f (STT g) = STT $ \s# -> (fmap . fmap) f (g s#)

instance (Monad m, Functor m) => Applicative (STT s m) where
  pure a = STT $ \s# -> return (STTRet s# a)
  (STT m) <*> (STT n) = STT $ \s1 ->
                        do (STTRet s2 f) <- m s1
                           (STTRet s3 x) <- n s2
                           return (STTRet s3 (f x))

-- | Mutable references
data STRef s a = STRef (MutVar# s a)

-- | Create a new reference
newSTRef :: Monad m => a -> STT s m (STRef s a)
newSTRef init = STT $ \st1 ->
    case newMutVar# init st1 of
      (# st2, var #) -> return (STTRet st2 (STRef var))

-- | Reads the value of a reference
readSTRef :: Monad m => STRef s a -> STT s m a
readSTRef (STRef var) = STT $ \st1 -> 
    case readMutVar# var st1 of
      (# st2, a #) -> return (STTRet st2 a)

-- | Modifies the value of a reference
writeSTRef :: Monad m => STRef s a -> a -> STT s m ()
writeSTRef (STRef var) a = STT $ \st1 ->
    case writeMutVar# var a st1 of
      st2 -> return (STTRet st2 ())

instance Eq (STRef s a) where
  STRef v1 == STRef v2 = isTrue# (sameMutVar# v1 v2)

-- | Executes a computation in the 'STT' monad transformer
runST :: Monad m => (forall s. STT s m a) -> m a
runST m = let (STT f) = m
 -- the parenthesis is needed because of a bug in GHC's parser
          in do (STTRet st a) <- ( f realWorld# )
                return a

-- Instances of other monad classes

instance MonadError e m => MonadError e (STT s m) where
  throwError e = lift (throwError e)
  catchError (STT m) f = STT $ \st -> catchError (m st) 
                         (\e -> unSTT (f e) st)

instance MonadReader r m => MonadReader r (STT s m) where
  ask = lift ask
  local f (STT m) = STT $ \st -> local f (m st)

instance MonadState s m => MonadState s (STT s' m) where
  get = lift get
  put s = lift (put s)

instance MonadWriter w m => MonadWriter w (STT s m) where
  tell w = lift (tell w)
  listen (STT m)= STT $ \st1 -> do (STTRet st2 a, w) <- listen (m st1)
                                   return (STTRet st2 (a,w))
  pass (STT m) = STT $ \st1 -> pass (do (STTRet st2 (a,f)) <- m st1
                                        return (STTRet st2 a, f))

-- Mutable arrays. See the definition in GHC.Arr

-- | Mutable arrays
data STArray s i e = STArray !i !i !Int (MutableArray# s e)

instance Eq (STArray s i e) where
  STArray _ _ _ arr1# == STArray _ _ _ arr2# = isTrue# (sameMutableArray# arr1# arr2#)

-- | Creates a new mutable array
newSTArray :: (Ix i, Monad m) => (i,i) -> e -> STT s m (STArray s i e)
newSTArray (l,u) init = STT $ \s1# ->
    case safeRangeSize (l,u)            of { n@(I# n#) ->
    case newArray# n# init s1#          of { (# s2#, marr# #) ->
    return (STTRet s2# (STArray l u n marr#)) }}

-- | Returns the lowest and highest indices of the array
boundsSTArray :: STArray s i e -> (i,i)
boundsSTArray (STArray l u _ _) = (l,u)

-- | Returns the number of elements in the array
numElementsSTArray :: STArray s i e -> Int
numElementsSTArray (STArray _ _ n _) = n

-- | Retrieves an element from the array
readSTArray :: (Ix i, Monad m) => STArray s i e -> i -> STT s m e
readSTArray marr@(STArray l u n _) i =
    unsafeReadSTArray marr (safeIndex (l,u) n i)

unsafeReadSTArray :: (Ix i, Monad m) => STArray s i e -> Int -> STT s m e
unsafeReadSTArray (STArray _ _ _ marr#) (I# i#)
    = STT $ \s1# -> case readArray# marr# i# s1# of
                      (# s2#, e #) -> return (STTRet s2# e)

-- | Modifies an element in the array
writeSTArray :: (Ix i, Monad m) => STArray s i e -> i -> e -> STT s m () 
writeSTArray marr@(STArray l u n _) i e =
    unsafeWriteSTArray marr (safeIndex (l,u) n i) e

unsafeWriteSTArray :: (Ix i, Monad m) => STArray s i e -> Int -> e -> STT s m () 
unsafeWriteSTArray (STArray _ _ _ marr#) (I# i#) e = STT $ \s1# ->
    case writeArray# marr# i# e s1# of
        s2# -> return (STTRet s2# ())

-- | Copy a mutable array and turn it into an immutable array
freezeSTArray :: (Ix i,Monad m) => STArray s i e -> STT s m (Array i e)
freezeSTArray (STArray l u n@(I# n#) marr#) = STT $ \s1# ->
    case newArray# n# arrEleBottom s1#  of { (# s2#, marr'# #) ->
    let copy i# s3# | isTrue# (i# ==# n#) = s3#
                    | otherwise           =
            case readArray# marr# i# s3# of { (# s4#, e #) ->
            case writeArray# marr'# i# e s4# of { s5# ->
            copy (i# +# 1#) s5# }} in
    case copy 0# s2#                    of { s3# ->
    case unsafeFreezeArray# marr'# s3#  of { (# s4#, arr# #) ->
    return (STTRet s4# (Array l u n arr# )) }}}

unsafeFreezeSTArray :: (Ix i, Monad m) => STArray s i e -> STT s m (Array i e)
unsafeFreezeSTArray (STArray l u n marr#) = STT $ \s1# ->
    case unsafeFreezeArray# marr# s1#   of { (# s2#, arr# #) ->
    return (STTRet s2# (Array l u n arr# )) }

-- | Copy an immutable array and turn it into a mutable array
thawSTArray :: (Ix i, Monad m) => Array i e -> STT s m (STArray s i e)
thawSTArray (Array l u n@(I# n#) arr#) = STT $ \s1# ->
    case newArray# n# arrEleBottom s1#  of { (# s2#, marr# #) ->
    let copy i# s3# | isTrue# (i# ==# n#) = s3#
                    | otherwise           =
            case indexArray# arr# i#    of { (# e #) ->
            case writeArray# marr# i# e s3# of { s4# ->
            copy (i# +# 1#) s4# }} in
    case copy 0# s2#                    of { s3# ->
    return (STTRet s3# (STArray l u n marr# )) }}

unsafeThawSTArray :: (Ix i, Monad m) => Array i e -> STT s m (STArray s i e)
unsafeThawSTArray (Array l u n arr#) = STT $ \s1# ->
    case unsafeThawArray# arr# s1#      of { (# s2#, marr# #) ->
    return (STTRet s2# (STArray l u n marr# )) }

-- | A safe way to create and work with a mutable array before returning an
-- immutable array for later perusal.  This function avoids copying
-- the array before returning it.
runSTArray :: (Ix i, Monad m)
           => (forall s . STT s m (STArray s i e))
           -> m (Array i e)
runSTArray st = runST (st >>= unsafeFreezeSTArray)


{-# NOINLINE unsafeIOToSTT #-} 
unsafeIOToSTT :: (Monad m) => IO a -> STT s m a
unsafeIOToSTT m = return $! unsafePerformIO m

unsafeSTToIO :: STT s IO a -> IO a
unsafeSTToIO m = runST $ unsafeCoerce m

-- This should work, as STRef and IORef should have identical internal representation
unsafeSTRefToIORef  :: STRef s a -> IORef a
unsafeSTRefToIORef ref = unsafeCoerce ref

unsafeIORefToSTRef :: IORef a -> STRef s a
unsafeIORefToSTRef ref = unsafeCoerce ref