module Data.BitSyntax (
BitBlock(..),
makeBits,
ReadType(..), bitSyn,
decodeU8, decodeU16, decodeU32, decodeBits) where
import Language.Haskell.TH
import Language.Haskell.TH.Lib
import Language.Haskell.TH.Syntax
import qualified Data.ByteString as BS
import Data.Word
import Data.Bits
import Data.Char (chr, ord)
import Control.Monad
import Test.QuickCheck
import Foreign
import Foreign.C
foreign import ccall unsafe "htonl" htonl :: Word32 -> Word32
foreign import ccall unsafe "htons" htons :: Word16 -> Word16
data BitBlock =
U8 Int |
U16 Int |
U32 Int |
NullTerminated String |
RawString String |
RawByteString BS.ByteString |
PackBits [(Int, Int)]
deriving (Show)
getBytes :: (Integral a, Bounded a, Bits a) => a -> BS.ByteString
getBytes input =
let getByte _ 0 = []
getByte x remaining = (fromIntegral $ (x .&. 0xff)) :
getByte (shiftR x 8) (remaining 1)
in
if (bitSize input `mod` 8) /= 0
then error "Input data bit size must be a multiple of 8"
else BS.pack $ getByte input (bitSize input `div` 8)
packBits :: (Word8, Int, [Word8])
-> (Int, Int)
-> (Word8, Int, [Word8])
packBits (current, used, bytes) (size, value) =
if bitsWritten < size
then packBits (0, 0, current' : bytes) (size bitsWritten, value)
else if used' == 8
then (0, 0, current' : bytes)
else (current', used', bytes)
where
top = size 1
topOfByte = 7 used
aligned = value `shift` (topOfByte top)
newBits = (fromIntegral aligned) :: Word8
current' = current .|. newBits
bitsWritten = min (8 used) size
used' = used + bitsWritten
bits (U8 v) = BS.pack [((fromIntegral v) :: Word8)]
bits (U16 v) = getBytes ((htons $ fromIntegral v) :: Word16)
bits (U32 v) = getBytes ((htonl $ fromIntegral v) :: Word32)
bits (NullTerminated str) = BS.pack $ (map (fromIntegral . ord) str) ++ [0]
bits (RawString str) = BS.pack $ map (fromIntegral . ord) str
bits (RawByteString bs) = bs
bits (PackBits bitspec) =
if (sum $ map fst bitspec) `mod` 8 /= 0
then error "Sum of sizes of a bit spec must == 0 mod 8"
else (\(_, _, a) -> BS.pack $ reverse a) $ foldl packBits (0, 0, []) bitspec
makeBits :: [BitBlock] -> BS.ByteString
makeBits = BS.concat . (map bits)
data ReadType =
Unsigned Integer |
Variable String |
Skip Integer |
Fixed Integer |
Ignore ReadType |
Context String |
LengthPrefixed |
PackedBits [Integer] |
Rest
fromBytes :: (Bits a) => [a] -> a
fromBytes input =
let dofb accum [] = accum
dofb accum (x:xs) = dofb ((shiftL accum 8) .|. x) xs
in
dofb 0 $ reverse input
decodeU8 :: BS.ByteString -> Word8
decodeU8 = fromIntegral . head . BS.unpack
decodeU16 :: BS.ByteString -> Word16
decodeU16 = htons . fromBytes . map fromIntegral . BS.unpack
decodeU32 :: BS.ByteString -> Word32
decodeU32 = htonl . fromBytes . map fromIntegral . BS.unpack
decodeBits :: [Integer] -> BS.ByteString -> [Integer]
decodeBits sizes bs =
reverse values
where
(values, _, _) = foldl unpackBits ([], 0, BS.unpack bitdata) sizes
bytesize = (sum sizes) `shiftR` 3
(bitdata, rest) = BS.splitAt (fromIntegral bytesize) bs
unpackBits :: ([Integer], Integer, [Word8]) -> Integer -> ([Integer], Integer, [Word8])
unpackBits state size = unpackBitsInner 0 state size
unpackBitsInner :: Integer ->
([Integer], Integer, [Word8]) ->
Integer ->
([Integer], Integer, [Word8])
unpackBitsInner _ (output, used, []) _ = (output, used, [])
unpackBitsInner val (output, used, current : input) bitsToGet =
if bitsToGet' > 0
then unpackBitsInner val'' (output, 0, input) bitsToGet'
else if used' < 8
then (val'' : output, used', current'' : input)
else (val'' : output, 0, input)
where
bitsAv = 8 used
bitsTaken = min bitsAv bitsToGet
val' = val `shift` (fromIntegral bitsTaken)
current' = current `shiftR` (fromIntegral (8 bitsTaken))
current'' = current `shiftL` (fromIntegral bitsTaken)
val'' = val' .|. (fromIntegral current')
bitsToGet' = bitsToGet bitsTaken
used' = used + bitsTaken
readElement :: ([Dec], Name, [Name]) -> ReadType -> Q ([Dec], Name, [Name])
readElement (decs, inputname, tuplenames) (Context funcname) = do
valname <- newName "val"
restname <- newName "rest"
let dec = ValD (TupP [VarP valname, VarP restname])
(NormalB $ AppE (AppE (VarE $ mkName funcname)
(VarE inputname))
(TupE $ map VarE $ reverse tuplenames))
[]
return (dec : decs, restname, valname : tuplenames)
readElement (decs, inputname, tuplenames) (Fixed n) = do
valname <- newName "val"
restname <- newName "rest"
let dec1 = ValD (TupP [VarP valname, VarP restname])
(NormalB $ AppE (AppE (VarE $ mkName "BS.splitAt")
(LitE (IntegerL n)))
(VarE inputname))
[]
return (dec1 : decs, restname, valname : tuplenames)
readElement state@(_, _, tuplenames) (Ignore n) = do
(a, b, c) <- readElement state n
return (a, b, tuplenames)
readElement (decs, inputname, tuplenames) LengthPrefixed = do
valname <- newName "val"
restname <- newName "rest"
let sourcename = head tuplenames
dec = ValD (TupP [VarP valname, VarP restname])
(NormalB $ AppE (AppE (VarE $ mkName "BS.splitAt")
(AppE (VarE $ mkName "fromIntegral")
(VarE sourcename)))
(VarE inputname))
[]
return (dec : decs, restname, valname : tuplenames)
readElement (decs, inputname, tuplenames) (Variable funcname) = do
valname <- newName "val"
restname <- newName "rest"
let dec = ValD (TupP [VarP valname, VarP restname])
(NormalB $ AppE (VarE $ mkName funcname)
(VarE inputname))
[]
return (dec : decs, restname, valname : tuplenames)
readElement (decs, inputname, tuplenames) Rest = do
restname <- newName "rest"
let dec = ValD (VarP restname)
(NormalB $ VarE inputname)
[]
return (dec : decs, inputname, restname : tuplenames)
readElement (decs, inputname, tuplenames) (Skip n) = do
restname <- newName "rest"
let dec = ValD (VarP restname)
(NormalB $ AppE (AppE (VarE $ mkName "BS.drop")
(LitE (IntegerL n)))
(VarE inputname))
[]
return (dec : decs, restname, tuplenames)
readElement state (Unsigned size) = do
let decodefunc = case size of
1 -> "decodeU8"
2 -> "decodeU16"
4 -> "decodeU32"
decodeHelper state (VarE $ mkName decodefunc) size
readElement state (PackedBits sizes) =
if sum sizes `mod` 8 /= 0
then error "Sizes of packed bits must == 0 mod 8"
else decodeHelper state
(AppE (VarE $ mkName "decodeBits")
(ListE $ map (LitE . IntegerL) sizes))
((sum sizes) `shiftR` 3)
decodeHelper (decs, inputname, tuplenames) decodefunc size = do
valname <- newName "val"
restname <- newName "rest"
tuplename <- newName "tup"
let dec1 = ValD (TupP [VarP valname, VarP restname])
(NormalB $ AppE (AppE (VarE $ mkName "BS.splitAt")
(LitE (IntegerL size)))
(VarE inputname))
[]
let dec2 = ValD (VarP tuplename)
(NormalB $ AppE decodefunc (VarE valname))
[]
return (dec1 : dec2 : decs, restname, tuplename : tuplenames)
decGetName (ValD (VarP name) _ _) = name
bitSyn :: [ReadType] -> Q Exp
bitSyn elements = do
inputname <- newName "input"
(lets, restname, tuplenames) <- foldM readElement ([], inputname, []) elements
return $ LamE [VarP inputname] (LetE lets $ TupE $ map VarE $ reverse tuplenames)
prop_bitPacking fields =
prevalues == (map fromIntegral postvalues) ||
any (< 1) (map fst fields) ||
any (< 0) (map snd fields)
where
undershoot = sum (map fst fields) `mod` 8
fields' = if undershoot > 0
then (8 undershoot, 1) : fields
else fields
prevalues = map snd fields'
packed = bits $ PackBits fields'
postvalues = decodeBits (map (fromIntegral . fst) fields') packed