{-# LANGUAGE ScopedTypeVariables #-}

module HaskellWorks.Data.Streams.Vector where

import Control.Monad
import Control.Monad.ST
import HaskellWorks.Data.Streams.Internal
import HaskellWorks.Data.Streams.Size
import HaskellWorks.Data.Streams.Stream

import qualified Data.Vector                      as DV
import qualified Data.Vector.Mutable              as DVM
import qualified HaskellWorks.Data.Streams.Stream as HW

unstream :: forall a. HW.Stream a -> DV.Vector a
unstream :: forall a. Stream a -> Vector a
unstream (HW.Stream s -> Step s a
step s
initialState Size
size) = (forall s. ST s (Vector a)) -> Vector a
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector a)) -> Vector a)
-> (forall s. ST s (Vector a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
  MVector s a
v <- case Size
size of
    Exact Int
n -> Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
DVM.unsafeNew Int
n
    Max   Int
n -> Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
DVM.unsafeNew Int
n
    Size
Unknown -> Int -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> m (MVector (PrimState m) a)
DVM.unsafeNew (Int
32 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1024)
  (s -> Step s a) -> MVector s a -> Int -> s -> ST s (Vector a)
forall s t.
(s -> Step s a) -> MVector t a -> Int -> s -> ST t (Vector a)
loop s -> Step s a
step MVector s a
v Int
0 s
initialState
  where loop :: (s -> Step s a) -> DVM.MVector t a -> Int -> s -> ST t (DV.Vector a)
        loop :: forall s t.
(s -> Step s a) -> MVector t a -> Int -> s -> ST t (Vector a)
loop s -> Step s a
g MVector t a
v Int
i s
s = case s -> Step s a
g s
s of
            Yield a
a s
s' -> do
              Bool -> ST t () -> ST t ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= MVector t a -> Int
forall s a. MVector s a -> Int
DVM.length MVector t a
v) (ST t () -> ST t ()) -> ST t () -> ST t ()
forall a b. (a -> b) -> a -> b
$ ST t (MVector t a) -> ST t ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ST t (MVector t a) -> ST t ()) -> ST t (MVector t a) -> ST t ()
forall a b. (a -> b) -> a -> b
$ MVector (PrimState (ST t)) a
-> Int -> ST t (MVector (PrimState (ST t)) a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m (MVector (PrimState m) a)
DVM.unsafeGrow MVector t a
MVector (PrimState (ST t)) a
v (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2)
              MVector (PrimState (ST t)) a -> Int -> a -> ST t ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
DVM.unsafeWrite MVector t a
MVector (PrimState (ST t)) a
v Int
i a
a
              (s -> Step s a) -> MVector t a -> Int -> s -> ST t (Vector a)
forall s t.
(s -> Step s a) -> MVector t a -> Int -> s -> ST t (Vector a)
loop s -> Step s a
g MVector t a
v (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) s
s'
            Skip s
s0 -> (s -> Step s a) -> MVector t a -> Int -> s -> ST t (Vector a)
forall s t.
(s -> Step s a) -> MVector t a -> Int -> s -> ST t (Vector a)
loop s -> Step s a
g MVector t a
v Int
i s
s0
            Step s a
Done -> MVector (PrimState (ST t)) a -> ST t (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
DV.freeze MVector t a
MVector (PrimState (ST t)) a
v
{-# INLINE [1] unstream #-}

stream :: forall a. DV.Vector a -> Stream a
stream :: forall a. Vector a -> Stream a
stream Vector a
v = (Int -> Step Int a) -> Int -> Size -> Stream a
forall s a. (s -> Step s a) -> s -> Size -> Stream a
Stream Int -> Step Int a
step Int
0 (Int -> Size
Exact Int
len)
  where len :: Int
len = Vector a -> Int
forall a. Vector a -> Int
DV.length Vector a
v
        step :: Int -> Step Int a
step Int
i = if Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
len
          then Step Int a
forall s a. Step s a
Done
          else a -> Int -> Step Int a
forall s a. a -> s -> Step s a
Yield (Vector a -> Int -> a
forall a. Vector a -> Int -> a
DV.unsafeIndex Vector a
v Int
i) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
{-# INLINE [1] stream #-}

map :: (a -> b) -> DV.Vector a -> DV.Vector b
map :: forall a b. (a -> b) -> Vector a -> Vector b
map a -> b
f = Stream b -> Vector b
forall a. Stream a -> Vector a
unstream (Stream b -> Vector b)
-> (Vector a -> Stream b) -> Vector a -> Vector b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stream a -> Stream b) -> Stream a -> Stream b
forall a. a -> a
inplace ((a -> b) -> Stream a -> Stream b
forall a b. (a -> b) -> Stream a -> Stream b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f) (Stream a -> Stream b)
-> (Vector a -> Stream a) -> Vector a -> Stream b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector a -> Stream a
forall a. Vector a -> Stream a
stream
{-# INLINE map #-}

{-# RULES
  "stream/unstream" forall f. stream (unstream f) = f
  #-}