{-# LANGUAGE DataKinds, KindSignatures, ScopedTypeVariables,
    MultiParamTypeClasses, FlexibleContexts, FlexibleInstances, TypeFamilies,
    TypeOperators, RankNTypes, MagicHash, GeneralizedNewtypeDeriving #-}

-- | State bag monad transformer which runs on a PrimMonad stack.
module Control.Monad.Trans.StateBag.Primitive (
    StateBaggerT,
    runBagger,
    addItem,
    topItem,
    stackItem,
    StateBagT,
    makeBag,
    getItem,
    putItem,
    modifyItemM,
    ElementCount(),
    ElementIndex(),
) where

import Control.Applicative
import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.Reader
import Control.Monad.Trans.State
import Control.Monad.Trans.StateBag.Internal
import Control.Monad.IO.Class
import Control.Monad.Primitive
import Data.Proxy
import GHC.Prim (Any, unsafeCoerce#)
import qualified Data.Vector.Mutable as V

newtype BagImpl s (bag :: [*]) = BagImpl (V.MVector s Any)

-- | Monad transformer for building state bags. 
newtype StateBaggerT full (bag :: [*]) m a
    = StateBaggerT (ReaderT (BagImpl (PrimState m) full) m a)
    deriving (Functor, Applicative, Monad, MonadIO)

instance MonadTrans (StateBaggerT full bag) where
    lift = StateBaggerT . lift

instance (PrimMonad m) => PrimMonad (StateBaggerT full bag m) where
  type PrimState (StateBaggerT full bag m) = PrimState m
  primitive = lift . primitive

-- | Run an empty state bagger on top of a monad stack.
runBagger :: forall full m a. (PrimMonad m, ElementCount full) =>
    StateBaggerT full '[] m a -> m a
runBagger (StateBaggerT r) = do
    vec <- V.new $ elemCount (Proxy :: Proxy full)
    runReaderT r $ BagImpl vec

-- | Run a state bagger with one additional item.
addItem :: forall item full bag m a. (PrimMonad m, ElementIndex item full) =>
    item -> StateBaggerT full (item ': bag) m a -> StateBaggerT full bag m a
addItem item (StateBaggerT chain) = StateBaggerT $ do
    (BagImpl vec) <- ask
    V.write vec (elemIndex (Proxy :: Proxy item) (Proxy :: Proxy full)) $
        unsafeCoerce# item
    lift $ runReaderT chain $ BagImpl vec

-- | Get the value of the top item in a state bagger.
topItem :: forall item full bag m. (PrimMonad m, ElementIndex item full) =>
    StateBaggerT full (item ': bag) m item
topItem = StateBaggerT $ do
    (BagImpl vec) <- ask
    fmap unsafeCoerce# $ V.read vec $
        elemIndex (Proxy :: Proxy item) (Proxy :: Proxy full)

-- | Run a state bagger with one additional item and capture the final value of
-- that item on return.
stackItem :: forall item full bag m a. (PrimMonad m, ElementIndex item full) =>
    item -> StateBaggerT full (item ': bag) m a ->
    StateBaggerT full bag m (a, item)
stackItem item chain =
    addItem item $ liftM2 (,) chain topItem

-- | State bag monad transformer where the state items are represented by the
-- type-level list @bag@.
newtype StateBagT bag m a = StateBagT (ReaderT (BagImpl (PrimState m) bag) m a)
    deriving (Functor, Applicative, Monad, MonadIO)

instance MonadTrans (StateBagT bag) where
    lift = StateBagT . lift

instance (PrimMonad m) => PrimMonad (StateBagT bag m) where
  type PrimState (StateBagT bag m) = PrimState m
  primitive = lift . primitive

-- | Runs a state bag with the items prepared in a state bagger.
makeBag ::
    forall bag m a. (PrimMonad m, ElementCount bag) =>
    StateBagT bag m a -> StateBaggerT bag bag m a
makeBag (StateBagT r) = StateBaggerT r

itemImpl :: forall m item bag.
    (PrimMonad m, ElementIndex item bag) =>
    StateBagT bag m (StateBagT bag m item, item -> StateBagT bag m ())
{-# INLINE itemImpl #-}
itemImpl = do
    let i = elemIndex (Proxy :: Proxy item) (Proxy :: Proxy bag)
    (BagImpl vec) <- StateBagT ask
    let geti = fmap unsafeCoerce# $ V.read vec i
    let puti item = V.write vec i $ unsafeCoerce# item
    return (geti, puti)

-- | Gets the current value of @item@ from the bag.
getItem :: forall m item bag.
    (PrimMonad m, ElementIndex item bag) =>
    StateBagT bag m item
{-# INLINE getItem #-}
getItem = itemImpl >>= fst

-- | Stores a new value of @item@ in the bag.
putItem :: forall m item bag.
    (PrimMonad m, ElementIndex item bag) =>
    item -> StateBagT bag m ()
{-# INLINE putItem #-}
putItem item = itemImpl >>= flip snd item

-- | Applies a monadic function to an item in the bag and stores the result.
modifyItemM :: forall m item bag.
    (PrimMonad m, ElementIndex item bag) =>
    (item -> StateBagT bag m item) -> StateBagT bag m ()
{-# INLINE modifyItemM #-}
modifyItemM f = do
    (get, put) <- itemImpl
    item <- get
    item' <- f item
    put item'