module Data.Binary.Strict.BitUtil
  ( topNBits
  , bottomNBits
  , leftShift
  , rightShift
  , leftTruncateBits
  , rightTruncateBits
  ) where

import Data.Word (Word8)
import qualified Data.ByteString as B
import Data.Bits (shiftL, shiftR, (.|.), (.&.))

-- | This is used for masking the last byte of a ByteString so that extra
--   bits don't leak in
topNBits :: Int -> Word8
topNBits 0 = 0
topNBits 1 = 0x80
topNBits 2 = 0xc0
topNBits 3 = 0xe0
topNBits 4 = 0xf0
topNBits 5 = 0xf8
topNBits 6 = 0xfc
topNBits 7 = 0xfe
topNBits 8 = 0xff
topNBits x = error ("topNBits undefined for " ++ show x)

-- | Return a Word8 with the bottom n bits set
bottomNBits :: Int -> Word8
bottomNBits 0 = 0
bottomNBits 1 = 0x01
bottomNBits 2 = 0x03
bottomNBits 3 = 0x07
bottomNBits 4 = 0x0f
bottomNBits 5 = 0x1f
bottomNBits 6 = 0x3f
bottomNBits 7 = 0x7f
bottomNBits 8 = 0xff
bottomNBits x = error ("bottomNBits undefined for " ++ show x)

-- | Shift the whole ByteString some number of bits left where 0 <= @n@ < 8
leftShift :: Int -> B.ByteString -> B.ByteString
leftShift 0 = id
leftShift n = snd . B.mapAccumR f 0 where
  f acc b = (b `shiftR` (8 - n), (b `shiftL` n) .|. acc)

-- | Shift the whole ByteString some number of bits right where 0 <= @n@ < 8
rightShift :: Int -> B.ByteString -> B.ByteString
rightShift 0 = id
rightShift n = snd . B.mapAccumL f 0 where
  f acc b = (b .&. (bottomNBits n), (b `shiftR` n) .|. (acc `shiftL` (8 - n)))

-- | Truncate a ByteString to a given number of bits (counting from the left)
--   by masking out extra bits in the last byte
leftTruncateBits :: Int -> B.ByteString -> B.ByteString
leftTruncateBits n = B.take ((n + 7) `div` 8) . snd . B.mapAccumL f n where
  f bits w | bits >= 8 = (bits - 8, w)
           | bits == 0 = (0, 0)
           | otherwise = (0, w .&. topNBits bits)

-- | Truncate a ByteString to a given number of bits (counting from the right)
--   by masking out extra bits in the first byte
rightTruncateBits :: Int -> B.ByteString -> B.ByteString
rightTruncateBits n bs = B.drop (B.length bs - ((n + 7) `div` 8)) $ snd $ B.mapAccumR f n bs where
  f bits w | bits >= 8 = (bits - 8, w)
           | bits == 0 = (0, 0)
           | otherwise = (0, w .&. bottomNBits bits)