{-# LANGUAGE ConstraintKinds     #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}
{-# LANGUAGE ViewPatterns        #-}
module Data.Array.Accelerate.Data.Bits (
  Bits(..),
  FiniteBits(..),
) where
import Data.Array.Accelerate.Array.Sugar
import Data.Array.Accelerate.Language
import Data.Array.Accelerate.Smart
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Classes.Eq
import Data.Array.Accelerate.Classes.Ord
import Data.Array.Accelerate.Classes.Num
import Data.Array.Accelerate.Classes.Integral                       ()
import Prelude                                                      ( ($), undefined, otherwise )
import qualified Data.Bits                                          as B
infixl 8 `shift`, `rotate`, `shiftL`, `shiftR`, `rotateL`, `rotateR`
infixl 7 .&.
infixl 6 `xor`
infixl 5 .|.
class Eq a => Bits a where
  {-# MINIMAL (.&.), (.|.), xor, complement,
              (shift | (shiftL, shiftR)),
              (rotate | (rotateL, rotateR)),
              isSigned, testBit, bit, popCount #-}
  
  (.&.)         :: Exp a -> Exp a -> Exp a
  
  (.|.)         :: Exp a -> Exp a -> Exp a
  
  xor           :: Exp a -> Exp a -> Exp a
  
  complement    :: Exp a -> Exp a
  
  
  
  
  shift         :: Exp a -> Exp Int -> Exp a
  shift x i
    = cond (i < 0) (x `shiftR` (-i))
    $ cond (i > 0) (x `shiftL` i)
    $ x
  
  
  rotate        :: Exp a -> Exp Int -> Exp a
  rotate x i
    = cond (i < 0) (x `rotateR` (-i))
    $ cond (i > 0) (x `rotateL` i)
    $ x
  
  zeroBits      :: Exp a
  zeroBits = clearBit (bit 0) 0
  
  bit           :: Exp Int -> Exp a
  
  setBit        :: Exp a -> Exp Int -> Exp a
  setBit x i = x .|. bit i
  
  clearBit      :: Exp a -> Exp Int -> Exp a
  clearBit x i = x .&. complement (bit i)
  
  complementBit :: Exp a -> Exp Int -> Exp a
  complementBit x i = x `xor` bit i
  
  testBit       :: Exp a -> Exp Int -> Exp Bool
  
  isSigned      :: Exp a -> Exp Bool
  
  
  shiftL        :: Exp a -> Exp Int -> Exp a
  shiftL x i = x `shift` i
  
  
  
  unsafeShiftL  :: Exp a -> Exp Int -> Exp a
  unsafeShiftL = shiftL
  
  
  
  
  
  shiftR        :: Exp a -> Exp Int -> Exp a
  shiftR x i = x `shift` (-i)
  
  
  
  unsafeShiftR  :: Exp a -> Exp Int -> Exp a
  unsafeShiftR = shiftR
  
  
  rotateL       :: Exp a -> Exp Int -> Exp a
  rotateL x i = x `rotate` i
  
  rotateR       :: Exp a -> Exp Int -> Exp a
  rotateR x i = x `rotate` (-i)
  
  
  popCount      :: Exp a -> Exp Int
class Bits b => FiniteBits b where
  
  finiteBitSize :: Exp b -> Exp Int
  
  
  
  
  
  countLeadingZeros :: Exp b -> Exp Int
  
  
  
  
  
  
  
  countTrailingZeros :: Exp b -> Exp Int
instance Bits Bool where
  (.&.)        = (&&)
  (.|.)        = (||)
  xor          = (/=)
  complement   = not
  shift x i    = cond (i == 0) x (constant False)
  testBit x i  = cond (i == 0) x (constant False)
  rotate x _   = x
  bit i        = i == 0
  isSigned     = isSignedDefault
  popCount     = mkBoolToInt
instance Bits Int where
  (.&.)        = mkBAnd
  (.|.)        = mkBOr
  xor          = mkBXor
  complement   = mkBNot
  bit          = bitDefault
  testBit      = testBitDefault
  shift        = shiftDefault
  shiftL       = shiftLDefault
  shiftR       = shiftRDefault
  unsafeShiftL = mkBShiftL
  unsafeShiftR = mkBShiftR
  rotate       = rotateDefault
  rotateL      = rotateLDefault
  rotateR      = rotateRDefault
  isSigned     = isSignedDefault
  popCount     = mkPopCount
instance Bits Int8 where
  (.&.)        = mkBAnd
  (.|.)        = mkBOr
  xor          = mkBXor
  complement   = mkBNot
  bit          = bitDefault
  testBit      = testBitDefault
  shift        = shiftDefault
  shiftL       = shiftLDefault
  shiftR       = shiftRDefault
  unsafeShiftL = mkBShiftL
  unsafeShiftR = mkBShiftR
  rotate       = rotateDefault
  rotateL      = rotateLDefault
  rotateR      = rotateRDefault
  isSigned     = isSignedDefault
  popCount     = mkPopCount
instance Bits Int16 where
  (.&.)        = mkBAnd
  (.|.)        = mkBOr
  xor          = mkBXor
  complement   = mkBNot
  bit          = bitDefault
  testBit      = testBitDefault
  shift        = shiftDefault
  shiftL       = shiftLDefault
  shiftR       = shiftRDefault
  unsafeShiftL = mkBShiftL
  unsafeShiftR = mkBShiftR
  rotate       = rotateDefault
  rotateL      = rotateLDefault
  rotateR      = rotateRDefault
  isSigned     = isSignedDefault
  popCount     = mkPopCount
instance Bits Int32 where
  (.&.)        = mkBAnd
  (.|.)        = mkBOr
  xor          = mkBXor
  complement   = mkBNot
  bit          = bitDefault
  testBit      = testBitDefault
  shift        = shiftDefault
  shiftL       = shiftLDefault
  shiftR       = shiftRDefault
  unsafeShiftL = mkBShiftL
  unsafeShiftR = mkBShiftR
  rotate       = rotateDefault
  rotateL      = rotateLDefault
  rotateR      = rotateRDefault
  isSigned     = isSignedDefault
  popCount     = mkPopCount
instance Bits Int64 where
  (.&.)        = mkBAnd
  (.|.)        = mkBOr
  xor          = mkBXor
  complement   = mkBNot
  bit          = bitDefault
  testBit      = testBitDefault
  shift        = shiftDefault
  shiftL       = shiftLDefault
  shiftR       = shiftRDefault
  unsafeShiftL = mkBShiftL
  unsafeShiftR = mkBShiftR
  rotate       = rotateDefault
  rotateL      = rotateLDefault
  rotateR      = rotateRDefault
  isSigned     = isSignedDefault
  popCount     = mkPopCount
instance Bits Word where
  (.&.)        = mkBAnd
  (.|.)        = mkBOr
  xor          = mkBXor
  complement   = mkBNot
  bit          = bitDefault
  testBit      = testBitDefault
  shift        = shiftDefault
  shiftL       = shiftLDefault
  shiftR       = shiftRDefault
  unsafeShiftL = mkBShiftL
  unsafeShiftR = mkBShiftR
  rotate       = rotateDefault
  rotateL      = rotateLDefault
  rotateR      = rotateRDefault
  isSigned     = isSignedDefault
  popCount     = mkPopCount
instance Bits Word8 where
  (.&.)        = mkBAnd
  (.|.)        = mkBOr
  xor          = mkBXor
  complement   = mkBNot
  bit          = bitDefault
  testBit      = testBitDefault
  shift        = shiftDefault
  shiftL       = shiftLDefault
  shiftR       = shiftRDefault
  unsafeShiftL = mkBShiftL
  unsafeShiftR = mkBShiftR
  rotate       = rotateDefault
  rotateL      = rotateLDefault
  rotateR      = rotateRDefault
  isSigned     = isSignedDefault
  popCount     = mkPopCount
instance Bits Word16 where
  (.&.)        = mkBAnd
  (.|.)        = mkBOr
  xor          = mkBXor
  complement   = mkBNot
  bit          = bitDefault
  testBit      = testBitDefault
  shift        = shiftDefault
  shiftL       = shiftLDefault
  shiftR       = shiftRDefault
  unsafeShiftL = mkBShiftL
  unsafeShiftR = mkBShiftR
  rotate       = rotateDefault
  rotateL      = rotateLDefault
  rotateR      = rotateRDefault
  isSigned     = isSignedDefault
  popCount     = mkPopCount
instance Bits Word32 where
  (.&.)        = mkBAnd
  (.|.)        = mkBOr
  xor          = mkBXor
  complement   = mkBNot
  bit          = bitDefault
  testBit      = testBitDefault
  shift        = shiftDefault
  shiftL       = shiftLDefault
  shiftR       = shiftRDefault
  unsafeShiftL = mkBShiftL
  unsafeShiftR = mkBShiftR
  rotate       = rotateDefault
  rotateL      = rotateLDefault
  rotateR      = rotateRDefault
  isSigned     = isSignedDefault
  popCount     = mkPopCount
instance Bits Word64 where
  (.&.)        = mkBAnd
  (.|.)        = mkBOr
  xor          = mkBXor
  complement   = mkBNot
  bit          = bitDefault
  testBit      = testBitDefault
  shift        = shiftDefault
  shiftL       = shiftLDefault
  shiftR       = shiftRDefault
  unsafeShiftL = mkBShiftL
  unsafeShiftR = mkBShiftR
  rotate       = rotateDefault
  rotateL      = rotateLDefault
  rotateR      = rotateRDefault
  isSigned     = isSignedDefault
  popCount     = mkPopCount
instance Bits CInt where
  (.&.)        = mkBAnd
  (.|.)        = mkBOr
  xor          = mkBXor
  complement   = mkBNot
  bit          = bitDefault
  testBit      = testBitDefault
  shift        = shiftDefault
  shiftL       = shiftLDefault
  shiftR       = shiftRDefault
  unsafeShiftL = mkBShiftL
  unsafeShiftR = mkBShiftR
  rotate       = rotateDefault
  rotateL      = rotateLDefault
  rotateR      = rotateRDefault
  isSigned     = isSignedDefault
  popCount     = mkPopCount
instance Bits CUInt where
  (.&.)        = mkBAnd
  (.|.)        = mkBOr
  xor          = mkBXor
  complement   = mkBNot
  bit          = bitDefault
  testBit      = testBitDefault
  shift        = shiftDefault
  shiftL       = shiftLDefault
  shiftR       = shiftRDefault
  unsafeShiftL = mkBShiftL
  unsafeShiftR = mkBShiftR
  rotate       = rotateDefault
  rotateL      = rotateLDefault
  rotateR      = rotateRDefault
  isSigned     = isSignedDefault
  popCount     = mkPopCount
instance Bits CLong where
  (.&.)        = mkBAnd
  (.|.)        = mkBOr
  xor          = mkBXor
  complement   = mkBNot
  bit          = bitDefault
  testBit      = testBitDefault
  shift        = shiftDefault
  shiftL       = shiftLDefault
  shiftR       = shiftRDefault
  unsafeShiftL = mkBShiftL
  unsafeShiftR = mkBShiftR
  rotate       = rotateDefault
  rotateL      = rotateLDefault
  rotateR      = rotateRDefault
  isSigned     = isSignedDefault
  popCount     = mkPopCount
instance Bits CULong where
  (.&.)        = mkBAnd
  (.|.)        = mkBOr
  xor          = mkBXor
  complement   = mkBNot
  bit          = bitDefault
  testBit      = testBitDefault
  shift        = shiftDefault
  shiftL       = shiftLDefault
  shiftR       = shiftRDefault
  unsafeShiftL = mkBShiftL
  unsafeShiftR = mkBShiftR
  rotate       = rotateDefault
  rotateL      = rotateLDefault
  rotateR      = rotateRDefault
  isSigned     = isSignedDefault
  popCount     = mkPopCount
instance Bits CLLong where
  (.&.)        = mkBAnd
  (.|.)        = mkBOr
  xor          = mkBXor
  complement   = mkBNot
  bit          = bitDefault
  testBit      = testBitDefault
  shift        = shiftDefault
  shiftL       = shiftLDefault
  shiftR       = shiftRDefault
  unsafeShiftL = mkBShiftL
  unsafeShiftR = mkBShiftR
  rotate       = rotateDefault
  rotateL      = rotateLDefault
  rotateR      = rotateRDefault
  isSigned     = isSignedDefault
  popCount     = mkPopCount
instance Bits CULLong where
  (.&.)        = mkBAnd
  (.|.)        = mkBOr
  xor          = mkBXor
  complement   = mkBNot
  bit          = bitDefault
  testBit      = testBitDefault
  shift        = shiftDefault
  shiftL       = shiftLDefault
  shiftR       = shiftRDefault
  unsafeShiftL = mkBShiftL
  unsafeShiftR = mkBShiftR
  rotate       = rotateDefault
  rotateL      = rotateLDefault
  rotateR      = rotateRDefault
  isSigned     = isSignedDefault
  popCount     = mkPopCount
instance Bits CShort where
  (.&.)        = mkBAnd
  (.|.)        = mkBOr
  xor          = mkBXor
  complement   = mkBNot
  bit          = bitDefault
  testBit      = testBitDefault
  shift        = shiftDefault
  shiftL       = shiftLDefault
  shiftR       = shiftRDefault
  unsafeShiftL = mkBShiftL
  unsafeShiftR = mkBShiftR
  rotate       = rotateDefault
  rotateL      = rotateLDefault
  rotateR      = rotateRDefault
  isSigned     = isSignedDefault
  popCount     = mkPopCount
instance Bits CUShort where
  (.&.)        = mkBAnd
  (.|.)        = mkBOr
  xor          = mkBXor
  complement   = mkBNot
  bit          = bitDefault
  testBit      = testBitDefault
  shift        = shiftDefault
  shiftL       = shiftLDefault
  shiftR       = shiftRDefault
  unsafeShiftL = mkBShiftL
  unsafeShiftR = mkBShiftR
  rotate       = rotateDefault
  rotateL      = rotateLDefault
  rotateR      = rotateRDefault
  isSigned     = isSignedDefault
  popCount     = mkPopCount
instance FiniteBits Bool where
  finiteBitSize _      = constant 8 
  countLeadingZeros  x = cond x 0 1
  countTrailingZeros x = cond x 0 1
instance FiniteBits Int where
  finiteBitSize _    = constant (B.finiteBitSize (undefined::Int))
  countLeadingZeros  = mkCountLeadingZeros
  countTrailingZeros = mkCountTrailingZeros
instance FiniteBits Int8 where
  finiteBitSize _    = constant (B.finiteBitSize (undefined::Int8))
  countLeadingZeros  = mkCountLeadingZeros
  countTrailingZeros = mkCountTrailingZeros
instance FiniteBits Int16 where
  finiteBitSize _    = constant (B.finiteBitSize (undefined::Int16))
  countLeadingZeros  = mkCountLeadingZeros
  countTrailingZeros = mkCountTrailingZeros
instance FiniteBits Int32 where
  finiteBitSize _    = constant (B.finiteBitSize (undefined::Int32))
  countLeadingZeros  = mkCountLeadingZeros
  countTrailingZeros = mkCountTrailingZeros
instance FiniteBits Int64 where
  finiteBitSize _    = constant (B.finiteBitSize (undefined::Int64))
  countLeadingZeros  = mkCountLeadingZeros
  countTrailingZeros = mkCountTrailingZeros
instance FiniteBits Word where
  finiteBitSize _    = constant (B.finiteBitSize (undefined::Word))
  countLeadingZeros  = mkCountLeadingZeros
  countTrailingZeros = mkCountTrailingZeros
instance FiniteBits Word8 where
  finiteBitSize _    = constant (B.finiteBitSize (undefined::Word8))
  countLeadingZeros  = mkCountLeadingZeros
  countTrailingZeros = mkCountTrailingZeros
instance FiniteBits Word16 where
  finiteBitSize _    = constant (B.finiteBitSize (undefined::Word16))
  countLeadingZeros  = mkCountLeadingZeros
  countTrailingZeros = mkCountTrailingZeros
instance FiniteBits Word32 where
  finiteBitSize _    = constant (B.finiteBitSize (undefined::Word32))
  countLeadingZeros  = mkCountLeadingZeros
  countTrailingZeros = mkCountTrailingZeros
instance FiniteBits Word64 where
  finiteBitSize _    = constant (B.finiteBitSize (undefined::Word64))
  countLeadingZeros  = mkCountLeadingZeros
  countTrailingZeros = mkCountTrailingZeros
instance FiniteBits CInt where
  finiteBitSize _    = constant (B.finiteBitSize (undefined::CInt))
  countLeadingZeros  = mkCountLeadingZeros
  countTrailingZeros = mkCountTrailingZeros
instance FiniteBits CUInt where
  finiteBitSize _    = constant (B.finiteBitSize (undefined::CUInt))
  countLeadingZeros  = mkCountLeadingZeros
  countTrailingZeros = mkCountTrailingZeros
instance FiniteBits CLong where
  finiteBitSize _    = constant (B.finiteBitSize (undefined::CLong))
  countLeadingZeros  = mkCountLeadingZeros
  countTrailingZeros = mkCountTrailingZeros
instance FiniteBits CULong where
  finiteBitSize _    = constant (B.finiteBitSize (undefined::CULong))
  countLeadingZeros  = mkCountLeadingZeros
  countTrailingZeros = mkCountTrailingZeros
instance FiniteBits CLLong where
  finiteBitSize _    = constant (B.finiteBitSize (undefined::CLLong))
  countLeadingZeros  = mkCountLeadingZeros
  countTrailingZeros = mkCountTrailingZeros
instance FiniteBits CULLong where
  finiteBitSize _    = constant (B.finiteBitSize (undefined::CULLong))
  countLeadingZeros  = mkCountLeadingZeros
  countTrailingZeros = mkCountTrailingZeros
instance FiniteBits CShort where
  finiteBitSize _    = constant (B.finiteBitSize (undefined::CShort))
  countLeadingZeros  = mkCountLeadingZeros
  countTrailingZeros = mkCountTrailingZeros
instance FiniteBits CUShort where
  finiteBitSize _    = constant (B.finiteBitSize (undefined::CUShort))
  countLeadingZeros  = mkCountLeadingZeros
  countTrailingZeros = mkCountTrailingZeros
bitDefault :: (IsIntegral t, Bits t) => Exp Int -> Exp t
bitDefault x = constant 1 `shiftL` x
testBitDefault :: (IsIntegral t, Bits t) => Exp t -> Exp Int -> Exp Bool
testBitDefault x i = (x .&. bit i) /= constant 0
shiftDefault :: (FiniteBits t, IsIntegral t, B.Bits t) => Exp t -> Exp Int -> Exp t
shiftDefault x i
  = cond (i >= 0) (shiftLDefault x i)
                  (shiftRDefault x (-i))
shiftLDefault :: (FiniteBits t, IsIntegral t) => Exp t -> Exp Int -> Exp t
shiftLDefault x i
  = cond (i >= finiteBitSize x) (constant 0)
  $ mkBShiftL x i
shiftRDefault :: forall t. (B.Bits t, FiniteBits t, IsIntegral t) => Exp t -> Exp Int -> Exp t
shiftRDefault
  | B.isSigned (undefined::t) = shiftRADefault
  | otherwise                 = shiftRLDefault
shiftRADefault :: (FiniteBits t, IsIntegral t) => Exp t -> Exp Int -> Exp t
shiftRADefault x i
  = cond (i >= finiteBitSize x) (cond (mkLt x (constant 0)) (constant (-1)) (constant 0))
  $ mkBShiftR x i
shiftRLDefault :: (FiniteBits t, IsIntegral t) => Exp t -> Exp Int -> Exp t
shiftRLDefault x i
  = cond (i >= finiteBitSize x) (constant 0)
  $ mkBShiftR x i
rotateDefault :: forall t. (FiniteBits t, IsIntegral t) => Exp t -> Exp Int -> Exp t
rotateDefault =
  case (integralType :: IntegralType t) of
    TypeInt{}     -> rotateDefault' (undefined::Word)
    TypeInt8{}    -> rotateDefault' (undefined::Word8)
    TypeInt16{}   -> rotateDefault' (undefined::Word16)
    TypeInt32{}   -> rotateDefault' (undefined::Word32)
    TypeInt64{}   -> rotateDefault' (undefined::Word64)
    TypeWord{}    -> rotateDefault' (undefined::Word)
    TypeWord8{}   -> rotateDefault' (undefined::Word8)
    TypeWord16{}  -> rotateDefault' (undefined::Word16)
    TypeWord32{}  -> rotateDefault' (undefined::Word32)
    TypeWord64{}  -> rotateDefault' (undefined::Word64)
    TypeCShort{}  -> rotateDefault' (undefined::CUShort)
    TypeCUShort{} -> rotateDefault' (undefined::CUShort)
    TypeCInt{}    -> rotateDefault' (undefined::CUInt)
    TypeCUInt{}   -> rotateDefault' (undefined::CUInt)
    TypeCLong{}   -> rotateDefault' (undefined::CULong)
    TypeCULong{}  -> rotateDefault' (undefined::CULong)
    TypeCLLong{}  -> rotateDefault' (undefined::CULLong)
    TypeCULLong{} -> rotateDefault' (undefined::CULLong)
rotateDefault'
    :: forall i w. (Elt w, FiniteBits i, IsIntegral i, IsIntegral w, IsIntegral (EltRepr i), IsIntegral (EltRepr w), BitSizeEq (EltRepr i) (EltRepr w), BitSizeEq (EltRepr w) (EltRepr i))
    => w 
    -> Exp i
    -> Exp Int
    -> Exp i
rotateDefault' _ x i
  = cond (i' == 0) x
  $ w2i ((x' `mkBShiftL` i') `mkBOr` (x' `mkBShiftR` (wsib - i')))
  where
    w2i  = mkBitcast :: Exp w -> Exp i
    i2w  = mkBitcast :: Exp i -> Exp w
    
    x'   = i2w x
    i'   = i `mkBAnd` (wsib - 1)
    wsib = finiteBitSize x
rotateLDefault :: (Elt t, IsIntegral t) => Exp t -> Exp Int -> Exp t
rotateLDefault x i
  = cond (i == 0) x
  $ mkBRotateL x i
rotateRDefault :: (Elt t, IsIntegral t) => Exp t -> Exp Int -> Exp t
rotateRDefault x i
  = cond (i == 0) x
  $ mkBRotateR x i
isSignedDefault :: forall b. B.Bits b => Exp b -> Exp Bool
isSignedDefault _ = constant (B.isSigned (undefined::b))