{-# LANGUAGE MagicHash, UnboxedTuples, RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE Unsafe #-}
module Control.Monad.STE.Internal
(
  STE(..)
  ,unSTE
  ,STERep
  ,STEret(..)
  ,runSTE
  ,throwSTE
  ,handleSTE
  
  ,unsafeInterleaveSTE
  ,liftSTE
  ,fixSTE
  ,runBasicSTE
  ,RealWorld
  ,unsafeIOToSTE
  ,unsafeSTEToIO
  )
  where
#if MIN_VERSION_ghc_prim(0,5,0)
import GHC.Prim (State#, raiseIO#, catch#)
#else
import GHC.Prim (State#, raiseIO#, catch#, realWorld#)
#endif
import qualified Control.Monad.Catch as CMC
import Control.Exception as Except
import Control.Monad (ap)
import qualified Control.Monad.Fix as MF
import Control.Monad.Primitive
import Data.Typeable
import Unsafe.Coerce (unsafeCoerce)
import GHC.IO(IO(..))
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative
#endif
#if MIN_VERSION_ghc_prim(0,5,0)
import GHC.Magic(runRW#)
#endif
newtype STE e s a = STE  (STERep s a)
unSTE :: STE e s a -> STERep s a
unSTE = \(STE a) ->  a
type STERep s a = State# s -> (# State# s, a #)
data STEret s a = STEret (State# s) a
liftSTE :: STE e s a -> State# s -> STEret s a
liftSTE (STE m) = \s -> case m s of (# s', r #) -> STEret s' r
{-# NOINLINE unsafeInterleaveSTE #-}
unsafeInterleaveSTE :: STE e s a -> STE e s a
unsafeInterleaveSTE (STE m) = STE ( \ s ->
    let
        r = case m s of (# _, res #) -> res
    in
    (# s, r #)
  )
fixSTE :: (a -> STE e s a) -> STE e s a
fixSTE k = STE $ \ s ->
    let ans       = liftSTE (k r) s
        STEret _ r = ans
    in
    case ans of STEret s' x -> (# s', x #)
instance Functor (STE e s) where
    fmap f (STE m) = STE $ \ s ->
      case (m s) of { (# new_s, r #) ->
      (# new_s, f r #) }
instance Applicative (STE e s) where
    {-# INLINE pure  #-}
    {-# INLINE (<*>) #-}
    {-# INLINE (*> ) #-}
    pure = return
    (*>) = \ m k ->  m >>= \ _ -> k
    (<*>) = ap
instance Monad (STE e s) where
    {-# INLINE return #-}
    {-# INLINE (>>)   #-}
    {-# INLINE (>>=)  #-}
    return x = STE (\ s -> (# s, x #))
    m >> k   = m >>= \ _ -> k
    (STE m) >>= k
      = STE (\ s ->
        case (m s) of { (# new_s, r #) ->
        case (k r) of { STE k2 ->
        (k2 new_s) }})
instance MF.MonadFix (STE e s) where
  mfix = fixSTE
instance PrimMonad (STE e s) where
  type PrimState (STE e s) = s
  primitive = \ m ->  STE m
  {-# INLINE primitive #-}
instance PrimBase (STE e s) where
  internal (STE p) = \ s# -> case p s# of
                          (# a , b #) -> (# a , b #)
  {-# INLINE internal #-}
instance (Except.SomeException ~ err) =>  CMC.MonadThrow (STE err s) where
  throwM x = throwSTE  $ toException x
{-# INLINE runSTE #-} 
runSTE ::  (forall s. STE e s a) -> (Either e a  -> b) -> b
runSTE = \ st  f -> f  $
            runBasicSTE (privateCatchSTE st)
{-# INLINE handleSTE #-}
handleSTE :: (Either e a -> b) -> (forall s. STE e s a)  -> b
handleSTE f st = runSTE st f
throwSTE :: forall e s a .  e -> STE e s a
throwSTE err = unsafeIOToSTE  $
    IO (raiseIO# (toException $ STException $ ( Box $ unsafeCoerce err)))
privateCatchSTE:: forall e s b . STE e s b  -> STE e s (Either e b)
privateCatchSTE = \ steAct  ->
      unsafeIOToSTE $
        IO  (catch# (unsafeCoerce $ unSTE $ fmap Right steAct) handler')
  where
    
    handler' :: SomeException -> STERep RealWorld (Either e b)
    handler' e = case (fromException e) of
        Just (STException (Box val)) -> \ s -> (# s , Left $ unsafeCoerce val #)
        Nothing -> raiseIO# e
unsafeIOToSTE        :: IO a -> STE e s a
unsafeIOToSTE (IO io) = STE $ \ s -> (unsafeCoerce io) s
unsafeSTEToIO :: STE e s a -> IO a
unsafeSTEToIO (STE m) = IO (unsafeCoerce m)
#if MIN_VERSION_ghc_prim(0,5,0)
runBasicSTE :: (forall s. STE e s a) -> a
runBasicSTE (STE st_rep) = case runRW# st_rep of (# _, a #) -> a
{-# INLINE runBasicSTE #-}
#else
runBasicSTE :: (forall s. STE e s a) -> a
runBasicSTE st = runSTERep (case st of { STE st_rep -> st_rep })
{-# INLINE runBasicSTE #-}
runSTERep :: (forall s. STERep  s a) -> a
runSTERep st_rep = case st_rep realWorld# of
                        (# _, r #) -> r
{-# NOINLINE runSTERep #-}
#endif
#if MIN_VERSION_base(4,8,0)
data Box a = Box {-# NOUNPACK #-} a
#else
data Box a = Box  a
#endif
data STException = STException  (Box ())
  deriving Typeable
instance Show (STException ) where
  show (STException _) = "STException(..)! did you use the Unsafe/internal STE interface?"
instance Exception (STException)