{-# LANGUAGE ScopedTypeVariables #-}

module HaskellWorks.Data.Streams.Vector.Storable
  ( stream
  , unstream

  , map
  , zipWith

  , enumFromStepN
  , foldl

  , dotp
  , sum
  ) where

import Control.Monad
import Control.Monad.ST
import Data.Vector.Storable               (Storable)
import HaskellWorks.Data.Streams.Internal (inplace)
import HaskellWorks.Data.Streams.Size
import HaskellWorks.Data.Streams.Stream   (Step (..), Stream (..))
import Prelude                            hiding (foldl, map, sum, zipWith)

import qualified Data.Vector.Storable             as DVS
import qualified Data.Vector.Storable.Mutable     as DVSM
import qualified HaskellWorks.Data.Streams.Stream as S

unstream :: forall a. Storable a => S.Stream a -> DVS.Vector a
unstream :: forall a. Storable a => Stream a -> Vector a
unstream (S.Stream s -> Step s a
step s
initialState Size
size) = (forall s. ST s (Vector a)) -> Vector a
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector a)) -> Vector a)
-> (forall s. ST s (Vector a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
  MVector s a
v <- case Size
size of
    Exact Int
n -> Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
DVSM.unsafeNew Int
n
    Max   Int
n -> Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
DVSM.unsafeNew Int
n
    Size
Unknown -> Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
DVSM.unsafeNew (Int
32 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1024)
  (s -> Step s a) -> MVector s a -> Int -> s -> ST s (Vector a)
forall s t.
(s -> Step s a) -> MVector t a -> Int -> s -> ST t (Vector a)
loop s -> Step s a
step MVector s a
v Int
0 s
initialState
  where loop :: (s -> Step s a) -> DVSM.MVector t a -> Int -> s -> ST t (DVS.Vector a)
        loop :: forall s t.
(s -> Step s a) -> MVector t a -> Int -> s -> ST t (Vector a)
loop s -> Step s a
g MVector t a
v Int
i s
s = case s -> Step s a
g s
s of
            Yield a
a s
s' -> do
              Bool -> ST t () -> ST t ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= MVector t a -> Int
forall a s. Storable a => MVector s a -> Int
DVSM.length MVector t a
v) (ST t () -> ST t ()) -> ST t () -> ST t ()
forall a b. (a -> b) -> a -> b
$ ST t (MVector t a) -> ST t ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ST t (MVector t a) -> ST t ()) -> ST t (MVector t a) -> ST t ()
forall a b. (a -> b) -> a -> b
$ MVector (PrimState (ST t)) a
-> Int -> ST t (MVector (PrimState (ST t)) a)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> m (MVector (PrimState m) a)
DVSM.unsafeGrow MVector t a
MVector (PrimState (ST t)) a
v (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2)
              MVector (PrimState (ST t)) a -> Int -> a -> ST t ()
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> a -> m ()
DVSM.unsafeWrite MVector t a
MVector (PrimState (ST t)) a
v Int
i a
a
              (s -> Step s a) -> MVector t a -> Int -> s -> ST t (Vector a)
forall s t.
(s -> Step s a) -> MVector t a -> Int -> s -> ST t (Vector a)
loop s -> Step s a
g MVector t a
v (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) s
s'
            Skip s
s0 -> (s -> Step s a) -> MVector t a -> Int -> s -> ST t (Vector a)
forall s t.
(s -> Step s a) -> MVector t a -> Int -> s -> ST t (Vector a)
loop s -> Step s a
g MVector t a
v Int
i s
s0
            Step s a
