{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}

{-|
Module      : Data.BitVector.Sized.Internal
Copyright   : (c) Benjamin Selfridge, 2018
                  Galois Inc.
License     : BSD3
Maintainer  : benselfridge@galois.com
Stability   : experimental
Portability : portable

This module defines a width-parameterized 'BitVector' type and various associated
operations that assume a 2's complement representation.
-}

module Data.BitVector.Sized
  ( -- * BitVector type
    BitVector(..)
  , bitVector
    -- * Bitwise operations (width-preserving)
    -- | These are alternative versions of some of the 'Bits' functions where we do
    -- not need to know the width at compile time. They are all width-preserving.
  , bvAnd, bvOr, bvXor
  , bvComplement
  , bvShift, bvShiftL, bvShiftRA, bvShiftRL,  bvRotate
  , bvWidth
  , bvTestBit
  , bvPopCount
  , bvTruncBits
    -- * Arithmetic operations (width-preserving)
  , bvAdd, bvMul
  , bvAbs, bvNegate
  , bvSignum
  , bvLTS, bvLTU
    -- * Variable-width operations
    -- | These are functions that involve bit vectors of different lengths.
  , bvConcat, (<:>)
  , bvExtract, bvExtractWithRepr
  , bvZext, bvZextWithRepr
  , bvSext, bvSextWithRepr
  , bvMulFU, bvMulFS
    -- * Conversions to Integer
  , bvIntegerU
  , bvIntegerS
  ) where

import Data.Bits
import Data.Parameterized.Classes
import Data.Parameterized.NatRepr
import GHC.TypeLits
import System.Random
import Test.QuickCheck (Arbitrary(..), choose)
import Text.Printf
import Unsafe.Coerce (unsafeCoerce)
----------------------------------------
-- BitVector data type definitions

-- | BitVector datatype, parameterized by width.
data BitVector (w :: Nat) :: * where
  BV :: NatRepr w -> Integer -> BitVector w

-- | Construct a bit vector in a context where the width is inferrable from the type
-- context. The 'Integer' input (an unbounded data type, hence with an infinite-width
-- bit representation), whether positive or negative is silently truncated to fit
-- into the number of bits demanded by the return type.
--
-- >>> bitVector 0xA :: BitVector 4
-- 0xa<4>
-- >>> 0xA :: BitVector 4
-- >>> 0xA :: BitVector 3
-- 0x2<3>
-- >>> (-1) :: BitVector 8
-- 0xff<8>
-- >>> (-1) :: BitVector 32
-- 0xffffffff<32>
bitVector :: KnownNat w => Integer -> BitVector w
bitVector x = BV wRepr (truncBits width (fromIntegral x))
  where wRepr = knownNat
        width = natValue wRepr

----------------------------------------
-- BitVector -> Integer functions

-- | Unsigned interpretation of a bit vector as a (positive) Integer.
bvIntegerU :: BitVector w -> Integer
bvIntegerU (BV _ x) = x

-- | Signed interpretation of a bit vector as an Integer.
bvIntegerS :: BitVector w -> Integer
bvIntegerS bv = case bvTestBit bv (width - 1) of
  True  -> bvIntegerU bv - (1 `shiftL` width)
  False -> bvIntegerU bv
  where width = bvWidth bv

----------------------------------------
-- BitVector w operations (fixed width)

-- | Bitwise and.
bvAnd :: BitVector w -> BitVector w -> BitVector w
bvAnd (BV wRepr x) (BV _ y) = BV wRepr (x .&. y)

-- | Bitwise or.
bvOr :: BitVector w -> BitVector w -> BitVector w
bvOr (BV wRepr x) (BV _ y) = BV wRepr (x .|. y)

-- | Bitwise xor.
bvXor :: BitVector w -> BitVector w -> BitVector w
bvXor (BV wRepr x) (BV _ y) = BV wRepr (x `xor` y)

-- | Bitwise complement (flip every bit).
bvComplement :: BitVector w -> BitVector w
bvComplement (BV wRepr x) = BV wRepr (truncBits width (complement x))
  where width = natValue wRepr

-- | Bitwise shift.
bvShift :: BitVector w -> Int -> BitVector w
bvShift bv@(BV wRepr _) shf = BV wRepr (truncBits width (x `shift` shf))
  where width = natValue wRepr
        x     = bvIntegerS bv -- arithmetic right shift when negative

toPos :: Int -> Int
toPos x | x < 0 = 0
toPos x = x

-- | Left shift.
bvShiftL :: BitVector w -> Int -> BitVector w
bvShiftL bv shf = bvShift bv (toPos shf)

-- | Right arithmetic shift.
bvShiftRA :: BitVector w -> Int -> BitVector w
bvShiftRA bv shf = bvShift bv (- (toPos shf))

