module Data.BitVector.Sized.Internal
(
BitVector(..)
, bv
, bvAnd, bvOr, bvXor
, bvComplement
, bvShift, bvRotate
, bvWidth
, bvTestBit
, bvPopCount
, bvAdd, bvMul
, bvAbs, bvNegate
, bvSignum
, bvConcat, (<:>)
, bvExtract, bvExtractWithRepr
, bvZext, bvZextWithRepr
, bvSext, bvSextWithRepr
, bvIntegerU
, bvIntegerS
) where
import Data.Bits
import Data.Parameterized.Classes
import Data.Parameterized.NatRepr
import GHC.TypeLits
import Text.Printf
import Unsafe.Coerce (unsafeCoerce)
data BitVector (w :: Nat) :: * where
BV :: NatRepr w -> Integer -> BitVector w
bv :: KnownNat w => Integer -> BitVector w
bv x = BV wRepr (truncBits width (fromIntegral x))
where wRepr = knownNat
width = natValue wRepr
bvIntegerU :: BitVector w -> Integer
bvIntegerU (BV _ x) = x
bvIntegerS :: BitVector w -> Integer
bvIntegerS bvec = case bvTestBit bvec (width 1) of
True -> bvIntegerU bvec (1 `shiftL` width)
False -> bvIntegerU bvec
where width = bvWidth bvec
bvAnd :: BitVector w -> BitVector w -> BitVector w
bvAnd (BV wRepr x) (BV _ y) = BV wRepr (x .&. y)
bvOr :: BitVector w -> BitVector w -> BitVector w
bvOr (BV wRepr x) (BV _ y) = BV wRepr (x .|. y)
bvXor :: BitVector w -> BitVector w -> BitVector w
bvXor (BV wRepr x) (BV _ y) = BV wRepr (x `xor` y)
bvComplement :: BitVector w -> BitVector w
bvComplement (BV wRepr x) = BV wRepr (truncBits width (complement x))
where width = natValue wRepr
bvShift :: BitVector w -> Int -> BitVector w
bvShift bvec@(BV wRepr _) shf = BV wRepr (truncBits width (x `shift` shf))
where width = natValue wRepr
x = bvIntegerS bvec
bvRotate :: BitVector w -> Int -> BitVector w
bvRotate bvec rot' = leftChunk `bvOr` rightChunk
where rot = rot' `mod` (bvWidth bvec)
leftChunk = bvShift bvec rot
rightChunk = bvShift bvec (rot bvWidth bvec)
bvWidth :: BitVector w -> Int
bvWidth (BV wRepr _) = fromIntegral (natValue wRepr)
bvTestBit :: BitVector w -> Int -> Bool
bvTestBit (BV _ x) b = testBit x b
bvPopCount :: BitVector w -> Int
bvPopCount (BV _ x) = popCount x
bvAdd :: BitVector w -> BitVector w -> BitVector w
bvAdd (BV wRepr x) (BV _ y) = BV wRepr (truncBits width (x + y))
where width = natValue wRepr
bvMul :: BitVector w -> BitVector w -> BitVector w
bvMul (BV wRepr x) (BV _ y) = BV wRepr (truncBits width (x * y))
where width = natValue wRepr
bvAbs :: BitVector w -> BitVector w
bvAbs bvec@(BV wRepr _) = BV wRepr abs_x
where width = natValue wRepr
x = bvIntegerS bvec
abs_x = truncBits width (abs x)
bvNegate :: BitVector w -> BitVector w
bvNegate (BV wRepr x) = BV wRepr (truncBits width (x))
where width = fromIntegral (natValue wRepr) :: Integer
bvSignum :: BitVector w -> BitVector w
bvSignum bvec@(BV wRepr _) = (bvShift bvec (1 width)) `bvAnd` (BV wRepr 0x1)
where width = fromIntegral (natValue wRepr)
bvConcat :: BitVector v -> BitVector w -> BitVector (v+w)
bvConcat (BV hiWRepr hi) (BV loWRepr lo) =
BV (hiWRepr `addNat` loWRepr) ((hi `shiftL` loWidth) .|. lo)
where loWidth = fromIntegral (natValue loWRepr)
(<:>) :: BitVector v -> BitVector w -> BitVector (v+w)
(<:>) = bvConcat
infixl 6 <:>
bvExtract :: forall w w' . (KnownNat w')
=> Int
-> BitVector w
-> BitVector w'
bvExtract pos bvec = bv xShf
where (BV _ xShf) = bvShift bvec ( pos)
bvExtractWithRepr :: NatRepr w'
-> Int
-> BitVector w
-> BitVector w'
bvExtractWithRepr repr pos bvec = BV repr xShf
where (BV _ xShf) = bvShift bvec ( pos)
bvZext :: forall w w' . KnownNat w'
=> BitVector w
-> BitVector w'
bvZext (BV _ x) = bv x
bvZextWithRepr :: NatRepr w'
-> BitVector w
-> BitVector w'
bvZextWithRepr repr (BV _ x) = BV repr x
bvSext :: forall w w' . KnownNat w'
=> BitVector w
-> BitVector w'
bvSext bvec = bv (bvIntegerS bvec)
bvSextWithRepr :: NatRepr w'
-> BitVector w
-> BitVector w'
bvSextWithRepr repr bvec = BV repr (bvIntegerS bvec)
instance Show (BitVector w) where
show (BV wRepr val) = prettyHex width val
where width = natValue wRepr
instance ShowF BitVector
instance Eq (BitVector w) where
(BV _ x) == (BV _ y) = x == y
instance EqF BitVector where
(BV _ x) `eqF` (BV _ y) = x == y
instance TestEquality BitVector where
testEquality (BV wRepr x) (BV wRepr' y) =
case natValue wRepr == natValue wRepr' && x == y of
True -> Just (unsafeCoerce (Refl :: a :~: a))
False -> Nothing
instance KnownNat w => Bits (BitVector w) where
(.&.) = bvAnd
(.|.) = bvOr
xor = bvXor
complement = bvComplement
shift = bvShift
rotate = bvRotate
bitSize = bvWidth
bitSizeMaybe = Just . bvWidth
isSigned = const False
testBit = bvTestBit
bit = bv . bit
popCount = bvPopCount
instance KnownNat w => FiniteBits (BitVector w) where
finiteBitSize = bvWidth
instance KnownNat w => Num (BitVector w) where
(+) = bvAdd
(*) = bvMul
abs = bvAbs
signum = bvSignum
fromInteger = bv
negate = bvNegate
instance KnownNat w => Enum (BitVector w) where
toEnum = bv . fromIntegral
fromEnum = fromIntegral . bvIntegerU
instance KnownNat w => Bounded (BitVector w) where
minBound = bv 0
maxBound = bv (1)
prettyHex :: (Integral a, PrintfArg a, Show a) => a -> Integer -> String
prettyHex width val = printf format val width
where numDigits = (width+3) `div` 4
format = "0x%." ++ show numDigits ++ "x<%d>"
lowMask :: (Integral a, Bits b) => a -> b
lowMask numBits = complement (complement zeroBits `shiftL` fromIntegral numBits)
truncBits :: (Integral a, Bits b) => a -> b -> b
truncBits width b = b .&. lowMask width