```{-# OPTIONS_GHC -funbox-strict-fields #-}

{-# LANGUAGE BangPatterns #-}

-- |
-- Module    : Data.BitVector
-- Copyright : (c) Iago Abal, 2012
-- Maintainer: Iago Abal <iago.abal@gmail.com>
--
-- Implementation of bit-vectors as wrappers over 'Integer'.
--
-- * Bit-vectors are interpreted as unsigned integers
--   (i.e. natural numbers) except for some very specific cases.
--
-- * Bit-vectors are /size-polymorphic/ insofar as most operations treat
--   a bit-vector of size /k/ as of size /n/ for /n >= k/ if required.
--
-- For documentation purposes we will write @[n]k@ to denote a bit-vector
-- of size @n@ representing the natural number @k@.
--
module Data.BitVector
( -- * Bit-vectors
BitVector
, BV
, size, width
, nat, uint, int
-- * Creation
, bitVec
, ones, zeros
-- * Comparison
, (==.), (/=.)
-- * Indexing
, (@.), (@@)
, (!.)
, least, most
, msb, lsb, msb1
-- * Arithmetic
, sdiv, srem, smod
, lg2
-- * List-like operations
, (#)
, zeroExtend, signExtend
, foldl_, foldr_
, reverse_
, replicate_
-- * Bitwise operations
, module Data.Bits
, not_, nand, nor, xnor
, (<<.), (>>.), ashr, (<<<.), (>>>.)
-- * Conversion
, fromBits
, toBits
-- * Utilities
, maxNat
, integerWidth
) where

import Control.Exception ( assert )

import Data.Bits
import Data.Ord

----------------------------------------------------------------------
--- Bit-vectors

-- | Big-endian /pseudo size-polymorphic/ bit-vectors.
--
data BV
= BV {
size :: !Int      -- ^ The /size/ of a bit-vector.
, nat  :: !Integer  -- ^ The value of a bit-vector, as a natural number.
}

-- | An alias for 'BV'.
--
type BitVector = BV

-- | An alias for 'size'.
--
width :: BV -> Int
width = size

-- | An alias for 'nat'.
--
uint :: BV -> Integer
uint = nat

-- | 2's complement value of a bit-vector.
int :: BV -> Integer
int u | msb u     = - nat(-u)
| otherwise = nat u

instance Show BV where
show (BV n a) = "[" ++ show n ++ "]" ++ show a

----------------------------------------------------------------------
--- Construction

-- | Create a bit-vector given a size and an integer value.
--
-- >>> bitVec 4 3
-- [4]3
--
-- This function also handles negative values.
--
-- >>> bitVec 4 (-1)
-- [4]15
--
bitVec :: Integral a => Int -> a -> BV
bitVec n a | a >= 0    = BV n \$ fromIntegral a
| otherwise = negate \$ BV n \$ fromIntegral (-a)
{-# INLINE bitVec #-}

-- | Create a mask of ones.
--
ones :: Int -> BV
ones n = BV n \$ 2^n - 1
{-# INLINE ones #-}

-- | Create a mask of zeros.
--
zeros :: Int -> BV
zeros n = BV n 0
{-# INLINE zeros #-}

----------------------------------------------------------------------
--- Comparison

instance Eq BV where
(BV _ a) == (BV _ b) = a == b

instance Ord BV where
compare = comparing nat

-- | Fixed-size equality.
--
-- In contrast with '==', which is /size-polymorphic/, this equality
-- requires both bit-vectors to be of equal size.
--
-- >>> [n]k ==. [m]k
-- False
--
-- >>> [n]k == [n]k
-- True
--
(==.) :: BV -> BV -> Bool
(BV n a) ==. (BV m b) = n == m && a == b

-- | Fixed-size inequality.
--
-- The negated version of '==.'.
--
(/=.) :: BV -> BV -> Bool
u /=. v = not \$ u ==. v

----------------------------------------------------------------------
--- Indexing

-- | Bit indexing.
--
-- @u \@. i@ stands for the /i/-th bit of /u/.
--
-- >>> [4]2 @. 0
-- False
--
-- >>> [4]2 @. 1
-- True
--
(@.) :: BV -> Int -> Bool
(BV _ a) @. i = testBit a i
{-# INLINE (@.) #-}

-- | Bit-string extraction.
--
-- @u \@\@ (j,i) == fromBits (map (u \@.) [j,j-1..i])@
--
-- >>> [4]7 @@ (3,1)
-- [3]3
--
(@@) :: BV -> (Int,Int) -> BV
(BV _ a) @@ (j,i) = assert (j >= i) \$
BV m \$ (a `shiftR` i) `mod` 2^m
where m = j - i + 1

-- | Reverse bit-indexing.
--
-- Index from the end of the sequenc
--
-- @u !. i == u \@. (size u - i - 1) @
--
-- >>> [3]3 !. 0
-- False
--
(!.) :: BV -> Int -> Bool
(BV n a) !. i = assert (i < n) \$ testBit a (n-i-1)
{-# INLINE (!.) #-}

-- | Take least significant bits.
--
-- @least m u == u \@\@ (m-1,0)@
--
least :: Int -> BV -> BV
least m (BV _ a) = assert (m >= 1) \$
BV m \$ a `mod` 2^m

-- | Take most significant bits.
--
-- @most m u == u \@\@ (n-1,n-m)@
--
most :: Int -> BV -> BV
most m (BV n a) = assert (m >= 1 && m <= n) \$
BV m \$ a `shiftR` (n-m)

-- | Most significant bit.
--
-- @msb u == u !. 0@
--
msb :: BV -> Bool
msb = (!. 0)
{-# INLINE msb #-}

-- | Least significant bit.
--
-- @lsb u == u \@. 0@
--
lsb :: BV -> Bool
lsb = (@. 0)
{-# INLINE lsb #-}

-- | Most significant 1-bit.
--
-- /Pre/: input must be non-zero.
--
-- >>> msb1 [4]2
-- 1
--
-- >>> msb1 [4]4
-- 2
--
msb1 :: BV -> Int
msb1 (BV _ 0) = error "Data.BitVector.msb1: zero bit-vector"
msb1 (BV n a) = go (n-1)
where go i | testBit a i = i
| otherwise   = go (i-1)

----------------------------------------------------------------------
--- Arithmetic

instance Num BV where
(BV n1 a) + (BV n2 b) = BV n \$ (a + b) `mod` 2^n
where n = max n1 n2
(BV n1 a) * (BV n2 b) = BV n \$ (a * b) `mod` 2^n
where n = max n1 n2
negate (BV n a) = BV n \$ 2^n - a
abs u | msb u     = negate u
| otherwise = u
signum u = bitVec 2 \$ int u
fromInteger i = bitVec (integerWidth i) i

instance Real BV where
toRational = toRational . nat

instance Enum BV where
toEnum = fromIntegral
fromEnum (BV _ a) = assert (a < max_int) \$ fromIntegral a
where max_int = toInteger (maxBound::Int)

instance Integral BV where
quotRem (BV n1 a) (BV n2 b) = (BV n q,BV n r)
where n = max n1 n2
(q,r) = quotRem a b
divMod = quotRem
toInteger = nat

-- | 2's complement signed division.
--
sdiv :: BV -> BV -> BV
sdiv u@(BV n1 _) v@(BV n2 _) = bitVec n q
where n = max n1 n2
q = int u `quot` int v

-- | 2's complement signed remainder (sign follows dividend).
--
srem :: BV -> BV -> BV
srem u@(BV n1 _) v@(BV n2 _) = bitVec n r
where n = max n1 n2
r = int u `rem` int v

-- | 2's complement signed remainder (sign follows divisor).
--
smod :: BV -> BV -> BV
smod u@(BV n1 _) v@(BV n2 _) = bitVec n r
where n = max n1 n2
r = int u `mod` int v

-- | Ceiling logarithm base 2.
--
-- /Pre/: input bit-vector must be non-zero.
--
lg2 :: BV -> BV
lg2 (BV _ 0) = error "Data.BitVector.lg2: zero bit-vector"
lg2 (BV n 1) = BV n 0
lg2 (BV n a) = BV n \$ toInteger \$ integerWidth (a-1)

----------------------------------------------------------------------
--- List-like operations

-- | Concatenation of two bit-vectors.
--
(#) :: BV -> BV -> BV
(BV n a) # (BV m b) = BV (n + m) ((a `shiftL` m) + b)
{-# INLINABLE (#) #-}

-- | Logical extension.
--
-- >>> zeroExtend 3 [1]1
-- [4]1
--
zeroExtend :: Int -> BV -> BV
zeroExtend d (BV n a) = BV (n+d) a
{-# INLINE zeroExtend #-}

-- | Arithmetic extension.
--
-- >>> signExtend 2 [2]1
-- [4]1
--
-- >>> signExtend 2 [2]3
-- [4]15
--
signExtend :: Int -> BV -> BV
signExtend d (BV n a)
| testBit a (n-1) = BV (n+d) \$ (maxNat d `shiftL` n) + a
| otherwise       = BV (n+d) a

-- |
-- @foldl_ f z (fromBits [un, ..., u1, u0]) == ((((z \`f\` un) \`f\` ...) \`f\` u1) \`f\` u0)@
--
-- @foldl_ f e = fromBits . foldl f e . toBits@
--
foldl_ :: (a -> Bool -> a) -> a -> BV -> a
foldl_ f e (BV n a) = go (n-1) e
where go i !x | i >= 0    = let !b = testBit a i in go (i-1) \$ f x b
| otherwise = x
{-# INLINE foldl_ #-}

-- |
-- @foldr_ f z (fromBits [un, ..., u1, u0]) == un `f` (... `f` (u1 \`f\` (u0 \`f\` z)))@
--
-- @foldr_ f e = fromBits . foldr f e . toBits@
--
foldr_ :: (Bool -> a -> a) -> a -> BV -> a
foldr_ f e (BV n a) = go (n-1) e
where go i !x | i >= 0    = let !b = testBit a i in f b (go (i-1) x)
| otherwise = x
{-# INLINE foldr_ #-}

-- |
-- @reverse_ == fromBits . reverse . toBits@
--
reverse_ :: BV -> BV
reverse_ bv@(BV n _) = BV n \$ snd \$ foldl_ go (1,0) bv
where go (v,acc) b | b         = (v',acc+v)
| otherwise = (v',acc)
where v' = 2*v

-- |
-- /Pre/: if @replicate_ n u@ then @n > 0@ must hold.
--
-- @replicate_ n == fromBits . concat . replicate n . toBits @
--
replicate_ :: Int -> BV -> BV
replicate_ 0 _ = error "Data.BitVector.replicate_: cannot replicate 0-times"
replicate_ n u = go (n-1) u
where go 0 !acc = acc
go k !acc = go (k-1) (u # acc)

----------------------------------------------------------------------
--- Bitwise operations

instance Bits BV where
(BV n1 a) .&. (BV n2 b) = BV n \$ a .&. b
where n = max n1 n2
(BV n1 a) .|. (BV n2 b) = BV n \$ a .|. b
where n = max n1 n2
(BV n1 a) `xor` (BV n2 b) = BV n \$ a `xor` b
where n = max n1 n2
complement (BV n a) = BV n \$ 2^n - 1 - a
bit i = BV (i+1) (2^i)
testBit (BV n a) i | i < n     = testBit a i
| otherwise = False
bitSize = undefined
isSigned = const False
shiftL (BV n a) k
| k > n     = BV n 0
| otherwise = BV n \$ shiftL a k `mod` 2^n
shiftR (BV n a) k
| k > n     = BV n 0
| otherwise = BV n \$ shiftR a k
rotateL bv       0 = bv
rotateL (BV n a) k
| k == n    = BV n a
| k > n     = rotateL (BV n a) (k `mod` n)
| otherwise = BV n \$ h + l
where s = n - k
l = a `shiftR` s
h = (a `shiftL` k) `mod` 2^n
rotateR bv       0 = bv
rotateR (BV n a) k
| k == n    = BV n a
| k > n     = rotateR (BV n a) (k `mod` n)
| otherwise = BV n \$ h + l
where s = n - k
l = a `shiftR` k
h = (a `shiftL` s) `mod` 2^n

-- | An alias for 'complement'.
--
not_ :: BV -> BV
not_ = complement
{-# INLINE not_ #-}

-- | Negated '.&.'.
--
nand :: BV -> BV -> BV
nand u v = not_ \$ u .&. v
{-# INLINE nand #-}

-- | Negated '.|.'.
--
nor :: BV -> BV -> BV
nor u v = not_ \$ u .|. v
{-# INLINE nor #-}

-- | Negated 'xor'.
--
xnor :: BV -> BV -> BV
xnor u v = not_ \$ u `xor` v
{-# INLINE xnor #-}

-- | Left shift.
--
(<<.) :: BV -> BV -> BV
bv@BV{size=n} <<. (BV _ k)
| k >= fromIntegral n  = BV n 0
| otherwise            = bv `shiftL` (fromIntegral k)
{-# INLINE (<<.) #-}

-- | Logical right shift.
--
(>>.) :: BV -> BV -> BV
bv@BV{size=n} >>. (BV _ k)
| k >= fromIntegral n  = BV n 0
| otherwise            = bv `shiftR` (fromIntegral k)
{-# INLINE (>>.) #-}

-- | Arithmetic right shift
--
ashr :: BV -> BV -> BV
ashr u v | msb u     = not_ ((not_ u) >>. v)
| otherwise = u >>. v

-- | Rotate left.
--
(<<<.) :: BV -> BV -> BV
bv@BV{size=n} <<<. (BV _ k)
| k >= n'   = bv `rotateL` (fromIntegral \$ k `mod` n')
| otherwise = bv `rotateL` (fromIntegral k)
where n' = fromIntegral n
{-# INLINE (<<<.) #-}

-- | Rotate right.
--
(>>>.) :: BV -> BV -> BV
bv@BV{size=n} >>>. (BV _ k)
| k >= n'   = bv `rotateR` (fromIntegral \$ k `mod` n')
| otherwise = bv `rotateR` (fromIntegral k)
where n' = fromIntegral n
{-# INLINE (>>>.) #-}

----------------------------------------------------------------------
--- Conversion

-- | Create a bit-vector from a big-endian list of bits.
--
-- >>> fromBits [False, False, True]
-- [3]1
--
fromBits :: [Bool] -> BV
fromBits bs = BV n \$ snd \$ foldr go (1,0) bs
where n = length bs
go b (!v,!acc) | b         = (v',acc+v)
| otherwise = (v',acc)
where v' = 2*v

-- | Create a big-endian list of bits from a bit-vector.
--
-- >>> toBits [4]11
-- [True, False, True, True]
--
toBits :: BV -> [Bool]
toBits (BV n a) = map (testBit a) [n-1,n-2..0]

----------------------------------------------------------------------
--- Utilities

-- | Greatest natural number representable with /n/ bits.
--
maxNat :: Integral a => Int -> a
maxNat n = 2^n - 1
{-# INLINE maxNat #-}

-- | Minimum width of a bit-vector to represent a given integer number.
--
-- >>> integerWith 4
-- 3
--
-- >>> integerWith (-4)
-- 4
--
integerWidth :: Integer -> Int
integerWidth !n
| n >= 0    = go 1 1
| otherwise = 1 + integerWidth (abs n)
where go !k !k_max | k_max >= n = k
| otherwise  = go (k+1) (2*k_max+1)
{-# INLINE integerWidth #-}
```