module Data.Array.BitArray.ByteString
  (
  
    toByteString
  , fromByteString
  
  , toByteStringIO
  , fromByteStringIO
  ) where
import Data.Bits (shiftR, (.&.))
import Data.ByteString (ByteString)
import Data.Ix (Ix, rangeSize)
import Data.Word (Word8)
import Control.Monad (when)
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Marshal.Utils (copyBytes)
import Foreign.Ptr (castPtr)
import Foreign.Storable (peekByteOff, pokeByteOff)
import System.IO.Unsafe (unsafePerformIO)
import Compat (packCStringLen, unsafeUseAsCStringLen)
import Data.Bits.Bitwise (mask)
import Data.Array.BitArray (BitArray)
import Data.Array.BitArray.IO (IOBitArray)
import qualified Data.Array.BitArray.IO as IO
import Data.Array.BitArray.Internal (iobData)
toByteString :: Ix i => BitArray i -> ByteString
toByteString a = unsafePerformIO $ toByteStringIO =<< IO.unsafeThaw a
fromByteString :: Ix i => (i, i)  -> ByteString  -> BitArray i
fromByteString bs s = unsafePerformIO $ IO.unsafeFreeze =<< fromByteStringIO bs s
toByteStringIO :: Ix i => IOBitArray i -> IO ByteString
toByteStringIO a = do
  bs <- IO.getBounds a
  let rs = rangeSize bs
      bytes = (rs + 7) `shiftR` 3
      bits = rs .&. 7
      lastByte = bytes  1
  withForeignPtr (iobData a) $ \p -> do
    when (bits /= 0) $ do
      b <- peekByteOff p lastByte
      pokeByteOff p lastByte (b .&. mask bits :: Word8)
    packCStringLen (castPtr p, bytes)
fromByteStringIO :: Ix i => (i, i)  -> ByteString  -> IO (IOBitArray i)
fromByteStringIO bs s = do
  a <- IO.newArray bs False
  let rs = rangeSize bs
      bytes = (rs + 7) `shiftR` 3
  unsafeUseAsCStringLen s $ \(src, len) ->
    withForeignPtr (iobData a) $ \dst ->
      copyBytes dst (castPtr src) (bytes `min` len)
  return a