Done -> MVector (PrimState (ST t)) a -> ST t (Vector a)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
DVS.freeze MVector t a
MVector (PrimState (ST t)) a
v
{-# INLINE [1] unstream #-}

stream :: forall a. Storable a => DVS.Vector a -> Stream a
stream :: forall a. Storable a => Vector a -> Stream a
stream Vector a
v = (Int -> Step Int a) -> Int -> Size -> Stream a
forall s a. (s -> Step s a) -> s -> Size -> Stream a
Stream Int -> Step Int a
step Int
0 (Int -> Size
Exact Int
len)
  where len :: Int
len = Vector a -> Int
forall a. Storable a => Vector a -> Int
DVS.length Vector a
v
        step :: Int -> Step Int a
step Int
i = if Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
len
          then Step Int a
forall s a. Step s a
Done
          else a -> Int -> Step Int a
forall s a. a -> s -> Step s a
Yield (Vector a -> Int -> a
forall a. Storable a => Vector a -> Int -> a
DVS.unsafeIndex Vector a
v Int
i) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
{-# INLINE [1] stream #-}

map :: (Storable a, Storable b)
  => (a -> b)
  -> DVS.Vector a
  -> DVS.Vector b
map :: forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
map a -> b
f = Stream b -> Vector b
forall a. Storable a => Stream a -> Vector a
unstream (Stream b -> Vector b)
-> (Vector a -> Stream b) -> Vector a -> Vector b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stream a -> Stream b) -> Stream a -> Stream b
forall a. a -> a
inplace ((a -> b) -> Stream a -> Stream b
forall a b. (a -> b) -> Stream a -> Stream b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f) (Stream a -> Stream b)
-> (Vector a -> Stream a) -> Vector a -> Stream b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector a -> Stream a
forall a. Storable a => Vector a -> Stream a
stream
{-# INLINE map #-}

zipWith :: (Storable a, Storable b, Storable c)
  => (a -> b -> c)
  -> DVS.Vector a
  -> DVS.Vector b
  -> DVS.Vector c
zipWith :: forall a b c.
(Storable a, Storable b, Storable c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
zipWith a -> b -> c
f Vector a
v Vector b
w = Stream c -> Vector c
forall a. Storable a => Stream a -> Vector a
unstream ((a -> b -> c) -> Stream a -> Stream b -> Stream c
forall a b c. (a -> b -> c) -> Stream a -> Stream b -> Stream c
S.zipWith a -> b -> c
f (Vector a -> Stream a
forall a. Storable a => Vector a -> Stream a
stream Vector a
v) (Vector b -> Stream b
forall a. Storable a => Vector a -> Stream a
stream Vector b
w))

enumFromStepN :: (Num a, Storable a) => a -> a -> Int -> DVS.Vector a
enumFromStepN :: forall a. (Num a, Storable a) => a -> a -> Int -> Vector a
enumFromStepN a
x a
y = Stream a -> Vector a
forall a. Storable a => Stream a -> Vector a
unstream (Stream a -> Vector a) -> (Int -> Stream a) -> Int -> Vector a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Stream a) -> Int -> Stream a
forall a. a -> a
inplace (a -> a -> Int -> Stream a
forall a. Num a => a -> a -> Int -> Stream a
S.enumFromStepN a
x a
y)
{-# INLINE [1] enumFromStepN #-}

foldl :: Storable b => (a -> b -> a) -> a -> DVS.Vector b -> a
foldl :: forall b a. Storable b => (a -> b -> a) -> a -> Vector b -> a
foldl a -> b -> a
f a
z = (Stream b -> a) -> Stream b -> a
forall a. a -> a
inplace ((a -> b -> a) -> a -> Stream b -> a
forall a b. (a -> b -> a) -> a -> Stream b -> a
S.foldl a -> b -> a
f a
z) (Stream b -> a) -> (Vector b -> Stream b) -> Vector b -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector b -> Stream b
forall a. Storable a => Vector a -> Stream a
stream
{-# INLINE [1] foldl #-}

sum :: (Storable a, Num a) => DVS.Vector a -> a
sum :: forall a. (Storable a, Num a) => Vector a -> a
sum = (a -> a -> a) -> a -> Vector a -> a
forall b a. Storable b => (a -> b -> a) -> a -> Vector b -> a
foldl a -> a -> a
forall a. Num a => a -> a -> a
(+) a
0

dotp :: (Storable a, Num a) => DVS.Vector a -> DVS.Vector a -> a
dotp :: forall a. (Storable a, Num a) => Vector a -> Vector a -> a
dotp Vector a
v Vector a
w = Vector a -> a
forall a. (Storable a, Num a) => Vector a -> a
sum ((a -> a -> a) -> Vector a -> Vector a -> Vector a
forall a b c.
(Storable a, Storable b, Storable c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
zipWith a -> a -> a
forall a. Num a => a -> a -> a
(*) Vector a
v Vector a
w)

{-# RULES
  "stream/unstream" forall f. stream (unstream f) = f
  #-}