{-# LANGUAGE Rank2Types #-}
{- |
Module      : Data.StorableVector.ST.Strict
License     : BSD-style
Maintainer  : haskell@henning-thielemann.de
Stability   : experimental
Portability : portable, requires ffi
Tested with : GHC 6.4.1

Interface for access to a mutable StorableVector.
-}
module Data.StorableVector.ST.Strict (
        Vector,
        new,
        new_,
        read,
        write,
        modify,
        freeze,
        thaw,
        length,
        runSTVector,
        mapST,
        mapSTLazy,
        ) where

import Data.StorableVector.ST.Private
          (Vector(SV), unsafeCreate, unsafeToVector, )
import qualified Data.StorableVector.Base as V
import qualified Data.StorableVector as VS
import qualified Data.StorableVector.Lazy as VL

import qualified Control.Monad.ST.Strict as ST
import Control.Monad.ST.Strict (ST, unsafeIOToST, runST, )  -- stToIO,

import Foreign.Ptr              (Ptr, )
import Foreign.ForeignPtr       (withForeignPtr, unsafeForeignPtrToPtr, )
import Foreign.Storable         (Storable(peek, poke))
import Foreign.Marshal.Array    (advancePtr, copyArray, )
-- import System.IO.Unsafe         (unsafePerformIO)

-- import Prelude (Int, ($), (+), return, const, )
import Prelude hiding (read, length, )


{-# INLINE new #-}
{-# INLINE new_ #-}
{-# INLINE read #-}
{-# INLINE write #-}
{-# INLINE modify #-}
{-# INLINE freeze #-}
{-# INLINE thaw #-}
{-# INLINE length #-}
{-# INLINE runSTVector #-}
{-# INLINE mapST #-}
{-# INLINE mapSTLazy #-}


-- * access to mutable storable vector

new :: (Storable e) =>
   Int -> e -> ST s (Vector s e)
new n x =
   unsafeCreate n $
   let {-# INLINE go #-}
       go m p =
         if m>0
           then poke p x >> go (pred m) (p `advancePtr` 1)
           else return ()
   in  go n

new_ :: (Storable e) =>
   Int -> ST s (Vector s e)
new_ n =
   unsafeCreate n (const (return ()))


{- |
> Control.Monad.ST.runST (do arr <- new_ 10; Monad.zipWithM_ (write arr) [9,8..0] ['a'..]; read arr 3)
-}
read :: (Storable e) =>
   Vector s e -> Int -> ST s e
read v n =
   access "read" v n $ peek

{- |
> VS.unpack $ runSTVector (do arr <- new_ 10; Monad.zipWithM_ (write arr) [9,8..0] ['a'..]; return arr)
-}
write :: (Storable e) =>
   Vector s e -> Int -> e -> ST s ()
write v n x =
   access "write" v n $ \p -> poke p x

{- |
> VS.unpack $ runSTVector (do arr <- new 10 'a'; Monad.mapM_ (\n -> modify arr (mod n 8) succ) [0..10]; return arr)
-}
modify :: (Storable e) =>
   Vector s e -> Int -> (e -> e) -> ST s ()
modify v n f =
   access "modify" v n $ \p -> poke p . f =<< peek p

{-# INLINE access #-}
access :: (Storable e) =>
   String -> Vector s e -> Int -> (Ptr e -> IO a) -> ST s a
access name (SV v l) n act =
   if 0<=n && n<l
     then unsafeIOToST (withForeignPtr v $ \p -> act (advancePtr p n))
     else error ("StorableVector.ST." ++ name ++ ": index out of range")

freeze :: (Storable e) =>
   Vector s e -> ST s (VS.Vector e)
freeze (SV x l) =
   unsafeIOToST $
   V.create l $ \p ->
   withForeignPtr x $ \f ->
   copyArray p f (fromIntegral l)


thaw :: (Storable e) =>
   VS.Vector e -> ST s (Vector s e)
thaw (V.SV x s l) =
   unsafeCreate l $ \p ->
   withForeignPtr x $ \f ->
   copyArray p (f `advancePtr` s) (fromIntegral l)


length ::
   Vector s e -> Int
length (SV _v l) = l


runSTVector :: (Storable e) =>
   (forall s. ST s (Vector s e)) -> VS.Vector e
runSTVector m =
   runST (unsafeToVector =<< m)



-- * operations on immutable storable vector within ST monad

{- |
> :module + Data.STRef
> VS.unpack $ Control.Monad.ST.runST (do ref <- newSTRef 'a'; mapST (\ _n -> do c <- readSTRef ref; modifySTRef ref succ; return c) (VS.pack [1,2,3,4::Data.Int.Int16]))
-}
mapST :: (Storable a, Storable b) =>
   (a -> ST s b) -> VS.Vector a -> ST s (VS.Vector b)
mapST f (V.SV px sx n) =
   let {-# INLINE go #-}
       go l q p =
          if l>0
            then
               do unsafeIOToST . poke p =<< f =<< unsafeIOToST (peek q)
                  go (pred l) (advancePtr q 1) (advancePtr p 1)
            else return ()
   in  do ys@(SV py _) <- new_ n
          go n
              (unsafeForeignPtrToPtr px `advancePtr` sx)
              (unsafeForeignPtrToPtr py)
          unsafeToVector ys

{-
mapST f xs@(V.SV v s l) =
   let go l q p =
          if l>0
            then
               do poke p =<< stToIO . f =<< peek q
                  go (pred l) (advancePtr q 1) (advancePtr p 1)
            else return ()
       n = VS.length xs
   in  return $ V.unsafeCreate n $ \p ->
          withForeignPtr v $ \q -> go n (advancePtr q s) p
-}


{- |
> *Data.StorableVector.ST.Strict Data.STRef> VL.unpack $ Control.Monad.ST.runST (do ref <- newSTRef 'a'; mapSTLazy (\ _n -> do c <- readSTRef ref; modifySTRef ref succ; return c) (VL.pack VL.defaultChunkSize [1,2,3,4::Data.Int.Int16]))
> "abcd"

The following should not work on infinite streams,
since we are in 'ST' with strict '>>='.
But it works. Why?

> *Data.StorableVector.ST.Strict Data.STRef> VL.unpack $ Control.Monad.ST.runST (do ref <- newSTRef 'a'; mapSTLazy (\ _n -> do c <- readSTRef ref; modifySTRef ref succ; return c) (VL.pack VL.defaultChunkSize [0::Data.Int.Int16 ..]))
> "Interrupted.
-}
mapSTLazy :: (Storable a, Storable b) =>
   (a -> ST s b) -> VL.Vector a -> ST s (VL.Vector b)
mapSTLazy f (VL.SV xs) =
   fmap VL.SV $ mapM (mapST f) xs