{-# LANGUAGE BangPatterns #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} module HaskellWorks.Data.Vector.Storable ( padded , foldMap , mapAccumL , mmap , constructSI , construct2N , construct64UnzipN ) where import Control.Monad.ST (ST, runST) import Data.Monoid (Monoid (..), (<>)) import Data.Vector.Storable (Storable) import Data.Word import Foreign.ForeignPtr import HaskellWorks.Data.Vector.AsVector8 import Prelude hiding (foldMap) import qualified Data.ByteString as BS import qualified Data.Vector.Generic as DVG import qualified Data.Vector.Storable as DVS import qualified Data.Vector.Storable.Mutable as DVSM import qualified System.IO.MMap as IO {-# ANN module ("HLint: ignore Redundant do" :: String) #-} padded :: Int -> DVS.Vector Word8 -> DVS.Vector Word8 padded n v = v <> DVS.replicate ((n - DVS.length v) `max` 0) 0 {-# INLINE padded #-} foldMap :: (DVS.Storable a, Monoid m) => (a -> m) -> DVS.Vector a -> m foldMap f = DVS.foldl' (\a b -> a <> f b) mempty {-# INLINE foldMap #-} mapAccumL :: forall a b c. (Storable b, Storable c) => (a -> b -> (a, c)) -> a -> DVS.Vector b -> (a, DVS.Vector c) mapAccumL f a vb = DVS.createT $ do vc <- DVSM.unsafeNew (DVS.length vb) a' <- go 0 a vc return (a', vc) where go :: Int -> a -> DVS.MVector s c -> ST s a go i a0 vc = if i < DVS.length vb then do let (a1, c1) = f a0 (DVS.unsafeIndex vb i) DVSM.unsafeWrite vc i c1 go (i + 1) a1 vc else return a0 {-# INLINE mapAccumL #-} -- | MMap the file as a storable vector. If the size of the file is not a multiple of the element size -- in bytes, then the last few bytes of the file will not be included in the vector. mmap :: Storable a => FilePath -> IO (DVS.Vector a) mmap filepath = do (fptr :: ForeignPtr Word8, offset, size) <- IO.mmapFileForeignPtr filepath IO.ReadOnly Nothing let !v = DVS.unsafeFromForeignPtr fptr offset size return (DVS.unsafeCast v) -- | Construct a vector statefully with index constructSI :: forall a s. Storable a => Int -> (Int -> s -> (s, a)) -> s -> (s, DVS.Vector a) constructSI n f state = DVS.createT $ do mv <- DVSM.unsafeNew n state' <- go 0 state mv return (state', mv) where go :: Int -> s -> DVSM.MVector t a -> ST t s go i s mv = if i < DVSM.length mv then do let (s', a) = f i s DVSM.unsafeWrite mv i a go (i + 1) s' mv else return s construct2N :: (Storable b, Storable c) => Int -> (forall s. a -> DVSM.MVector s b -> ST s Int) -> Int -> (forall s. a -> DVSM.MVector s c -> ST s Int) -> [a] -> (DVS.Vector b, DVS.Vector c) construct2N nb fb nc fc as = runST $ do mbs <- DVSM.unsafeNew nb mcs <- DVSM.unsafeNew nc (mbs2, mcs2) <- go fb 0 mbs fc 0 mcs as bs <- DVG.unsafeFreeze mbs2 cs <- DVG.unsafeFreeze mcs2 return (bs, cs) where go :: (Storable b, Storable c) => (forall t. a -> DVSM.MVector t b -> ST t Int) -> Int -> DVSM.MVector s b -> (forall t. a -> DVSM.MVector t c -> ST t Int) -> Int -> DVSM.MVector s c -> [a] -> ST s (DVSM.MVector s b, DVSM.MVector s c) go _ bn mbs _ cn mcs [] = return (DVSM.take bn mbs, DVSM.take cn mcs) go fb' bn mbs fc' cn mcs (d:ds) = do bi <- fb' d (DVSM.drop bn mbs) ci <- fc' d (DVSM.drop cn mcs) go fb' (bn + bi) mbs fc' (cn + ci) mcs ds construct64UnzipN :: Int -> [(BS.ByteString, BS.ByteString)] -> (DVS.Vector Word64, DVS.Vector Word64) construct64UnzipN nBytes xs = (DVS.unsafeCast ibv, DVS.unsafeCast bpv) where [ibv, bpv] = DVS.createT $ do let nW64s = (nBytes + 7) `div` 8 ibmv <- DVSM.new nW64s bpmv <- DVSM.new nW64s (ibmvRemaining, bpmvRemaining) <- go ibmv bpmv xs return [ DVSM.take (((DVSM.length ibmv - ibmvRemaining) `div` 8) * 8) ibmv , DVSM.take (((DVSM.length bpmv - bpmvRemaining) `div` 8) * 8) bpmv ] go :: DVSM.MVector s Word8 -> DVSM.MVector s Word8 -> [(BS.ByteString, BS.ByteString)] -> ST s (Int, Int) go ibmv bpmv ((ib, bp):ys) = do DVS.copy (DVSM.take (BS.length ib) ibmv) (asVector8 ib) DVS.copy (DVSM.take (BS.length bp) bpmv) (asVector8 bp) go (DVSM.drop (BS.length ib) ibmv) (DVSM.drop (BS.length bp) bpmv) ys go ibmv bpmv [] = do DVSM.set (DVSM.take 8 ibmv) 0 DVSM.set (DVSM.take 8 bpmv) 0 return (DVSM.length (DVSM.drop 8 ibmv), DVSM.length (DVSM.drop 8 bpmv))