-- | Interface with chain fusion. module Data.Repa.Eval.Chain ( chainOfArray , unchainToArray , unchainToArrayIO) where import Data.Repa.Chain (Chain(..), Step(..)) import Data.Repa.Array.Generic.Index as A import Data.Repa.Array.Internals.Bulk as A import Data.Repa.Array.Internals.Target as A import qualified Data.Vector.Fusion.Stream.Monadic as S import qualified Data.Vector.Fusion.Stream.Size as S import qualified Data.Vector.Fusion.Util as S import System.IO.Unsafe #include "repa-array.h" ------------------------------------------------------------------------------- -- | Produce a `Chain` for the elements of the given array. -- The order in which the elements appear in the chain is -- determined by the layout of the array. chainOfArray :: (Monad m, Bulk l a) => Array l a -> Chain m Int a chainOfArray !arr = Chain (S.Exact len) 0 step where !len = A.length arr step !i | i >= len = return $ Done i | otherwise = return $ Yield (A.index arr $ A.fromIndex (A.layout arr) i) (i + 1) {-# INLINE_INNER step #-} {-# INLINE_STREAM chainOfArray #-} -- | Lift a pure chain to a monadic chain. liftChain :: Monad m => Chain S.Id s a -> Chain m s a liftChain (Chain sz s step) = Chain sz s (return . S.unId . step) {-# INLINE_STREAM liftChain #-} ------------------------------------------------------------------------------- -- | Compute the elements of a pure `Chain`, -- writing them into a new array `Array`. unchainToArray :: Target l a => Name l -> Chain S.Id s a -> (Array l a, s) unchainToArray nDst c = unsafePerformIO $ unchainToArrayIO nDst $ liftChain c {-# INLINE_STREAM unchainToArray #-} -- | Compute the elements of an `IO` `Chain`, -- writing them to a new `Array`. unchainToArrayIO :: Target l a => Name l -> Chain IO s a -> IO (Array l a, s) unchainToArrayIO nDst (Chain sz s0 step) = case sz of S.Exact i -> unchainToArrayIO_max i S.Max i -> unchainToArrayIO_max i S.Unknown -> unchainToArrayIO_unknown 32 -- unchain when we known the maximum size of the vector. where unchainToArrayIO_max !nMax = do !vec0 <- unsafeNewBuffer (create nDst zeroDim) !vec <- unsafeGrowBuffer vec0 nMax let go_unchainIO_max !sPEC !i !s = step s >>= \m -> case m of Yield e s' -> do unsafeWriteBuffer vec i e go_unchainIO_max sPEC (i + 1) s' Skip s' -> go_unchainIO_max sPEC i s' Done s' -> do buf' <- unsafeSliceBuffer 0 i vec arr <- unsafeFreezeBuffer buf' return (arr, s') {-# INLINE_INNER go_unchainIO_max #-} go_unchainIO_max S.SPEC 0 s0 {-# INLINE_INNER unchainToArrayIO_max #-} -- unchain when we don't know the maximum size of the vector. unchainToArrayIO_unknown !nStart = do !vec0 <- unsafeNewBuffer (create nDst zeroDim) !vec1 <- unsafeGrowBuffer vec0 nStart let go_unchainIO_unknown !uvec !i !n !s = go_unchainIO_unknown1 uvec i n s (\vec' i' n' s' -> go_unchainIO_unknown vec' i' n' s') (\result -> return result) go_unchainIO_unknown1 !vec !i !n !s cont done = step s >>= \r -> case r of Yield e s' -> do (vec', n') <- if i >= n then do vec' <- unsafeGrowBuffer vec n return (vec', n + n) else return (vec, n) unsafeWriteBuffer vec' i e cont vec' (i + 1) n' s' Skip s' -> cont vec i n s' Done s' -> do vec' <- unsafeSliceBuffer 0 i vec arr <- unsafeFreezeBuffer vec' done (arr, s') go_unchainIO_unknown vec1 0 nStart s0 {-# INLINE_INNER unchainToArrayIO_unknown #-} {-# INLINE_STREAM unchainToArrayIO #-}