{-# LANGUAGE UndecidableInstances #-} module Data.Repa.Flow.States ( Next (..) , States (..) , Refs (..) , foldRefsM , toListM) where import Control.Monad import qualified Data.Vector.Mutable as VM #include "repa-flow.h" ------------------------------------------------------------------------------- class (Ord i, Eq i) => Next i where -- | Get the zero for this index type. first :: i -- | Given an index an arity, get the next index after this one, -- or `Nothing` if there aren't any more. next :: i -> i -> Maybe i -- | Check if an index is valid for this arity. check :: i -> i -> Bool -- | Unit indices. instance Next () where first = () next _ _ = Nothing check _ _ = True {-# INLINE first #-} {-# INLINE next #-} -- | Integer indices. instance Next Int where first = 0 {-# INLINE first #-} next i len | i + 1 >= len = Nothing | otherwise = Just (i + 1) {-# INLINE next #-} check i len = i >= 0 && len >= 0 && i < len {-# INLINE check #-} -- | Tuple indices. instance Next (Int, Int) where first = (0, 0) {-# INLINE first #-} next (ix1, ix0) (a1, a0) | ix0 + 1 >= a0 = if ix1 + 1 >= a1 then Nothing else Just (ix1 + 1, 0) | otherwise = Just (ix1, ix0 + 1) {-# INLINE next #-} check (ix1, ix2) (len1, len2) = check ix1 len1 && check ix2 len2 {-# INLINE check #-} ------------------------------------------------------------------------------- class (Ord i, Next i, Monad m) => States i m where -- | A collection of mutable references. data Refs i m a -- | Get the extent of the collection. extentRefs :: Refs i m a -> i -- | Allocate a new state of the given arity, also returning an index to the -- first element of the collection. newRefs :: i -> a -> m (Refs i m a) -- | Write an element of the state. readRefs :: Refs i m a -> i -> m a -- | Read an element of the state. writeRefs :: Refs i m a -> i -> a -> m () -- | Fold all the elements in a collection of refs. foldRefsM :: States i m => (a -> b -> b) -> b -> Refs i m a -> m b foldRefsM f z refs = loop_foldsRefsM first z where loop_foldsRefsM i acc = do x <- readRefs refs i let acc' = f x acc case next i (extentRefs refs) of Nothing -> return acc' Just i' -> loop_foldsRefsM i' acc' {-# INLINE loop_foldsRefsM #-} {-# INLINE foldRefsM #-} toListM :: States i m => Refs i m a -> m [a] toListM refs = foldRefsM (:) [] refs {-# NOINLINE toListM #-} instance States Int IO where data Refs Int IO a = Refs !(VM.IOVector a) extentRefs (Refs !refs) = VM.length refs newRefs !n !x = liftM Refs $ unsafeNewWithVM n x readRefs (Refs !refs) !i = VM.unsafeRead refs i writeRefs (Refs !refs) !i !x = VM.unsafeWrite refs i x {-# NOINLINE newRefs #-} {-# INLINE readRefs #-} {-# INLINE writeRefs #-} instance States Int m => States () m where data Refs () m a = URefs !(Refs Int m a) extentRefs _ = () {-# INLINE extentRefs #-} newRefs _ !x = do refs <- newRefs (1 :: Int) x return $ URefs refs {-# NOINLINE newRefs #-} readRefs (URefs !refs) _ = readRefs refs 0 writeRefs (URefs !refs) _ !x = writeRefs refs 0 x {-# INLINE readRefs #-} {-# INLINE writeRefs #-} ------------------------------------------------------------------------------- unsafeNewWithVM :: Int -> a -> IO (VM.IOVector a) unsafeNewWithVM n x = do vec <- VM.unsafeNew n let loop_newRefs !i | i >= n = return () | otherwise = do VM.unsafeWrite vec i x loop_newRefs (i + 1) {-# INLINE loop_newRefs #-} loop_newRefs 0 return vec {-# INLINE unsafeNewWithVM #-}