{-# LANGUAGE DataKinds, KindSignatures, ScopedTypeVariables, MultiParamTypeClasses, FlexibleContexts, FlexibleInstances, TypeFamilies, TypeOperators, RankNTypes, MagicHash, GeneralizedNewtypeDeriving #-} -- | State bag monad transformer which runs on any monad stack. module Control.Monad.Trans.StateBag.Pure ( StateBaggerT, runBagger, addItem, topItem, stackItem, StateBagT, makeBag, getItem, putItem, modifyItemM, ElementCount(), ElementIndex(), ) where import Control.Monad import Control.Monad.Trans.Class import Control.Monad.Trans.State import Control.Monad.Trans.StateBag.Internal import Control.Monad.IO.Class import Control.Monad.Primitive import Data.Proxy import GHC.Prim (Any, unsafeCoerce#) import qualified Data.Vector as V import qualified Data.Vector.Mutable as MV newtype BaggerImpl (bag :: [*]) = BaggerImpl [Any] -- | Monad transformer for building state bags. newtype StateBaggerT bag m a = StateBaggerT (StateT (BaggerImpl bag) m a) deriving (Functor, Applicative, Monad, MonadIO) instance MonadTrans (StateBaggerT bag) where lift = StateBaggerT . lift instance (PrimMonad m) => PrimMonad (StateBaggerT bag m) where type PrimState (StateBaggerT bag m) = PrimState m primitive = lift . primitive -- | Run an empty state bagger on top of a monad stack. runBagger :: (Monad m) => StateBaggerT '[] m a -> m a runBagger (StateBaggerT s) = fmap fst $ runStateT s $ BaggerImpl [] -- | Run a state bagger with one additional item. addItem :: forall item bag m a. (Monad m) => item -> StateBaggerT (item ': bag) m a -> StateBaggerT bag m a addItem item (StateBaggerT chain) = StateBaggerT $ do (BaggerImpl list) <- get (ret, BaggerImpl (_:list')) <- lift $ runStateT chain $ BaggerImpl (unsafeCoerce# item : list) put $ BaggerImpl list' return ret -- | Get the value of the top item in a state bagger. topItem :: forall item bag m. (Monad m) => StateBaggerT (item ': bag) m item topItem = StateBaggerT $ do (BaggerImpl (item:_)) <- get return $ unsafeCoerce# item -- | Run a state bagger with one additional item and capture the final value of -- that item on return. stackItem :: forall item bag m a. (Monad m) => item -> StateBaggerT (item ': bag) m a -> StateBaggerT bag m (a, item) stackItem item chain = addItem item $ liftM2 (,) chain topItem newtype BagImpl (bag :: [*]) = BagImpl (V.Vector Any) -- | State bag monad transformer where the state items are represented by the -- type-level list @bag@. newtype StateBagT bag m a = StateBagT (StateT (BagImpl bag) m a) deriving (Functor, Applicative, Monad, MonadIO) instance MonadTrans (StateBagT bag) where lift = StateBagT . lift instance (PrimMonad m) => PrimMonad (StateBagT bag m) where type PrimState (StateBagT bag m) = PrimState m primitive = lift . primitive -- | Runs a state bag with the items prepared in a state bagger. makeBag :: forall bag m a. (Monad m, ElementCount bag) => StateBagT bag m a -> StateBaggerT bag m a makeBag (StateBagT s) = StateBaggerT $ do (BaggerImpl list) <- get let vlen = elemCount (Proxy :: Proxy bag) (ret, BagImpl vec') <- lift $ runStateT s $ BagImpl $ V.fromListN vlen list put $ BaggerImpl $ V.toList vec' return ret itemImpl :: forall m item bag. (Monad m, ElementIndex item bag) => StateBagT bag m (item, item -> StateBagT bag m ()) {-# INLINE itemImpl #-} itemImpl = do let i = elemIndex (Proxy :: Proxy item) (Proxy :: Proxy bag) (BagImpl vec) <- StateBagT get let item = unsafeCoerce# $ (V.!) vec i let puti item' = StateBagT $ put $ BagImpl $ V.modify (\mvec -> MV.write mvec i $ unsafeCoerce# item') vec return (item, puti) -- | Gets the current value of @item@ from the bag. getItem :: forall m item bag. (Monad m, ElementIndex item bag) => StateBagT (bag :: [*]) m item getItem = fmap fst itemImpl -- | Stores a new value of @item@ in the bag. putItem :: forall m item bag. (Monad m, ElementIndex item bag) => item -> StateBagT (bag :: [*]) m () putItem item = itemImpl >>= flip snd item -- | Applies a monadic function to an item in the bag and stores the result. modifyItemM :: forall m item bag. (Monad m, ElementIndex item bag) => (item -> StateBagT bag m item) -> StateBagT bag m () modifyItemM f = do (item, puti) <- itemImpl item' <- f item puti item'