{-# LANGUAGE ScopedTypeVariables #-} module HaskellWorks.Data.Streams.Vector.Storable ( stream , unstream , map , zipWith , enumFromStepN , foldl , dotp , sum ) where import Control.Monad import Control.Monad.ST import Data.Vector.Storable (Storable) import HaskellWorks.Data.Streams.Internal (inplace) import HaskellWorks.Data.Streams.Size import HaskellWorks.Data.Streams.Stream (Step (..), Stream (..)) import Prelude hiding (foldl, map, sum, zipWith) import qualified Data.Vector.Storable as DVS import qualified Data.Vector.Storable.Mutable as DVSM import qualified HaskellWorks.Data.Streams.Stream as S unstream :: forall a. Storable a => S.Stream a -> DVS.Vector a unstream (S.Stream step initialState size) = runST $ do v <- case size of Exact n -> DVSM.unsafeNew n Max n -> DVSM.unsafeNew n Unknown -> DVSM.unsafeNew (32 * 1024) loop step v 0 initialState where loop :: (s -> Step s a) -> DVSM.MVector t a -> Int -> s -> ST t (DVS.Vector a) loop g v i s = case g s of Yield a s' -> do when (i >= DVSM.length v) $ void $ DVSM.unsafeGrow v (i * 2) DVSM.unsafeWrite v i a loop g v (i + 1) s' Skip s0 -> loop g v i s0 Done -> DVS.freeze v {-# INLINE [1] unstream #-} stream :: forall a. Storable a => DVS.Vector a -> Stream a stream v = Stream step 0 (Exact len) where len = DVS.length v step i = if i >= len then Done else Yield (DVS.unsafeIndex v i) (i + 1) {-# INLINE [1] stream #-} map :: (Storable a, Storable b) => (a -> b) -> DVS.Vector a -> DVS.Vector b map f = unstream . inplace (fmap f) . stream {-# INLINE map #-} zipWith :: (Storable a, Storable b, Storable c) => (a -> b -> c) -> DVS.Vector a -> DVS.Vector b -> DVS.Vector c zipWith f v w = unstream (S.zipWith f (stream v) (stream w)) enumFromStepN :: (Num a, Storable a) => a -> a -> Int -> DVS.Vector a enumFromStepN x y = unstream . inplace (S.enumFromStepN x y) {-# INLINE [1] enumFromStepN #-} foldl :: Storable b => (a -> b -> a) -> a -> DVS.Vector b -> a foldl f z = inplace (S.foldl f z) . stream {-# INLINE [1] foldl #-} sum :: (Storable a, Num a) => DVS.Vector a -> a sum = foldl (+) 0 dotp :: (Storable a, Num a) => DVS.Vector a -> DVS.Vector a -> a dotp v w = sum (zipWith (*) v w) {-# RULES "stream/unstream" forall f. stream (unstream f) = f #-}