module Data.BitSyntax (
BitBlock(..),
makeBits,
ReadType(..), bitSyn,
decodeU8, decodeU16, decodeU32, decodeU16LE, decodeU32LE) where
import Language.Haskell.TH.Lib
import Language.Haskell.TH.Syntax
import qualified Data.ByteString as BS
import Data.Char (ord)
import Control.Monad
import Test.QuickCheck (Arbitrary(), arbitrary, Gen())
import Foreign
foreign import ccall unsafe "htonl" htonl :: Word32 -> Word32
foreign import ccall unsafe "htons" htons :: Word16 -> Word16
endianSwitch32 :: Word32 -> Word32
endianSwitch32 a = ((a .&. 0xff) `shiftL` 24) .|.
((a .&. 0xff00) `shiftL` 8) .|.
((a .&. 0xff0000) `shiftR` 8) .|.
(a `shiftR` 24)
endianSwitch16 :: Word16 -> Word16
endianSwitch16 a = ((a .&. 0xff) `shiftL` 8) .|.
(a `shiftR` 8)
littleEndian32 :: Word32 -> Word32
littleEndian32 a = if htonl 1 == 1
then endianSwitch32 a
else a
littleEndian16 :: Word16 -> Word16
littleEndian16 a = if htonl 1 == 1
then endianSwitch16 a
else a
data BitBlock =
U8 Int |
U16 Int |
U32 Int |
U16LE Int |
U32LE Int |
NullTerminated String |
RawString String |
RawByteString BS.ByteString |
PackBits [(Int, Int)]
deriving (Show)
getBytes :: (Integral a, Bounded a, Bits a) => a -> BS.ByteString
getBytes input =
let getByte _ 0 = []
getByte x remaining = (fromIntegral $ (x .&. 0xff)) :
getByte (shiftR x 8) (remaining 1)
in
if (bitSize input `mod` 8) /= 0
then error "Input data bit size must be a multiple of 8"
else BS.pack $ getByte input (bitSize input `div` 8)
packBits :: (Word8, Int, [Word8])
-> (Int, Int)
-> (Word8, Int, [Word8])
packBits (current, used, bytes) (size, value) =
if bitsWritten < size
then packBits (0, 0, current' : bytes) (size bitsWritten, value)
else if used' == 8
then (0, 0, current' : bytes)
else (current', used', bytes)
where
top = size 1
topOfByte = 7 used
aligned = value `shift` (topOfByte top)
newBits = (fromIntegral aligned) :: Word8
current' = current .|. newBits
bitsWritten = min (8 used) size
used' = used + bitsWritten
bits :: BitBlock -> BS.ByteString
bits (U8 v) = BS.pack [((fromIntegral v) :: Word8)]
bits (U16 v) = getBytes ((htons $ fromIntegral v) :: Word16)
bits (U32 v) = getBytes ((htonl $ fromIntegral v) :: Word32)
bits (U16LE v) = getBytes (littleEndian16 $ fromIntegral v)
bits (U32LE v) = getBytes (littleEndian32 $ fromIntegral v)
bits (NullTerminated str) = BS.pack $ (map (fromIntegral . ord) str) ++ [0]
bits (RawString str) = BS.pack $ map (fromIntegral . ord) str
bits (RawByteString bs) = bs
bits (PackBits bitspec) =
if (sum $ map fst bitspec) `mod` 8 /= 0
then error "Sum of sizes of a bit spec must == 0 mod 8"
else (\(_, _, a) -> BS.pack $ reverse a) $ foldl packBits (0, 0, []) bitspec
makeBits :: [BitBlock] -> BS.ByteString
makeBits = BS.concat . (map bits)
data ReadType =
Unsigned Integer |
UnsignedLE Integer |
Variable Name |
Skip Integer |
Fixed Integer |
Ignore ReadType |
Context Name |
LengthPrefixed |
PackedBits [Integer] |
Rest
fromBytes :: (Num a, Bits a) => [a] -> a
fromBytes input =
let dofb accum [] = accum
dofb accum (x:xs) = dofb ((shiftL accum 8) .|. x) xs
in
dofb 0 $ reverse input
decodeU8 :: BS.ByteString -> Word8
decodeU8 = fromIntegral . head . BS.unpack
decodeU16 :: BS.ByteString -> Word16
decodeU16 = htons . fromBytes . map fromIntegral . BS.unpack
decodeU32 :: BS.ByteString -> Word32
decodeU32 = htonl . fromBytes . map fromIntegral . BS.unpack
decodeU16LE :: BS.ByteString -> Word16
decodeU16LE = littleEndian16 . fromBytes . map fromIntegral . BS.unpack
decodeU32LE :: BS.ByteString -> Word32
decodeU32LE = littleEndian32 . fromBytes . map fromIntegral . BS.unpack
decodeBits :: [Integer] -> BS.ByteString -> [Integer]
decodeBits sizes bs =
reverse values
where
(values, _, _) = foldl unpackBits ([], 0, BS.unpack bitdata) sizes
bytesize = (sum sizes) `shiftR` 3
(bitdata, _) = BS.splitAt (fromIntegral bytesize) bs
unpackBits :: ([Integer], Integer, [Word8]) -> Integer -> ([Integer], Integer, [Word8])
unpackBits state size = unpackBitsInner 0 state size
unpackBitsInner :: Integer ->
([Integer], Integer, [Word8]) ->
Integer ->
([Integer], Integer, [Word8])
unpackBitsInner _ (output, used, []) _ = (output, used, [])
unpackBitsInner val (output, used, current : input) bitsToGet =
if bitsToGet' > 0
then unpackBitsInner val'' (output, 0, input) bitsToGet'
else if used' < 8
then (val'' : output, used', current'' : input)
else (val'' : output, 0, input)
where
bitsAv = 8 used
bitsTaken = min bitsAv bitsToGet
val' = val `shift` (fromIntegral bitsTaken)
current' = current `shiftR` (fromIntegral (8 bitsTaken))
current'' = current `shiftL` (fromIntegral bitsTaken)
val'' = val' .|. (fromIntegral current')
bitsToGet' = bitsToGet bitsTaken
used' = used + bitsTaken
readElement :: ([Stmt], Name, [Name]) -> ReadType -> Q ([Stmt], Name, [Name])
readElement (stmts, inputname, tuplenames) (Context funcname) = do
valname <- newName "val"
restname <- newName "rest"
let stmt = BindS (TupP [VarP valname, VarP restname])
(AppE (AppE (VarE funcname)
(VarE inputname))
(TupE $ map VarE $ reverse tuplenames))
return (stmt : stmts, restname, valname : tuplenames)
readElement (stmts, inputname, tuplenames) (Fixed n) = do
valname <- newName "val"
restname <- newName "rest"
let dec1 = ValD (TupP [VarP valname, VarP restname])
(NormalB $ AppE (AppE (VarE 'BS.splitAt)
(LitE (IntegerL n)))
(VarE inputname))
[]
return (LetS [dec1] : stmts, restname, valname : tuplenames)
readElement state@(_, _, tuplenames) (Ignore n) = do
(a, b, _) <- readElement state n
return (a, b, tuplenames)
readElement (stmts, inputname, tuplenames) LengthPrefixed = do
valname <- newName "val"
restname <- newName "rest"
let sourcename = head tuplenames
dec = ValD (TupP [VarP valname, VarP restname])
(NormalB $ AppE (AppE (VarE 'BS.splitAt)
(AppE (VarE 'fromIntegral)
(VarE sourcename)))
(VarE inputname))
[]
return (LetS [dec] : stmts, restname, valname : tuplenames)
readElement (stmts, inputname, tuplenames) (Variable funcname) = do
valname <- newName "val"
restname <- newName "rest"
let stmt = BindS (TupP [VarP valname, VarP restname])
(AppE (VarE funcname) (VarE inputname))
return (stmt : stmts, restname, valname : tuplenames)
readElement (stmts, inputname, tuplenames) Rest = do
restname <- newName "rest"
let dec = ValD (VarP restname)
(NormalB $ VarE inputname)
[]
return (LetS [dec] : stmts, inputname, restname : tuplenames)
readElement (stmts, inputname, tuplenames) (Skip n) = do
restname <- newName "rest"
let dec = ValD (VarP restname)
(NormalB $ AppE (AppE (VarE 'BS.drop)
(LitE (IntegerL n)))
(VarE inputname))
[]
return (LetS [dec] : stmts, restname, tuplenames)
readElement state (Unsigned size) = do
let decodefunc = case size of
1 -> 'decodeU8
2 -> 'decodeU16
_ -> 'decodeU32
decodeHelper state (VarE decodefunc) size
readElement state (UnsignedLE size) = do
let decodefunc = case size of
2 -> 'decodeU16LE
_ -> 'decodeU32LE
decodeHelper state (VarE decodefunc) size
readElement state (PackedBits sizes) =
if sum sizes `mod` 8 /= 0
then error "Sizes of packed bits must == 0 mod 8"
else decodeHelper state
(AppE (VarE 'decodeBits)
(ListE $ map (LitE . IntegerL) sizes))
((sum sizes) `shiftR` 3)
decodeHelper :: ([Stmt], Name, [Name]) -> Exp
-> Integer
-> Q ([Stmt], Name, [Name])
decodeHelper (stmts, inputname, tuplenames) decodefunc size = do
valname <- newName "val"
restname <- newName "rest"
tuplename <- newName "tup"
let dec1 = ValD (TupP [VarP valname, VarP restname])
(NormalB $ AppE (AppE (VarE 'BS.splitAt)
(LitE (IntegerL size)))
(VarE inputname))
[]
let dec2 = ValD (VarP tuplename)
(NormalB $ AppE decodefunc (VarE valname))
[]
return (LetS [dec1, dec2] : stmts, restname, tuplename : tuplenames)
decGetName :: Dec -> Name
decGetName (ValD (VarP name) _ _) = name
decGetName _ = undefined
bitSyn :: [ReadType] -> Q Exp
bitSyn elements = do
inputname <- newName "input"
(stmts, restname, tuplenames) <- foldM readElement ([], inputname, []) elements
returnS <- NoBindS `liftM` [| return $(tupE . map varE $ reverse tuplenames) |]
return $ LamE [VarP inputname] (DoE . reverse $ returnS : stmts)
prop_bitPacking :: [(Int, Int)] -> Bool
prop_bitPacking fields =
prevalues == (map fromIntegral postvalues) ||
any (< 1) (map fst fields) ||
any (< 0) (map snd fields)
where
undershoot = sum (map fst fields) `mod` 8
fields' = if undershoot > 0
then (8 undershoot, 1) : fields
else fields
prevalues = map snd fields'
packed = bits $ PackBits fields'
postvalues = decodeBits (map (fromIntegral . fst) fields') packed
#if !MIN_VERSION_QuickCheck(2,1,2)
instance Arbitrary Word16 where
arbitrary = (arbitrary :: Gen Int) >>= return . fromIntegral
instance Arbitrary Word32 where
arbitrary = (arbitrary :: Gen Int) >>= return . fromIntegral
#endif
prop_nativeByteShuffle32 :: Word32 -> Bool
prop_nativeByteShuffle32 x = endianSwitch32 x == htonl x
prop_nativeByteShuffle16 :: Word16 -> Bool
prop_nativeByteShuffle16 x = endianSwitch16 x == htons x
prop_littleEndian16 :: Word16 -> Bool
prop_littleEndian16 x = littleEndian16 x == x
prop_littleEndian32 :: Word32 -> Bool
prop_littleEndian32 x = littleEndian32 x == x