-- | Right logical shift.
bvShiftRL :: BitVector w -> Int -> BitVector w
bvShiftRL bv@(BV wRepr _) shf = BV wRepr (truncBits width (x `shift` toPos shf))
  where width = natValue wRepr
        x     = bvIntegerU bv

-- | Bitwise rotate.
bvRotate :: BitVector w -> Int -> BitVector w
bvRotate bv rot' = leftChunk `bvOr` rightChunk
  where rot = rot' `mod` (bvWidth bv)
        leftChunk = bvShift bv rot
        rightChunk = bvShift bv (rot - bvWidth bv)

-- | Get the width of a 'BitVector'.
bvWidth :: BitVector w -> Int
bvWidth (BV wRepr _) = fromIntegral (natValue wRepr)

-- | Test if a particular bit is set.
bvTestBit :: BitVector w -> Int -> Bool
bvTestBit (BV _ x) b = testBit x b

-- | Get the number of 1 bits in a 'BitVector'.
bvPopCount :: BitVector w -> Int
bvPopCount (BV _ x) = popCount x

-- | Truncate a bit vector to a particular width given at runtime, while keeping the
-- type-level width constant.
bvTruncBits :: BitVector w -> Int -> BitVector w
bvTruncBits (BV wRepr x) b = BV wRepr (truncBits b x)

----------------------------------------
-- BitVector w arithmetic operations (fixed width)

-- | Bitwise add.
bvAdd :: BitVector w -> BitVector w -> BitVector w
bvAdd (BV wRepr x) (BV _ y) = BV wRepr (truncBits width (x + y))
  where width = natValue wRepr

-- | Bitwise multiply.
bvMul :: BitVector w -> BitVector w -> BitVector w
bvMul (BV wRepr x) (BV _ y) = BV wRepr (truncBits width (x * y))
  where width = natValue wRepr

-- | Bitwise absolute value.
bvAbs :: BitVector w -> BitVector w
bvAbs bv@(BV wRepr _) = BV wRepr abs_x
  where width = natValue wRepr
        x     = bvIntegerS bv
        abs_x = truncBits width (abs x) -- this is necessary

-- | Bitwise negation.
bvNegate :: BitVector w -> BitVector w
bvNegate (BV wRepr x) = BV wRepr (truncBits width (-x))
  where width = fromIntegral (natValue wRepr) :: Integer

-- | Get the sign bit as a 'BitVector'.
bvSignum :: BitVector w -> BitVector w
bvSignum bv@(BV wRepr _) = (bvShift bv (1 - width)) `bvAnd` (BV wRepr 0x1)
  where width = fromIntegral (natValue wRepr)

-- | Signed less than.
bvLTS :: BitVector w -> BitVector w -> Bool
bvLTS bv1 bv2 = bvIntegerS bv1 < bvIntegerS bv2

-- | Unsigned less than.
bvLTU :: BitVector w -> BitVector w -> Bool
bvLTU bv1 bv2 = bvIntegerU bv1 < bvIntegerU bv2

----------------------------------------
-- Width-changing operations

-- | Concatenate two bit vectors.
--
-- >>> (0xAA :: BitVector 8 `bvConcat` 0xBCDEF0 :: BitVector 24)
-- 0xaabcdef0<32>
-- >>> :type it
-- it :: BitVector 32
--
-- Note that the first argument gets placed in the higher-order bits. The above
-- example should be illustrative enough.
bvConcat :: BitVector v -> BitVector w -> BitVector (v+w)
bvConcat (BV hiWRepr hi) (BV loWRepr lo) =
  BV (hiWRepr `addNat` loWRepr) ((hi `shiftL` loWidth) .|. lo)
  where loWidth = fromIntegral (natValue loWRepr)

-- | Infix 'bvConcat'.
(<:>) :: BitVector v -> BitVector w -> BitVector (v+w)
(<:>) = bvConcat

infixl 6 <:>

-- | Slice out a smaller bit vector from a larger one. The lowest significant bit is
-- given explicitly as an argument of type 'Int', and the length of the slice is
-- inferred from a type-level context.
--
-- >>> bvExtract 12 (0xAABCDEF0 :: BitVector 32) :: BitVector 8
-- 0xcd<8>
--
-- Note that 'bvExtract' does not do any bounds checking whatsoever; if you try and
-- extract bits that aren't present in the input, you will get 0's.
bvExtract :: forall w w' . (KnownNat w')
          => Int
          -> BitVector w
          -> BitVector w'
bvExtract pos bv = bitVector xShf
  where (BV _ xShf) = bvShift bv (- pos)

-- | Unconstrained variant of 'bvExtract' with an explicit 'NatRepr' argument.
bvExtractWithRepr :: NatRepr w'
                  -> Int
                  -> BitVector w
                  -> BitVector w'
bvExtractWithRepr repr pos bv = BV repr (truncBits width xShf)
  where (BV _ xShf) = bvShift bv (- pos)
        width = natValue repr

