{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE Unsafe #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise       #-}
{-# OPTIONS_HADDOCK show-extensions not-home #-}
module Clash.Sized.Internal.BitVector
  ( 
    Bit (..)
    
  , high
  , low
    
    
  , eq##
  , neq##
    
  , lt##
  , ge##
  , gt##
  , le##
    
  , fromInteger##
    
  , and##
  , or##
  , xor##
  , complement##
    
  , pack#
  , unpack#
    
  , BitVector (..)
    
  , size#
  , maxIndex#
    
  , bLit
  , undefined#
    
  , (++#)
    
  , reduceAnd#
  , reduceOr#
  , reduceXor#
    
  , index#
  , replaceBit#
  , setSlice#
  , slice#
  , split#
  , msb#
  , lsb#
    
    
  , eq#
  , neq#
  , isLike
    
  , lt#
  , ge#
  , gt#
  , le#
    
  , enumFrom#
  , enumFromThen#
  , enumFromTo#
  , enumFromThenTo#
    
  , minBound#
  , maxBound#
    
  , (+#)
  , (-#)
  , (*#)
  , negate#
  , fromInteger#
    
  , plus#
  , minus#
  , times#
    
  , quot#
  , rem#
  , toInteger#
    
  , and#
  , or#
  , xor#
  , complement#
  , shiftL#
  , shiftR#
  , rotateL#
  , rotateR#
  , popCountBV
    
  , countLeadingZerosBV
  , countTrailingZerosBV
    
  , truncateB#
    
  , shrinkSizedUnsigned
  
  , undefError
  , checkUnpackUndef
  , bitPattern
  )
where
import Control.DeepSeq            (NFData (..))
import Control.Lens               (Index, Ixed (..), IxValue)
import Data.Bits                  (Bits (..), FiniteBits (..))
import Data.Data                  (Data)
import Data.Default.Class         (Default (..))
import Data.Either                (isLeft)
import Data.Proxy                 (Proxy (..))
import Data.Typeable              (Typeable, typeOf)
import GHC.Generics               (Generic)
import Data.Maybe                 (fromMaybe)
import GHC.Exts
  (Word#, Word (W#), eqWord#, int2Word#, uncheckedShiftRL#)
import qualified GHC.Exts
import GHC.Integer.GMP.Internals  (Integer (..), bigNatToWord, shiftRBigNat)
import GHC.Natural
  (Natural (..), naturalFromInteger, wordToNatural)
#if MIN_VERSION_base(4,12,0)
import GHC.Natural                (naturalToInteger)
#endif
import GHC.Prim                   (dataToTag#)
import GHC.Stack                  (HasCallStack, withFrozenCallStack)
import GHC.TypeLits               (KnownNat, Nat, type (+), type (-), natVal)
import GHC.TypeLits.Extra         (Max)
import Language.Haskell.TH        (Q, TExp, TypeQ, appT, conT, litT, numTyLit, sigE, Lit(..), litE, Pat, litP)
import Language.Haskell.TH.Syntax (Lift(..))
import Test.QuickCheck.Arbitrary  (Arbitrary (..), CoArbitrary (..),
                                   arbitraryBoundedIntegral,
                                   coarbitraryIntegral, shrinkIntegral)
import Clash.Class.Num            (ExtendingNum (..), SaturatingNum (..),
                                   SaturationMode (..))
import Clash.Class.Resize         (Resize (..))
import Clash.Promoted.Nat
  (SNat (..), SNatLE (..), compareSNat, snatToInteger, snatToNum, natToNum)
import Clash.XException
  (ShowX (..), NFDataX (..), errorX, isX, showsPrecXWith, rwhnfX)
import Clash.Sized.Internal.Mod
import {-# SOURCE #-} qualified Clash.Sized.Vector         as V
import {-# SOURCE #-} qualified Clash.Sized.Internal.Index as I
import                qualified Data.List                  as L
data BitVector (n :: Nat) =
    
    
    BV { unsafeMask      :: !Natural
       , unsafeToNatural :: !Natural
       }
  deriving (Data, Generic)
data Bit =
  
  
  Bit { unsafeMask#      :: {-# unpack #-} !Word
      , unsafeToInteger# :: {-# unpack #-} !Word
      }
  deriving (Data, Generic)
{-# NOINLINE high #-}
high :: Bit
high = Bit 0 1
{-# NOINLINE low #-}
low :: Bit
low = Bit 0 0
instance NFData Bit where
  rnf (Bit m i) = rnf m `seq` rnf i `seq` ()
  {-# NOINLINE rnf #-}
instance Show Bit where
  show (Bit 0 b) =
    case testBit b 0 of
      True  -> "1"
      False -> "0"
  show (Bit _ _) = "."
instance ShowX Bit where
  showsPrecX = showsPrecXWith showsPrec
instance NFDataX Bit where
  deepErrorX = errorX
  rnfX = rwhnfX
  hasUndefined bv = isLeft (isX bv) || unsafeMask# bv /= 0
instance Lift Bit where
  lift (Bit m i) = [| fromInteger## $(litE (WordPrimL (toInteger m))) i |]
  {-# NOINLINE lift #-}
instance Eq Bit where
  (==) = eq##
  (/=) = neq##
eq## :: Bit -> Bit -> Bool
eq## b1 b2 = eq# (pack# b1) (pack# b2)
{-# NOINLINE eq## #-}
neq## :: Bit -> Bit -> Bool
neq## b1 b2 = neq# (pack# b1) (pack# b2)
{-# NOINLINE neq## #-}
instance Ord Bit where
  (<)  = lt##
  (<=) = le##
  (>)  = gt##
  (>=) = ge##
lt##,ge##,gt##,le## :: Bit -> Bit -> Bool
lt## b1 b2 = lt# (pack# b1) (pack# b2)
{-# NOINLINE lt## #-}
ge## b1 b2 = ge# (pack# b1) (pack# b2)
{-# NOINLINE ge## #-}
gt## b1 b2 = gt# (pack# b1) (pack# b2)
{-# NOINLINE gt## #-}
le## b1 b2 = le# (pack# b1) (pack# b2)
{-# NOINLINE le## #-}
instance Enum Bit where
  toEnum     = fromInteger## 0## . toInteger
  fromEnum b = if eq## b low then 0 else 1
instance Bounded Bit where
  minBound = low
  maxBound = high
instance Default Bit where
  def = low
instance Num Bit where
  (+)         = xor##
  (-)         = xor##
  (*)         = and##
  negate      = complement##
  abs         = id
  signum b    = b
  fromInteger = fromInteger## 0##
fromInteger## :: Word# -> Integer -> Bit
fromInteger## m# i = Bit ((W# m#) `mod` 2) (fromInteger i `mod` 2)
{-# NOINLINE fromInteger## #-}
instance Real Bit where
  toRational b = if eq## b low then 0 else 1
instance Integral Bit where
  quot    a _ = a
  rem     _ _ = low
  div     a _ = a
  mod     _ _ = low
  quotRem n _ = (n,low)
  divMod  n _ = (n,low)
  toInteger b = if eq## b low then 0 else 1
instance Bits Bit where
  (.&.)             = and##
  (.|.)             = or##
  xor               = xor##
  complement        = complement##
  zeroBits          = low
  bit i             = if i == 0 then high else low
  setBit b i        = if i == 0 then high else b
  clearBit b i      = if i == 0 then low  else b
  complementBit b i = if i == 0 then complement## b else b
  testBit b i       = if i == 0 then eq## b high else False
  bitSizeMaybe _    = Just 1
  bitSize _         = 1
  isSigned _        = False
  shiftL b i        = if i == 0 then b else low
  shiftR b i        = if i == 0 then b else low
  rotateL b _       = b
  rotateR b _       = b
  popCount b        = if eq## b low then 0 else 1
instance FiniteBits Bit where
  finiteBitSize _      = 1
  countLeadingZeros b  = if eq## b low then 1 else 0
  countTrailingZeros b = if eq## b low then 1 else 0
and##, or##, xor## :: Bit -> Bit -> Bit
and## (Bit m1 v1) (Bit m2 v2) = Bit mask (v1 .&. v2 .&. complement mask)
  where mask = (m1.&.v2 .|. m1.&.m2 .|. m2.&.v1)
{-# NOINLINE and## #-}
or## (Bit m1 v1) (Bit m2 v2) = Bit mask ((v1 .|. v2) .&. complement mask)
  where mask = m1 .&. complement v2 .|.  m1.&.m2  .|.  m2 .&. complement v1
{-# NOINLINE or## #-}
xor## (Bit m1 v1) (Bit m2 v2) = Bit mask ((v1 `xor` v2) .&. complement mask)
  where mask = m1 .|. m2
{-# NOINLINE xor## #-}
complement## :: Bit -> Bit
complement## (Bit m v) = Bit m (complementB v .&. complementB m)
  where complementB (W# b#) = W# (int2Word# (eqWord# b# 0##))
{-# NOINLINE complement## #-}
pack# :: Bit -> BitVector 1
pack# (Bit (W# m) (W# b)) = BV (NatS# m) (NatS# b)
{-# NOINLINE pack# #-}
unpack# :: BitVector 1 -> Bit
unpack# (BV m b) = Bit (go m) (go b)
 where
  go (NatS# w) = W# w
  go (NatJ# w) = W# (bigNatToWord w)
{-# NOINLINE unpack# #-}
instance NFData (BitVector n) where
  rnf (BV i m) = rnf i `seq` rnf m `seq` ()
  {-# NOINLINE rnf #-}
  
  
instance KnownNat n => Show (BitVector n) where
  show bv@(BV msk i) = reverse . underScore . reverse $ showBV (natVal bv) msk i []
    where
      showBV 0 _ _ s = s
      showBV n m v s = let (v',vBit) = divMod v 2
                           (m',mBit) = divMod m 2
                       in  case (mBit,vBit) of
                           (0,0) -> showBV (n - 1) m' v' ('0':s)
                           (0,_) -> showBV (n - 1) m' v' ('1':s)
                           _     -> showBV (n - 1) m' v' ('.':s)
      underScore xs = case splitAt 5 xs of
                        ([a,b,c,d,e],rest) -> [a,b,c,d,'_'] ++ underScore (e:rest)
                        (rest,_)               -> rest
  {-# NOINLINE show #-}
instance KnownNat n => ShowX (BitVector n) where
  showsPrecX = showsPrecXWith showsPrec
instance NFDataX (BitVector n) where
  deepErrorX = errorX
  rnfX = rwhnfX
  hasUndefined bv = isLeft (isX bv) || unsafeMask bv /= 0
bLit :: forall n. KnownNat n => String -> Q (TExp (BitVector n))
bLit s = [|| fromInteger# m i1 ||]
  where
    bv :: BitVector n
    bv = read# s
    m,i :: Natural
    BV m i = bv
    i1 :: Integer
    i1 = toInteger i
read# :: KnownNat n => String -> BitVector n
read# cs = BV m v
  where
    (vs,ms) = unzip . map readBit . filter (/= '_') $ cs
    combineBits = foldl (\b a -> b*2+a) 0
    v = combineBits vs
    m = combineBits ms
    readBit c = case c of
      '0' -> (0,0)
      '1' -> (1,0)
      '.' -> (0,1)
      _   -> error $ "Clash.Sized.Internal.bLit: unknown character: " ++ show c ++ " in input: " ++ cs
instance KnownNat n => Eq (BitVector n) where
  (==) = eq#
  (/=) = neq#
{-# NOINLINE eq# #-}
eq# :: KnownNat n => BitVector n -> BitVector n -> Bool
eq# (BV 0 v1) (BV 0 v2 ) = v1 == v2
eq# bv1 bv2 = undefErrorI "==" bv1 bv2
{-# NOINLINE neq# #-}
neq# :: KnownNat n => BitVector n -> BitVector n -> Bool
neq# (BV 0 v1) (BV 0 v2) = v1 /= v2
neq# bv1 bv2 = undefErrorI "/=" bv1 bv2
instance KnownNat n => Ord (BitVector n) where
  (<)  = lt#
  (>=) = ge#
  (>)  = gt#
  (<=) = le#
lt#,ge#,gt#,le# :: KnownNat n => BitVector n -> BitVector n -> Bool
{-# NOINLINE lt# #-}
lt# (BV 0 n) (BV 0 m) = n < m
lt# bv1 bv2 = undefErrorI "<" bv1 bv2
{-# NOINLINE ge# #-}
ge# (BV 0 n) (BV 0 m) = n >= m
ge# bv1 bv2 = undefErrorI ">=" bv1 bv2
{-# NOINLINE gt# #-}
gt# (BV 0 n) (BV 0 m) = n > m
gt# bv1 bv2 = undefErrorI ">" bv1 bv2
{-# NOINLINE le# #-}
le# (BV 0 n) (BV 0 m) = n <= m
le#  bv1 bv2 = undefErrorI "<=" bv1 bv2
instance KnownNat n => Enum (BitVector n) where
  succ           = (+# fromInteger# 0 1)
  pred           = (-# fromInteger# 0 1)
  toEnum         = fromInteger# 0 . toInteger
  fromEnum       = fromEnum . toInteger#
  enumFrom       = enumFrom#
  enumFromThen   = enumFromThen#
  enumFromTo     = enumFromTo#
  enumFromThenTo = enumFromThenTo#
enumFrom# :: forall n. KnownNat n => BitVector n -> [BitVector n]
enumFrom# (BV 0 x) = map (BV 0 . (`mod` m)) [x .. unsafeToNatural (maxBound :: BitVector n)]
  where m = 1 `shiftL` fromInteger (natVal (Proxy @n))
enumFrom# bv = undefErrorU "enumFrom" bv
{-# NOINLINE enumFrom# #-}
enumFromThen#
  :: forall n
   . KnownNat n
  => BitVector n
  -> BitVector n
  -> [BitVector n]
enumFromThen# (BV 0 x) (BV 0 y) =
  toBvs [x, y .. unsafeToNatural bound]
 where
  bound = if x <= y then maxBound else minBound :: BitVector n
  toBvs = map (BV 0 . (`mod` m))
  m = 1 `shiftL` fromInteger (natVal (Proxy @n))
enumFromThen# bv1 bv2 = undefErrorP "enumFromThen" bv1 bv2
{-# NOINLINE enumFromThen# #-}
enumFromTo#
  :: forall n
   . KnownNat n
  => BitVector n
  -> BitVector n
  -> [BitVector n]
enumFromTo# (BV 0 x) (BV 0 y) = map (BV 0 . (`mod` m)) [x .. y]
  where m = 1 `shiftL` fromInteger (natVal (Proxy @n))
enumFromTo# bv1 bv2 = undefErrorP "enumFromTo" bv1 bv2
{-# NOINLINE enumFromTo# #-}
enumFromThenTo#
  :: forall n
   . KnownNat n
  => BitVector n
  -> BitVector n
  -> BitVector n
  -> [BitVector n]
enumFromThenTo# (BV 0 x1) (BV 0 x2) (BV 0 y) = map (BV 0 . (`mod` m)) [x1, x2 .. y]
  where m = 1 `shiftL` fromInteger (natVal (Proxy @n))
enumFromThenTo# bv1 bv2 bv3 = undefErrorP3 "enumFromTo" bv1 bv2 bv3
{-# NOINLINE enumFromThenTo# #-}
instance KnownNat n => Bounded (BitVector n) where
  minBound = minBound#
  maxBound = maxBound#
minBound# :: BitVector n
minBound# = BV 0 0
{-# NOINLINE minBound# #-}
maxBound# :: forall n. KnownNat n => BitVector n
maxBound# = let m = 1 `shiftL` natToNum @n in BV 0 (m-1)
{-# NOINLINE maxBound# #-}
instance KnownNat n => Num (BitVector n) where
  (+)         = (+#)
  (-)         = (-#)
  (*)         = (*#)
  negate      = negate#
  abs         = id
  signum bv   = resizeBV (pack# (reduceOr# bv))
  fromInteger = fromInteger# 0
(+#),(-#),(*#) :: forall n . KnownNat n => BitVector n -> BitVector n -> BitVector n
{-# NOINLINE (+#) #-}
(+#) = go
  where
    go (BV 0 i) (BV 0 j) = BV 0 (addMod m i j)
    go bv1 bv2 = undefErrorI "+" bv1 bv2
    m = 1 `shiftL` fromInteger (natVal (Proxy @n))
{-# NOINLINE (-#) #-}
(-#) = go
  where
    go (BV 0 i) (BV 0 j) = BV 0 (subMod m i j)
    go bv1 bv2 = undefErrorI "-" bv1 bv2
    m = 1 `shiftL` fromInteger (natVal (Proxy @n))
{-# NOINLINE (*#) #-}
(*#) = go
 where
  go (BV 0 i) (BV 0 j) = BV 0 (mulMod2 m i j)
  go bv1 bv2 = undefErrorI "*" bv1 bv2
  m = (1 `shiftL` fromInteger (natVal (Proxy @n))) - 1
{-# NOINLINE negate# #-}
negate# :: forall n . KnownNat n => BitVector n -> BitVector n
negate# = go
 where
  go (BV 0 i) = BV 0 (negateMod m i)
  go bv = undefErrorU "negate" bv
  m = 1 `shiftL` fromInteger (natVal (Proxy @n))
{-# NOINLINE fromInteger# #-}
fromInteger# :: KnownNat n => Natural -> Integer -> BitVector n
fromInteger# m i = sz `seq` mx
  where
    mx = BV (m `mod` naturalFromInteger sz)
            (naturalFromInteger (i `mod` sz))
    sz  = 1 `shiftL` fromInteger (natVal mx) :: Integer
instance (KnownNat m, KnownNat n) => ExtendingNum (BitVector m) (BitVector n) where
  type AResult (BitVector m) (BitVector n) = BitVector (Max m n + 1)
  add  = plus#
  sub = minus#
  type MResult (BitVector m) (BitVector n) = BitVector (m + n)
  mul = times#
{-# NOINLINE plus# #-}
plus# :: (KnownNat m, KnownNat n) => BitVector m -> BitVector n -> BitVector (Max m n + 1)
plus# (BV 0 a) (BV 0 b) = BV 0 (a + b)
plus# bv1 bv2 = undefErrorP "add" bv1 bv2
{-# NOINLINE minus# #-}
minus# :: forall m n . (KnownNat m, KnownNat n) => BitVector m -> BitVector n
                                                -> BitVector (Max m n + 1)
minus# = go
 where
  go (BV 0 a) (BV 0 b) = BV 0 (subMod m a b)
  go bv1 bv2 = undefErrorP "sub" bv1 bv2
  m = 1 `shiftL` fromInteger (natVal (Proxy @(Max m n + 1)))
{-# NOINLINE times# #-}
times# :: (KnownNat m, KnownNat n) => BitVector m -> BitVector n -> BitVector (m + n)
times# (BV 0 a) (BV 0 b) = BV 0 (a * b)
times# bv1 bv2 = undefErrorP "mul" bv1 bv2
instance KnownNat n => Real (BitVector n) where
  toRational = toRational . toInteger#
instance KnownNat n => Integral (BitVector n) where
  quot        = quot#
  rem         = rem#
  div         = quot#
  mod         = rem#
  quotRem n d = (n `quot#` d,n `rem#` d)
  divMod  n d = (n `quot#` d,n `rem#` d)
  toInteger   = toInteger#
quot#,rem# :: KnownNat n => BitVector n -> BitVector n -> BitVector n
{-# NOINLINE quot# #-}
quot# (BV 0 i) (BV 0 j) = BV 0 (i `quot` j)
quot# bv1 bv2 = undefErrorP "quot" bv1 bv2
{-# NOINLINE rem# #-}
rem# (BV 0 i) (BV 0 j) = BV 0 (i `rem` j)
rem# bv1 bv2 = undefErrorP "rem" bv1 bv2
{-# NOINLINE toInteger# #-}
toInteger# :: KnownNat n => BitVector n -> Integer
toInteger# (BV 0 i) = naturalToInteger i
toInteger# bv = undefErrorU "toInteger" bv
instance KnownNat n => Bits (BitVector n) where
  (.&.)             = and#
  (.|.)             = or#
  xor               = xor#
  complement        = complement#
  zeroBits          = 0
  bit i             = replaceBit# 0 i high
  setBit v i        = replaceBit# v i high
  clearBit v i      = replaceBit# v i low
  complementBit v i = replaceBit# v i (complement## (index# v i))
  testBit v i       = eq## (index# v i) high
  bitSizeMaybe v    = Just (size# v)
  bitSize           = size#
  isSigned _        = False
  shiftL v i        = shiftL# v i
  shiftR v i        = shiftR# v i
  rotateL v i       = rotateL# v i
  rotateR v i       = rotateR# v i
  popCount bv       = fromInteger (I.toInteger# (popCountBV (bv ++# (0 :: BitVector 1))))
instance KnownNat n => FiniteBits (BitVector n) where
  finiteBitSize       = size#
  countLeadingZeros   = fromInteger . I.toInteger# . countLeadingZerosBV
  countTrailingZeros  = fromInteger . I.toInteger# . countTrailingZerosBV
countLeadingZerosBV :: KnownNat n => BitVector n -> I.Index (n+1)
countLeadingZerosBV = V.foldr (\l r -> if eq## l low then 1 + r else 0) 0 . V.bv2v
{-# INLINE countLeadingZerosBV #-}
countTrailingZerosBV :: KnownNat n => BitVector n -> I.Index (n+1)
countTrailingZerosBV = V.foldl (\l r -> if eq## r low then 1 + l else 0) 0 . V.bv2v
{-# INLINE countTrailingZerosBV #-}
{-# NOINLINE reduceAnd# #-}
reduceAnd# :: KnownNat n => BitVector n -> Bit
reduceAnd# bv@(BV 0 i) = Bit 0 (W# (int2Word# (dataToTag# check)))
  where
    check = i == maxI
    sz    = natVal bv
    maxI  = (2 ^ sz) - 1
reduceAnd# bv = V.foldl (.&.) 1 (V.bv2v bv)
{-# NOINLINE reduceOr# #-}
reduceOr# :: KnownNat n => BitVector n -> Bit
reduceOr# (BV 0 i) = Bit 0 (W# (int2Word# (dataToTag# check)))
  where
    check = i /= 0
reduceOr# bv = V.foldl (.|.) 0 (V.bv2v bv)
{-# NOINLINE reduceXor# #-}
reduceXor# :: KnownNat n => BitVector n -> Bit
reduceXor# (BV 0 i) = Bit 0 (fromIntegral (popCount i `mod` 2))
reduceXor# bv = undefErrorU "reduceXor" bv
instance Default (BitVector n) where
  def = minBound#
{-# NOINLINE size# #-}
size# :: KnownNat n => BitVector n -> Int
size# bv = fromInteger (natVal bv)
{-# NOINLINE maxIndex# #-}
maxIndex# :: KnownNat n => BitVector n -> Int
maxIndex# bv = fromInteger (natVal bv) - 1
{-# NOINLINE index# #-}
index# :: KnownNat n => BitVector n -> Int -> Bit
index# bv@(BV m v) i
    | i >= 0 && i < sz = Bit (W# (int2Word# (dataToTag# (testBit m i))))
                             (W# (int2Word# (dataToTag# (testBit v i))))
    | otherwise        = err
  where
    sz  = fromInteger (natVal bv)
    err = error $ concat [ "(!): "
                         , show i
                         , " is out of range ["
                         , show (sz - 1)
                         , "..0]"
                         ]
{-# NOINLINE msb# #-}
msb# :: forall n . KnownNat n => BitVector n -> Bit
msb# (BV m v)
  = Bit (msbN m)
        (msbN v)
 where
  !(S# i#) = natVal (Proxy @n)
  msbN (NatS# w)  = W# (w `uncheckedShiftRL#` (i# GHC.Exts.-# 1#))
  msbN (NatJ# bn) = W# (bigNatToWord (shiftRBigNat bn (i# GHC.Exts.-# 1#)))
{-# NOINLINE lsb# #-}
lsb# :: BitVector n -> Bit
lsb# (BV m v) = Bit (W# (int2Word# (dataToTag# (testBit m 0))))
                    (W# (int2Word# (dataToTag# (testBit v 0))))
{-# NOINLINE slice# #-}
slice# :: BitVector (m + 1 + i) -> SNat m -> SNat n -> BitVector (m + 1 - n)
slice# (BV msk i) m n = BV (shiftR (msk .&. mask) n')
                           (shiftR (i   .&. mask) n')
  where
    m' = snatToInteger m
    n' = snatToNum n
    mask = 2 ^ (m' + 1) - 1
{-# NOINLINE (++#) #-}
(++#) :: KnownNat m => BitVector n -> BitVector m -> BitVector (n + m)
(BV m1 v1) ++# bv2@(BV m2 v2) = BV (m1' .|. m2) (v1' .|. v2)
  where
    size2 = fromInteger (natVal bv2)
    v1' = shiftL v1 size2
    m1' = shiftL m1 size2
{-# NOINLINE replaceBit# #-}
replaceBit# :: KnownNat n => BitVector n -> Int -> Bit -> BitVector n
replaceBit# bv@(BV m v) i (Bit mb b)
    | i >= 0 && i < sz = BV (clearBit m i  .|. (wordToNatural mb `shiftL` i))
                            (if testBit b 0 && mb == 0 then setBit v i else clearBit v i)
    | otherwise        = err
  where
    sz   = fromInteger (natVal bv)
    err  = error $ concat [ "replaceBit: "
                          , show i
                          , " is out of range ["
                          , show (sz - 1)
                          , "..0]"
                          ]
{-# NOINLINE setSlice# #-}
setSlice#
  :: forall m i n
   . SNat (m + 1 + i)
  -> BitVector (m + 1 + i)
  -> SNat m
  -> SNat n
  -> BitVector (m + 1 - n)
  -> BitVector (m + 1 + i)
setSlice# SNat =
  \(BV iMask i) m@SNat n (BV jMask j) ->
    let m' = snatToInteger m
        n' = snatToInteger n
        j'     = shiftL j     (fromInteger n')
        jMask' = shiftL jMask (fromInteger n')
        mask   = complementN ((2 ^ (m' + 1) - 1) `xor` (2 ^ n' - 1))
    in  BV ((iMask .&. mask) .|. jMask') ((i .&. mask) .|. j')
 where
  complementN = complementMod (natVal (Proxy @(m + 1 + i)))
{-# NOINLINE split# #-}
split#
  :: forall n m
   . KnownNat n
  => BitVector (m + n)
  -> (BitVector m, BitVector n)
split# (BV m i) =
  let n     = fromInteger (natVal (Proxy @n))
      mask  = maskMod (natVal (Proxy @n))
      r     = mask i
      rMask = mask m
      l     = i `shiftR` n
      lMask = m `shiftR` n
  in  (BV lMask l, BV rMask r)
and#, or#, xor# :: forall n . KnownNat n => BitVector n -> BitVector n -> BitVector n
{-# NOINLINE and# #-}
and# =
  \(BV m1 v1) (BV m2 v2) ->
    let mask = (m1.&.v2 .|. m1.&.m2 .|. m2.&.v1)
    in  BV mask (v1 .&. v2  .&. complementN mask)
  where
    complementN = complementMod (natVal (Proxy @n))
{-# NOINLINE or# #-}
or# =
  \(BV m1 v1) (BV m2 v2) ->
    let mask = m1 .&. complementN v2  .|.  m1.&.m2  .|.  m2 .&. complementN v1
    in  BV mask ((v1.|.v2) .&. complementN mask)
  where
    complementN = complementMod (natVal (Proxy @n))
{-# NOINLINE xor# #-}
xor# =
  \(BV m1 v1) (BV m2 v2) ->
    let mask  = m1 .|. m2
    in  BV mask ((v1 `xor` v2) .&. complementN mask)
  where
    complementN = complementMod (natVal (Proxy @n))
{-# NOINLINE complement# #-}
complement# :: forall n . KnownNat n => BitVector n -> BitVector n
complement# = \(BV m v) -> BV m (complementN v .&. complementN m)
  where complementN = complementMod (natVal (Proxy @n))
shiftL#, shiftR#, rotateL#, rotateR#
  :: forall n . KnownNat n => BitVector n -> Int -> BitVector n
{-# NOINLINE shiftL# #-}
shiftL# =
  \(BV msk v) i ->
    if i >= 0 then
      BV ((shiftL msk i) `mod` m) ((shiftL v i) `mod` m)
    else
      error ("'shiftL' undefined for negative number: " ++ show i)
 where
  m = 1 `shiftL` fromInteger (natVal (Proxy @n))
{-# NOINLINE shiftR# #-}
shiftR# (BV m v) i
  | i < 0     = error
              $ "'shiftR undefined for negative number: " ++ show i
  | otherwise = BV (shiftR m i) (shiftR v i)
{-# NOINLINE rotateL# #-}
rotateL# =
  \(BV msk v) b ->
    if b >= 0 then
      let vl    = shiftL v b'
          vr    = shiftR v b''
          ml    = shiftL msk b'
          mr    = shiftR msk b''
          b'   = b `mod` sz
          b''  = sz - b'
      in  BV ((ml .|. mr) `mod` m) ((vl .|. vr) `mod` m)
    else
      error "'rotateL' undefined for negative numbers"
 where
  sz = fromInteger (natVal (Proxy @n)) :: Int
  m  = 1 `shiftL` sz
{-# NOINLINE rotateR# #-}
rotateR# =
  \(BV msk v) b ->
    if b >= 0 then
      let vl   = shiftR v b'
          vr   = shiftL v b''
          ml   = shiftR msk b'
          mr   = shiftL msk b''
          b'   = b `mod` sz
          b''  = sz - b'
      in  BV ((ml .|. mr) `mod` m) ((vl .|. vr) `mod` m)
    else
      error "'rotateR' undefined for negative numbers"
 where
  sz = fromInteger (natVal (Proxy @n)) :: Int
  m  = 1 `shiftL` sz
popCountBV :: forall n . KnownNat n => BitVector (n+1) -> I.Index (n+2)
popCountBV bv =
  let v = V.bv2v bv
  in  sum (V.map (fromIntegral . pack#) v)
{-# INLINE popCountBV #-}
instance Resize BitVector where
  resize     = resizeBV
  zeroExtend = (0 ++#)
  signExtend = \bv -> (if msb# bv == low then id else complement) 0 ++# bv
  truncateB  = truncateB#
resizeBV :: forall n m . (KnownNat n, KnownNat m) => BitVector n -> BitVector m
resizeBV = case compareSNat @n @m (SNat @n) (SNat @m) of
  SNatLE -> (++#) @n @(m-n) 0
  SNatGT -> truncateB# @m @(n - m)
{-# INLINE resizeBV #-}
truncateB# :: forall a b . KnownNat a => BitVector (a + b) -> BitVector a
truncateB# = \(BV msk i) -> BV (msk `mod` m) (i `mod` m)
  where m = 1 `shiftL` fromInteger (natVal (Proxy @a))
{-# NOINLINE truncateB# #-}
instance KnownNat n => Lift (BitVector n) where
  lift bv@(BV m i) = sigE [| fromInteger# m $(litE (IntegerL (toInteger i))) |] (decBitVector (natVal bv))
  {-# NOINLINE lift #-}
decBitVector :: Integer -> TypeQ
decBitVector n = appT (conT ''BitVector) (litT $ numTyLit n)
instance KnownNat n => SaturatingNum (BitVector n) where
  satAdd SatWrap a b = a +# b
  satAdd SatZero a b =
    let r = plus# a b
    in  if msb# r == low
           then truncateB# r
           else minBound#
  satAdd _ a b =
    let r  = plus# a b
    in  if msb# r == low
           then truncateB# r
           else maxBound#
  satSub SatWrap a b = a -# b
  satSub _ a b =
    let r = minus# a b
    in  if msb# r == low
           then truncateB# r
           else minBound#
  satMul SatWrap a b = a *# b
  satMul SatZero a b =
    let r       = times# a b
        (rL,rR) = split# r
    in  case rL of
          0 -> rR
          _ -> minBound#
  satMul _ a b =
    let r       = times# a b
        (rL,rR) = split# r
    in  case rL of
          0 -> rR
          _ -> maxBound#
instance KnownNat n => Arbitrary (BitVector n) where
  arbitrary = arbitraryBoundedIntegral
  shrink    = shrinkSizedUnsigned
shrinkSizedUnsigned :: (KnownNat n, Integral (p n)) => p n -> [p n]
shrinkSizedUnsigned x | natVal x < 2 = case toInteger x of
                                         1 -> [0]
                                         _ -> []
                      
                      
                      
                      
                      | otherwise    = shrinkIntegral x
{-# INLINE shrinkSizedUnsigned #-}
instance KnownNat n => CoArbitrary (BitVector n) where
  coarbitrary = coarbitraryIntegral
type instance Index   (BitVector n) = Int
type instance IxValue (BitVector n) = Bit
instance KnownNat n => Ixed (BitVector n) where
  ix i f bv = replaceBit# bv i <$> f (index# bv i)
undefErrorI :: (HasCallStack, KnownNat m, KnownNat n) => String -> BitVector m -> BitVector n -> a
undefErrorI op bv1 bv2 = withFrozenCallStack $
  errorX $ "Clash.Sized.BitVector." ++ op
  ++ " called with (partially) undefined arguments: "
  ++ show bv1 ++ " " ++ op ++" " ++ show bv2
undefErrorP :: (HasCallStack, KnownNat m, KnownNat n) => String -> BitVector m -> BitVector n -> a
undefErrorP op bv1 bv2 = withFrozenCallStack $
  errorX $ "Clash.Sized.BitVector." ++ op
  ++ " called with (partially) undefined arguments: "
  ++ show bv1 ++ " " ++ show bv2
undefErrorP3 :: (HasCallStack, KnownNat m, KnownNat n, KnownNat o) => String -> BitVector m -> BitVector n -> BitVector o -> a
undefErrorP3 op bv1 bv2 bv3 = withFrozenCallStack $
  errorX $ "Clash.Sized.BitVector." ++ op
  ++ " called with (partially) undefined arguments: "
  ++ show bv1 ++ " " ++ show bv2 ++ " " ++ show bv3
undefErrorU :: (HasCallStack, KnownNat n) => String -> BitVector n -> a
undefErrorU op bv1 = withFrozenCallStack $
  errorX $ "Clash.Sized.BitVector." ++ op
  ++ " called with (partially) undefined argument: "
  ++ show bv1
undefError :: (HasCallStack, KnownNat n) => String -> [BitVector n] -> a
undefError op bvs = withFrozenCallStack $
  errorX $ op
  ++ " called with (partially) undefined arguments: "
  ++ unwords (L.map show bvs)
checkUnpackUndef :: (KnownNat n, Typeable a)
                 => (BitVector n -> a) 
                 -> BitVector n -> a
checkUnpackUndef f bv@(BV 0 _) = f bv
checkUnpackUndef _ bv = res
  where
    ty = typeOf res
    res = undefError (show ty ++ ".unpack") [bv]
{-# NOINLINE checkUnpackUndef #-}
undefined# :: forall n . KnownNat n => BitVector n
undefined# =
  let m = 1 `shiftL` fromInteger (natVal (Proxy @n))
  in  BV (m-1) 0
{-# NOINLINE undefined# #-}
isLike :: forall n . KnownNat n => BitVector n -> BitVector n -> Bool
isLike =
  \(BV cMask c) (BV eMask e) ->
        
    let e' = e .&. complementN eMask
        
        c' = (c .&. complementN cMask) .&. complementN eMask
        
        c'' = (c .|. cMask) .&. complementN eMask
    in  e' == c' && e' == c''
 where
  complementN = complementMod (natVal (Proxy @n))
{-# NOINLINE isLike #-}
fromBits :: [Bit] -> Integer
fromBits = L.foldl (\v b -> v `shiftL` 1 .|. fromIntegral b) 0
bitPattern :: String -> Q Pat
bitPattern s = [p| (($mask .&.) -> $target) |]
  where
    bs = parse <$> s
    mask = litE . IntegerL . fromBits $ maybe 0 (const 1) <$> bs
    target = litP . IntegerL . fromBits $ fromMaybe 0 <$> bs
    parse '.' = Nothing
    parse '0' = Just 0
    parse '1' = Just 1
    parse c = error $ "Invalid bit pattern: " ++ show c