{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
module Data.BitVector.Sized
(
BitVector(..)
, bitVector
, bv0
, bvAnd, bvOr, bvXor
, bvComplement
, bvShift, bvShiftL, bvShiftRA, bvShiftRL, bvRotate
, bvWidth
, bvTestBit
, bvPopCount
, bvTruncBits
, bvAdd, bvMul
, bvQuotU, bvQuotS
, bvRemU, bvRemS
, bvAbs, bvNegate
, bvSignum
, bvLTS, bvLTU
, bvConcat, (<:>), bvConcatMany, bvConcatManyWithRepr
, bvExtract, bvExtractWithRepr
, bvZext, bvZextWithRepr
, bvSext, bvSextWithRepr
, bvMulFU, bvMulFS, bvMulFSU
, bvIntegerU
, bvIntegerS
, bvGetBytesU
) where
import Data.Bits
import Data.Ix
import Data.Parameterized
import GHC.TypeLits
import Numeric
import System.Random
import Test.QuickCheck (Arbitrary(..), choose)
import Text.PrettyPrint.HughesPJClass
import Text.Printf
import Unsafe.Coerce (unsafeCoerce)
data BitVector (w :: Nat) :: * where
BV :: NatRepr w -> Integer -> BitVector w
bitVector :: KnownNat w => Integer -> BitVector w
bitVector x = BV wRepr (truncBits width (fromIntegral x))
where wRepr = knownNat
width = natValue wRepr
bv0 :: BitVector 0
bv0 = bitVector 0
bvIntegerU :: BitVector w -> Integer
bvIntegerU (BV _ x) = x
bvIntegerS :: BitVector w -> Integer
bvIntegerS bv = if bvTestBit bv (width - 1)
then bvIntegerU bv - (1 `shiftL` width)
else bvIntegerU bv
where width = bvWidth bv
bvAnd :: BitVector w -> BitVector w -> BitVector w
bvAnd (BV wRepr x) (BV _ y) = BV wRepr (x .&. y)
bvOr :: BitVector w -> BitVector w -> BitVector w
bvOr (BV wRepr x) (BV _ y) = BV wRepr (x .|. y)
bvXor :: BitVector w -> BitVector w -> BitVector w
bvXor (BV wRepr x) (BV _ y) = BV wRepr (x `xor` y)
bvComplement :: BitVector w -> BitVector w
bvComplement (BV wRepr x) = BV wRepr (truncBits width (complement x))
where width = natValue wRepr
bvShift :: BitVector w -> Int -> BitVector w
bvShift bv@(BV wRepr _) shf = BV wRepr (truncBits width (x `shift` shf))
where width = natValue wRepr
x = bvIntegerS bv
toPos :: Int -> Int
toPos x | x < 0 = 0
toPos x = x
bvShiftL :: BitVector w -> Int -> BitVector w
bvShiftL bv shf = bvShift bv (toPos shf)
bvShiftRA :: BitVector w -> Int -> BitVector w
bvShiftRA bv shf = bvShift bv (- (toPos shf))
bvShiftRL :: BitVector w -> Int -> BitVector w
bvShiftRL bv@(BV wRepr _) shf = BV wRepr (truncBits width (x `shift` (- toPos shf)))
where width = natValue wRepr
x = bvIntegerU bv
bvRotate :: BitVector w -> Int -> BitVector w
bvRotate bv rot' = leftChunk `bvOr` rightChunk
where rot = rot' `mod` bvWidth bv
leftChunk = bvShift bv rot
rightChunk = bvShift bv (rot - bvWidth bv)
bvWidth :: BitVector w -> Int
bvWidth (BV wRepr _) = fromIntegral (natValue wRepr)
bvTestBit :: BitVector w -> Int -> Bool
bvTestBit (BV _ x) b = testBit x b
bvPopCount :: BitVector w -> Int
bvPopCount (BV _ x) = popCount x
bvTruncBits :: BitVector w -> Int -> BitVector w
bvTruncBits (BV wRepr x) b = BV wRepr (truncBits b x)
bvAdd :: BitVector w -> BitVector w -> BitVector w
bvAdd (BV wRepr x) (BV _ y) = BV wRepr (truncBits width (x + y))
where width = natValue wRepr
bvMul :: BitVector w -> BitVector w -> BitVector w
bvMul (BV wRepr x) (BV _ y) = BV wRepr (truncBits width (x * y))
where width = natValue wRepr
bvQuotU :: BitVector w -> BitVector w -> BitVector w
bvQuotU (BV wRepr x) (BV _ y) = BV wRepr (x `quot` y)
bvQuotS :: BitVector w -> BitVector w -> BitVector w
bvQuotS bv1@(BV wRepr _) bv2 = BV wRepr (truncBits width (x `quot` y))
where x = bvIntegerS bv1
y = bvIntegerS bv2
width = natValue wRepr
bvRemU :: BitVector w -> BitVector w -> BitVector w
bvRemU (BV wRepr x) (BV _ y) = BV wRepr (x `rem` y)
bvRemS :: BitVector w -> BitVector w -> BitVector w
bvRemS bv1@(BV wRepr _) bv2 = BV wRepr (truncBits width (x `rem` y))
where x = bvIntegerS bv1
y = bvIntegerS bv2
width = natValue wRepr
bvAbs :: BitVector w -> BitVector w
bvAbs bv@(BV wRepr _) = BV wRepr abs_x
where width = natValue wRepr
x = bvIntegerS bv
abs_x = truncBits width (abs x)
bvNegate :: BitVector w -> BitVector w
bvNegate (BV wRepr x) = BV wRepr (truncBits width (-x))
where width = fromIntegral (natValue wRepr) :: Integer
bvSignum :: BitVector w -> BitVector w
bvSignum bv@(BV wRepr _) = bvShift bv (1 - width) `bvAnd` BV wRepr 0x1
where width = fromIntegral (natValue wRepr)
bvLTS :: BitVector w -> BitVector w -> Bool
bvLTS bv1 bv2 = bvIntegerS bv1 < bvIntegerS bv2
bvLTU :: BitVector w -> BitVector w -> Bool
bvLTU bv1 bv2 = bvIntegerU bv1 < bvIntegerU bv2
bvConcat :: BitVector v -> BitVector w -> BitVector (v+w)
bvConcat (BV hiWRepr hi) (BV loWRepr lo) =
BV (hiWRepr `addNat` loWRepr) ((hi `shiftL` loWidth) .|. lo)
where loWidth = fromIntegral (natValue loWRepr)
(<:>) :: BitVector v -> BitVector w -> BitVector (v+w)
(<:>) = bvConcat
bvConcatSome :: Some BitVector -> Some BitVector -> Some BitVector
bvConcatSome (Some bv1) (Some bv2) = Some (bv2 <:> bv1)
bvConcatMany :: KnownNat w' => [BitVector w] -> BitVector w'
bvConcatMany = bvConcatManyWithRepr knownNat
bvConcatManyWithRepr :: NatRepr w' -> [BitVector w] -> BitVector w'
bvConcatManyWithRepr wRepr bvs =
viewSome (bvZextWithRepr wRepr) $ foldl bvConcatSome (Some bv0) (Some <$> bvs)
infixl 6 <:>
bvExtract :: forall w w' . (KnownNat w')
=> Int
-> BitVector w
-> BitVector w'
bvExtract pos bv = bitVector xShf
where (BV _ xShf) = bvShift bv (- pos)
bvExtractWithRepr :: NatRepr w'
-> Int
-> BitVector w
-> BitVector w'
bvExtractWithRepr repr pos bv = BV repr (truncBits width xShf)
where (BV _ xShf) = bvShift bv (- pos)
width = natValue repr
bvZext :: forall w w' . KnownNat w'
=> BitVector w
-> BitVector w'
bvZext (BV _ x) = bitVector x
bvZextWithRepr :: NatRepr w'
-> BitVector w
-> BitVector w'
bvZextWithRepr repr (BV _ x) = BV repr (truncBits width x)
where width = natValue repr
bvSext :: forall w w' . KnownNat w'
=> BitVector w
-> BitVector w'
bvSext bv = bitVector (bvIntegerS bv)
bvSextWithRepr :: NatRepr w'
-> BitVector w
-> BitVector w'
bvSextWithRepr repr bv = BV repr (truncBits width (bvIntegerS bv))
where width = natValue repr
bvMulFU :: BitVector w -> BitVector w' -> BitVector (w+w')
bvMulFU (BV wRepr x) (BV wRepr' y) = BV (wRepr `addNat` wRepr') (x*y)
bvMulFS :: BitVector w -> BitVector w' -> BitVector (w+w')
bvMulFS bv1@(BV wRepr _) bv2@(BV wRepr' _) = BV prodRepr (truncBits width (x'*y'))
where x' = bvIntegerS bv1
y' = bvIntegerS bv2
prodRepr = wRepr `addNat` wRepr'
width = natValue prodRepr
bvMulFSU :: BitVector w -> BitVector w' -> BitVector (w+w')
bvMulFSU bv1@(BV wRepr _) bv2@(BV wRepr' _) = BV prodRepr (truncBits width (x'*y'))
where x' = bvIntegerS bv1
y' = bvIntegerU bv2
prodRepr = wRepr `addNat` wRepr'
width = natValue prodRepr
bvGetBytesU :: Int -> BitVector w -> [BitVector 8]
bvGetBytesU n _ | n <= 0 = []
bvGetBytesU n bv = bvExtract 0 bv : bvGetBytesU (n-1) (bvShiftRL bv 8)
instance Show (BitVector w) where
show (BV _ x) = "0x" ++ showHex x ""
instance KnownNat w => Read (BitVector w) where
readsPrec s =
(fmap . fmap) (\(a,s') -> (bitVector a, s')) (readsPrec s :: ReadS Integer)
instance ShowF BitVector
instance Eq (BitVector w) where
(BV _ x) == (BV _ y) = x == y
instance EqF BitVector where
(BV _ x) `eqF` (BV _ y) = x == y
instance Ord (BitVector w) where
(BV _ x) `compare` (BV _ y) = x `compare` y
instance TestEquality BitVector where
testEquality (BV wRepr x) (BV wRepr' y) =
if natValue wRepr == natValue wRepr' && x == y
then Just (unsafeCoerce (Refl :: a :~: a))
else Nothing
instance KnownNat w => Bits (BitVector w) where
(.&.) = bvAnd
(.|.) = bvOr
xor = bvXor
complement = bvComplement
shift = bvShift
rotate = bvRotate
bitSize = bvWidth
bitSizeMaybe = Just . bvWidth
isSigned = const False
testBit = bvTestBit
bit = bitVector . bit
popCount = bvPopCount
instance KnownNat w => FiniteBits (BitVector w) where
finiteBitSize = bvWidth
instance KnownNat w => Num (BitVector w) where
(+) = bvAdd
(*) = bvMul
abs = bvAbs
signum = bvSignum
fromInteger = bitVector
negate = bvNegate
instance KnownNat w => Enum (BitVector w) where
toEnum = bitVector . fromIntegral
fromEnum = fromIntegral . bvIntegerU
instance KnownNat w => Ix (BitVector w) where
range (lo, hi) = bitVector <$> [bvIntegerU lo .. bvIntegerU hi]
index (lo, hi) bv = index (bvIntegerU lo, bvIntegerU hi) (bvIntegerU bv)
inRange (lo, hi) bv = inRange (bvIntegerU lo, bvIntegerU hi) (bvIntegerU bv)
instance KnownNat w => Bounded (BitVector w) where
minBound = bitVector 0
maxBound = bitVector (-1)
instance KnownNat w => Arbitrary (BitVector w) where
arbitrary = choose (minBound, maxBound)
instance KnownNat w => Random (BitVector w) where
randomR (bvLo, bvHi) gen =
let (x, gen') = randomR (bvIntegerU bvLo, bvIntegerU bvHi) gen
in (bitVector x, gen')
random gen =
let (x, gen') = random gen
in (bitVector x, gen')
prettyHex :: (Integral a, PrintfArg a, Show a) => a -> Integer -> String
prettyHex width val = printf format val width
where numDigits = (width+3) `quot` 4
format = "0x%." ++ show numDigits ++ "x<%d>"
instance Pretty (BitVector w) where
pPrint (BV wRepr x) = text $ prettyHex (natValue wRepr) x
lowMask :: (Integral a, Bits b) => a -> b
lowMask numBits = complement (complement zeroBits `shiftL` fromIntegral numBits)
truncBits :: (Integral a, Bits b) => a -> b -> b
truncBits width b = b .&. lowMask width