-- | Zero-extend a vector to one of greater length. If given an input of greater
-- length than the output type, this performs a truncation.
bvZext :: forall w w' . KnownNat w'
       => BitVector w
       -> BitVector w'
bvZext (BV _ x) = bitVector x

-- | Unconstrained variant of 'bvZext' with an explicit 'NatRepr' argument.
bvZextWithRepr :: NatRepr w'
               -> BitVector w
               -> BitVector w'
bvZextWithRepr repr (BV _ x) = BV repr (truncBits width x)
  where width = natValue repr

-- | Sign-extend a vector to one of greater length. If given an input of greater
-- length than the output type, this performs a truncation.
bvSext :: forall w w' . KnownNat w'
       => BitVector w
       -> BitVector w'
bvSext bv = bitVector (bvIntegerS bv)

-- | Unconstrained variant of 'bvSext' with an explicit 'NatRepr' argument.
bvSextWithRepr :: NatRepr w'
               -> BitVector w
               -> BitVector w'
bvSextWithRepr repr bv = BV repr (truncBits width (bvIntegerS bv))
  where width = natValue repr

-- | Fully multiply two bit vectors as unsigned integers, returning a bit vector
-- whose length is equal to the sum of the inputs.
bvMulFU :: BitVector w -> BitVector w' -> BitVector (w+w')
bvMulFU (BV wRepr x) (BV wRepr' y) = BV (wRepr `addNat` wRepr') (x*y)

-- | Fully multiply two bit vectors as signed integers, returning a bit vector whose
-- length is equal to the sum of the inputs.
bvMulFS :: BitVector w -> BitVector w' -> BitVector (w+w')
bvMulFS bv1@(BV wRepr _) bv2@(BV wRepr' _) = BV prodRepr (truncBits width (x'*y'))
  where x' = bvIntegerS bv1
        y' = bvIntegerS bv2
        prodRepr = wRepr `addNat` wRepr'
        width = natValue prodRepr

----------------------------------------
-- Class instances

instance Show (BitVector w) where
  show (BV wRepr val) = prettyHex width val
    where width = natValue wRepr

instance ShowF BitVector

instance Eq (BitVector w) where
  (BV _ x) == (BV _ y) = x == y

instance EqF BitVector where
  (BV _ x) `eqF` (BV _ y) = x == y

instance Ord (BitVector w) where
  (BV _ x) `compare` (BV _ y) = x `compare` y

instance TestEquality BitVector where
  testEquality (BV wRepr x) (BV wRepr' y) =
    case natValue wRepr == natValue wRepr' && x == y of
      True  -> Just (unsafeCoerce (Refl :: a :~: a))
      False -> Nothing

instance KnownNat w => Bits (BitVector w) where
  (.&.)        = bvAnd
  (.|.)        = bvOr
  xor          = bvXor
  complement   = bvComplement
  shift        = bvShift
  rotate       = bvRotate
  bitSize      = bvWidth
  bitSizeMaybe = Just . bvWidth
  isSigned     = const False
  testBit      = bvTestBit
  bit          = bitVector . bit
  popCount     = bvPopCount

instance KnownNat w => FiniteBits (BitVector w) where
  finiteBitSize = bvWidth

instance KnownNat w => Num (BitVector w) where
  (+)         = bvAdd
  (*)         = bvMul
  abs         = bvAbs
  signum      = bvSignum
  fromInteger = bitVector
  negate      = bvNegate

instance KnownNat w => Enum (BitVector w) where
  toEnum   = bitVector . fromIntegral
  fromEnum = fromIntegral . bvIntegerU

instance KnownNat w => Bounded (BitVector w) where
  minBound = bitVector 0
  maxBound = bitVector (-1)

instance KnownNat w => Arbitrary (BitVector w) where
  arbitrary = choose (minBound, maxBound)

instance KnownNat w => Random (BitVector w) where
  randomR (bvLo, bvHi) gen =
    let (x, gen') = randomR (bvIntegerU bvLo, bvIntegerU bvHi) gen
    in (bitVector x, gen')
  random gen =
    let (x, gen') = random gen
    in (bitVector x, gen')

----------------------------------------
-- UTILITIES

----------------------------------------
-- Pretty Printing

-- | Print an integral value in hex with a leading "0x"
prettyHex :: (Integral a, PrintfArg a, Show a) => a -> Integer -> String
prettyHex width val = printf format val width
  where numDigits = (width+3) `div` 4
        format = "0x%." ++ show numDigits ++ "x<%d>"

----------------------------------------
-- Bits

-- | Mask for a specified number of lower bits.
lowMask :: (Integral a, Bits b) => a -> b
lowMask numBits = complement (complement zeroBits `shiftL` fromIntegral numBits)

-- | Truncate to a specified number of lower bits.
truncBits :: (Integral a, Bits b) => a -> b -> b
truncBits width b = b .&. lowMask width