{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE Unsafe #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
{-# OPTIONS_HADDOCK show-extensions not-home #-}
module Clash.Sized.Internal.BitVector
(
Bit (..)
, high
, low
, eq##
, neq##
, lt##
, ge##
, gt##
, le##
, fromInteger##
, and##
, or##
, xor##
, complement##
, pack#
, unpack#
, BitVector (..)
, size#
, maxIndex#
, bLit
, undefined#
, (++#)
, reduceAnd#
, reduceOr#
, reduceXor#
, index#
, replaceBit#
, setSlice#
, slice#
, split#
, msb#
, lsb#
, eq#
, neq#
, isLike
, lt#
, ge#
, gt#
, le#
, enumFrom#
, enumFromThen#
, enumFromTo#
, enumFromThenTo#
, minBound#
, maxBound#
, (+#)
, (-#)
, (*#)
, negate#
, fromInteger#
, plus#
, minus#
, times#
, quot#
, rem#
, toInteger#
, and#
, or#
, xor#
, complement#
, shiftL#
, shiftR#
, rotateL#
, rotateR#
, popCountBV
, countLeadingZerosBV
, countTrailingZerosBV
, truncateB#
, shrinkSizedUnsigned
, undefError
, checkUnpackUndef
, bitPattern
)
where
import Control.DeepSeq (NFData (..))
import Control.Lens (Index, Ixed (..), IxValue)
import Data.Bits (Bits (..), FiniteBits (..))
import Data.Data (Data)
import Data.Default.Class (Default (..))
import Data.Either (isLeft)
import Data.Proxy (Proxy (..))
import Data.Typeable (Typeable, typeOf)
import GHC.Generics (Generic)
import Data.Maybe (fromMaybe)
import GHC.Integer (smallInteger)
import GHC.Prim (dataToTag#)
import GHC.Stack (HasCallStack, withFrozenCallStack)
import GHC.TypeLits (KnownNat, Nat, type (+), type (-), natVal)
import GHC.TypeLits.Extra (Max)
import Language.Haskell.TH (Q, TExp, TypeQ, appT, conT, litT, numTyLit, sigE, Lit(..), litE, Pat, litP)
import Language.Haskell.TH.Syntax (Lift(..))
import Test.QuickCheck.Arbitrary (Arbitrary (..), CoArbitrary (..),
arbitraryBoundedIntegral,
coarbitraryIntegral, shrinkIntegral)
import Clash.Class.Num (ExtendingNum (..), SaturatingNum (..),
SaturationMode (..))
import Clash.Class.Resize (Resize (..))
import Clash.Promoted.Nat
(SNat (..), SNatLE (..), compareSNat, snatToInteger, snatToNum)
import Clash.XException
(ShowX (..), NFDataX (..), errorX, isX, showsPrecXWith, rwhnfX)
import {-# SOURCE #-} qualified Clash.Sized.Vector as V
import {-# SOURCE #-} qualified Clash.Sized.Internal.Index as I
import qualified Data.List as L
data BitVector (n :: Nat) =
BV { unsafeMask :: !Integer
, unsafeToInteger :: !Integer
}
deriving (Data, Generic)
data Bit =
Bit { unsafeMask# :: !Integer
, unsafeToInteger# :: !Integer
}
deriving (Data, Generic)
{-# NOINLINE high #-}
high :: Bit
high = Bit 0 1
{-# NOINLINE low #-}
low :: Bit
low = Bit 0 0
instance NFData Bit where
rnf (Bit m i) = rnf m `seq` rnf i `seq` ()
{-# NOINLINE rnf #-}
instance Show Bit where
show (Bit 0 b) =
case testBit b 0 of
True -> "1"
False -> "0"
show (Bit _ _) = "."
instance ShowX Bit where
showsPrecX = showsPrecXWith showsPrec
instance NFDataX Bit where
deepErrorX = errorX
rnfX = rwhnfX
hasUndefined bv = isLeft (isX bv) || unsafeMask# bv /= 0
instance Lift Bit where
lift (Bit m i) = [| fromInteger## m i |]
{-# NOINLINE lift #-}
instance Eq Bit where
(==) = eq##
(/=) = neq##
eq## :: Bit -> Bit -> Bool
eq## b1 b2 = eq# (pack# b1) (pack# b2)
{-# NOINLINE eq## #-}
neq## :: Bit -> Bit -> Bool
neq## b1 b2 = neq# (pack# b1) (pack# b2)
{-# NOINLINE neq## #-}
instance Ord Bit where
(<) = lt##
(<=) = le##
(>) = gt##
(>=) = ge##
lt##,ge##,gt##,le## :: Bit -> Bit -> Bool
lt## b1 b2 = lt# (pack# b1) (pack# b2)
{-# NOINLINE lt## #-}
ge## b1 b2 = ge# (pack# b1) (pack# b2)
{-# NOINLINE ge## #-}
gt## b1 b2 = gt# (pack# b1) (pack# b2)
{-# NOINLINE gt## #-}
le## b1 b2 = le# (pack# b1) (pack# b2)
{-# NOINLINE le## #-}
instance Enum Bit where
toEnum = fromInteger## 0 . toInteger
fromEnum b = if eq## b low then 0 else 1
instance Bounded Bit where
minBound = low
maxBound = high
instance Default Bit where
def = low
instance Num Bit where
(+) = xor##
(-) = xor##
(*) = and##
negate = complement##
abs = id
signum b = b
fromInteger = fromInteger## 0
fromInteger## :: Integer -> Integer -> Bit
fromInteger## m i = Bit (m `mod` 2) (i `mod` 2)
{-# NOINLINE fromInteger## #-}
instance Real Bit where
toRational b = if eq## b low then 0 else 1
instance Integral Bit where
quot a _ = a
rem _ _ = low
div a _ = a
mod _ _ = low
quotRem n _ = (n,low)
divMod n _ = (n,low)
toInteger b = if eq## b low then 0 else 1
instance Bits Bit where
(.&.) = and##
(.|.) = or##
xor = xor##
complement = complement##
zeroBits = low
bit i = if i == 0 then high else low
setBit b i = if i == 0 then high else b
clearBit b i = if i == 0 then low else b
complementBit b i = if i == 0 then complement## b else b
testBit b i = if i == 0 then eq## b high else False
bitSizeMaybe _ = Just 1
bitSize _ = 1
isSigned _ = False
shiftL b i = if i == 0 then b else low
shiftR b i = if i == 0 then b else low
rotateL b _ = b
rotateR b _ = b
popCount b = if eq## b low then 0 else 1
instance FiniteBits Bit where
finiteBitSize _ = 1
countLeadingZeros b = if eq## b low then 1 else 0
countTrailingZeros b = if eq## b low then 1 else 0
and##, or##, xor## :: Bit -> Bit -> Bit
and## b1 b2 = unpack# $ and# (pack# b1) (pack# b2)
{-# NOINLINE and## #-}
or## b1 b2 = unpack# $ or# (pack# b1) (pack# b2)
{-# NOINLINE or## #-}
xor## b1 b2 = unpack# $ xor# (pack# b1) (pack# b2)
{-# NOINLINE xor## #-}
complement## :: Bit -> Bit
complement## = unpack# . complement# . pack#
{-# NOINLINE complement## #-}
pack# :: Bit -> BitVector 1
pack# (Bit m b) = BV m b
{-# NOINLINE pack# #-}
unpack# :: BitVector 1 -> Bit
unpack# (BV m b) = Bit m b
{-# NOINLINE unpack# #-}
instance NFData (BitVector n) where
rnf (BV i m) = rnf i `seq` rnf m `seq` ()
{-# NOINLINE rnf #-}
instance KnownNat n => Show (BitVector n) where
show bv@(BV msk i) = reverse . underScore . reverse $ showBV (natVal bv) msk i []
where
showBV 0 _ _ s = s
showBV n m v s = let (v',vBit) = divMod v 2
(m',mBit) = divMod m 2
in case (mBit,vBit) of
(0,0) -> showBV (n - 1) m' v' ('0':s)
(0,_) -> showBV (n - 1) m' v' ('1':s)
_ -> showBV (n - 1) m' v' ('.':s)
underScore xs = case splitAt 5 xs of
([a,b,c,d,e],rest) -> [a,b,c,d,'_'] ++ underScore (e:rest)
(rest,_) -> rest
{-# NOINLINE show #-}
instance KnownNat n => ShowX (BitVector n) where
showsPrecX = showsPrecXWith showsPrec
instance NFDataX (BitVector n) where
deepErrorX = errorX
rnfX = rwhnfX
hasUndefined bv = isLeft (isX bv) || unsafeMask bv /= 0
bLit :: forall n. KnownNat n => String -> Q (TExp (BitVector n))
bLit s = [|| fromInteger# m i ||]
where
bv :: BitVector n
bv = read# s
m,i :: Integer
BV m i = bv
read# :: KnownNat n => String -> BitVector n
read# cs = BV m v
where
(vs,ms) = unzip . map readBit . filter (/= '_') $ cs
combineBits = foldl (\b a -> b*2+a) 0
v = combineBits vs
m = combineBits ms
readBit c = case c of
'0' -> (0,0)
'1' -> (1,0)
'.' -> (0,1)
_ -> error $ "Clash.Sized.Internal.bLit: unknown character: " ++ show c ++ " in input: " ++ cs
instance KnownNat n => Eq (BitVector n) where
(==) = eq#
(/=) = neq#
{-# NOINLINE eq# #-}
eq# :: KnownNat n => BitVector n -> BitVector n -> Bool
eq# (BV 0 v1) (BV 0 v2 ) = v1 == v2
eq# bv1 bv2 = undefErrorI "==" bv1 bv2
{-# NOINLINE neq# #-}
neq# :: KnownNat n => BitVector n -> BitVector n -> Bool
neq# (BV 0 v1) (BV 0 v2) = v1 /= v2
neq# bv1 bv2 = undefErrorI "/=" bv1 bv2
instance KnownNat n => Ord (BitVector n) where
(<) = lt#
(>=) = ge#
(>) = gt#
(<=) = le#
lt#,ge#,gt#,le# :: KnownNat n => BitVector n -> BitVector n -> Bool
{-# NOINLINE lt# #-}
lt# (BV 0 n) (BV 0 m) = n < m
lt# bv1 bv2 = undefErrorI "<" bv1 bv2
{-# NOINLINE ge# #-}
ge# (BV 0 n) (BV 0 m) = n >= m
ge# bv1 bv2 = undefErrorI ">=" bv1 bv2
{-# NOINLINE gt# #-}
gt# (BV 0 n) (BV 0 m) = n > m
gt# bv1 bv2 = undefErrorI ">" bv1 bv2
{-# NOINLINE le# #-}
le# (BV 0 n) (BV 0 m) = n <= m
le# bv1 bv2 = undefErrorI "<=" bv1 bv2
instance KnownNat n => Enum (BitVector n) where
succ = (+# fromInteger# 0 1)
pred = (-# fromInteger# 0 1)
toEnum = fromInteger# 0 . toInteger
fromEnum = fromEnum . toInteger#
enumFrom = enumFrom#
enumFromThen = enumFromThen#
enumFromTo = enumFromTo#
enumFromThenTo = enumFromThenTo#
{-# NOINLINE enumFrom# #-}
{-# NOINLINE enumFromThen# #-}
{-# NOINLINE enumFromTo# #-}
{-# NOINLINE enumFromThenTo# #-}
enumFrom# :: forall n. KnownNat n => BitVector n -> [BitVector n]
enumFromThen# :: forall n. KnownNat n => BitVector n -> BitVector n -> [BitVector n]
enumFromTo# :: KnownNat n => BitVector n -> BitVector n -> [BitVector n]
enumFromThenTo# :: KnownNat n => BitVector n -> BitVector n -> BitVector n -> [BitVector n]
enumFrom# (BV 0 x) = map (fromInteger_INLINE 0) [x .. unsafeToInteger (maxBound :: BitVector n)]
enumFrom# bv
= undefErrorU "enumFrom" bv
enumFromThen# (BV 0 x) (BV 0 y) = map (fromInteger_INLINE 0) [x, y .. unsafeToInteger (maxBound :: BitVector n)]
enumFromThen# bv1 bv2
= undefErrorP "enumFromThen" bv1 bv2
enumFromTo# (BV 0 x) (BV 0 y) = map (BV 0) [x .. y]
enumFromTo# bv1 bv2
= undefErrorP "enumFromTo" bv1 bv2
enumFromThenTo# (BV 0 x1) (BV 0 x2) (BV 0 y) = map (BV 0) [x1, x2 .. y]
enumFromThenTo# bv1 bv2 bv3
= undefErrorP3 "enumFromTo" bv1 bv2 bv3
instance KnownNat n => Bounded (BitVector n) where
minBound = minBound#
maxBound = maxBound#
{-# NOINLINE minBound# #-}
minBound# :: BitVector n
minBound# = BV 0 0
{-# NOINLINE maxBound# #-}
maxBound# :: forall n . KnownNat n => BitVector n
maxBound# = let m = 1 `shiftL` fromInteger (natVal (Proxy @n))
in BV 0 (m-1)
instance KnownNat n => Num (BitVector n) where
(+) = (+#)
(-) = (-#)
(*) = (*#)
negate = negate#
abs = id
signum bv = resizeBV (pack# (reduceOr# bv))
fromInteger = fromInteger# 0
(+#),(-#),(*#) :: forall n . KnownNat n => BitVector n -> BitVector n -> BitVector n
{-# NOINLINE (+#) #-}
(+#) (BV 0 i) (BV 0 j) =
let m = 1 `shiftL` fromInteger (natVal (Proxy @n))
z = i + j
in if z >= m then BV 0 (z - m) else BV 0 z
(+#) bv1 bv2 = undefErrorI "+" bv1 bv2
{-# NOINLINE (-#) #-}
(-#) (BV 0 i) (BV 0 j) =
let m = 1 `shiftL` fromInteger (natVal (Proxy @n))
z = i - j
in if z < 0 then BV 0 (m + z) else BV 0 z
(-#) bv1 bv2 = undefErrorI "-" bv1 bv2
{-# NOINLINE (*#) #-}
(*#) (BV 0 i) (BV 0 j) = fromInteger_INLINE 0 (i * j)
(*#) bv1 bv2 = undefErrorI "*" bv1 bv2
{-# NOINLINE negate# #-}
negate# :: forall n . KnownNat n => BitVector n -> BitVector n
negate# (BV 0 0) = BV 0 0
negate# (BV 0 i) = BV 0 (sz - i)
where
sz = 1 `shiftL` fromInteger (natVal (Proxy @n))
negate# bv = undefErrorU "negate" bv
{-# NOINLINE fromInteger# #-}
fromInteger# :: KnownNat n => Integer -> Integer -> BitVector n
fromInteger# = fromInteger_INLINE
{-# INLINE fromInteger_INLINE #-}
fromInteger_INLINE :: forall n . KnownNat n => Integer -> Integer -> BitVector n
fromInteger_INLINE m i = sz `seq` BV (m `mod` sz) (i `mod` sz)
where
sz = 1 `shiftL` fromInteger (natVal (Proxy @n))
instance (KnownNat m, KnownNat n) => ExtendingNum (BitVector m) (BitVector n) where
type AResult (BitVector m) (BitVector n) = BitVector (Max m n + 1)
add = plus#
sub = minus#
type MResult (BitVector m) (BitVector n) = BitVector (m + n)
mul = times#
{-# NOINLINE plus# #-}
plus# :: (KnownNat m, KnownNat n) => BitVector m -> BitVector n -> BitVector (Max m n + 1)
plus# (BV 0 a) (BV 0 b) = BV 0 (a + b)
plus# bv1 bv2 = undefErrorP "plus" bv1 bv2
{-# NOINLINE minus# #-}
minus# :: forall m n . (KnownNat m, KnownNat n) => BitVector m -> BitVector n
-> BitVector (Max m n + 1)
minus# (BV 0 a) (BV 0 b) =
let sz = fromInteger (natVal (Proxy @(Max m n + 1)))
mask = 1 `shiftL` sz
z = a - b
in if z < 0 then BV 0 (mask + z) else BV 0 z
minus# bv1 bv2 = undefErrorP "minus" bv1 bv2
{-# NOINLINE times# #-}
times# :: (KnownNat m, KnownNat n) => BitVector m -> BitVector n -> BitVector (m + n)
times# (BV 0 a) (BV 0 b) = BV 0 (a * b)
times# bv1 bv2 = undefErrorP "times" bv1 bv2
instance KnownNat n => Real (BitVector n) where
toRational = toRational . toInteger#
instance KnownNat n => Integral (BitVector n) where
quot = quot#
rem = rem#
div = quot#
mod = rem#
quotRem n d = (n `quot#` d,n `rem#` d)
divMod n d = (n `quot#` d,n `rem#` d)
toInteger = toInteger#
quot#,rem# :: KnownNat n => BitVector n -> BitVector n -> BitVector n
{-# NOINLINE quot# #-}
quot# (BV 0 i) (BV 0 j) = BV 0 (i `quot` j)
quot# bv1 bv2 = undefErrorP "quot" bv1 bv2
{-# NOINLINE rem# #-}
rem# (BV 0 i) (BV 0 j) = BV 0 (i `rem` j)
rem# bv1 bv2 = undefErrorP "rem" bv1 bv2
{-# NOINLINE toInteger# #-}
toInteger# :: KnownNat n => BitVector n -> Integer
toInteger# (BV 0 i) = i
toInteger# bv = undefErrorU "toInteger" bv
instance KnownNat n => Bits (BitVector n) where
(.&.) = and#
(.|.) = or#
xor = xor#
complement = complement#
zeroBits = 0
bit i = replaceBit# 0 i high
setBit v i = replaceBit# v i high
clearBit v i = replaceBit# v i low
complementBit v i = replaceBit# v i (complement## (index# v i))
testBit v i = eq## (index# v i) high
bitSizeMaybe v = Just (size# v)
bitSize = size#
isSigned _ = False
shiftL v i = shiftL# v i
shiftR v i = shiftR# v i
rotateL v i = rotateL# v i
rotateR v i = rotateR# v i
popCount bv = fromInteger (I.toInteger# (popCountBV (bv ++# (0 :: BitVector 1))))
instance KnownNat n => FiniteBits (BitVector n) where
finiteBitSize = size#
countLeadingZeros = fromInteger . I.toInteger# . countLeadingZerosBV
countTrailingZeros = fromInteger . I.toInteger# . countTrailingZerosBV
countLeadingZerosBV :: KnownNat n => BitVector n -> I.Index (n+1)
countLeadingZerosBV = V.foldr (\l r -> if eq## l low then 1 + r else 0) 0 . V.bv2v
{-# INLINE countLeadingZerosBV #-}
countTrailingZerosBV :: KnownNat n => BitVector n -> I.Index (n+1)
countTrailingZerosBV = V.foldl (\l r -> if eq## r low then 1 + l else 0) 0 . V.bv2v
{-# INLINE countTrailingZerosBV #-}
{-# NOINLINE reduceAnd# #-}
reduceAnd# :: KnownNat n => BitVector n -> Bit
reduceAnd# bv@(BV 0 i) = Bit 0 (smallInteger (dataToTag# check))
where
check = i == maxI
sz = natVal bv
maxI = (2 ^ sz) - 1
reduceAnd# bv = V.foldl (.&.) 1 (V.bv2v bv)
{-# NOINLINE reduceOr# #-}
reduceOr# :: KnownNat n => BitVector n -> Bit
reduceOr# (BV 0 i) = Bit 0 (smallInteger (dataToTag# check))
where
check = i /= 0
reduceOr# bv = V.foldl (.|.) 0 (V.bv2v bv)
{-# NOINLINE reduceXor# #-}
reduceXor# :: KnownNat n => BitVector n -> Bit
reduceXor# (BV 0 i) = Bit 0 (toInteger (popCount i `mod` 2))
reduceXor# bv = undefErrorU "reduceXor" bv
instance Default (BitVector n) where
def = minBound#
{-# NOINLINE size# #-}
size# :: KnownNat n => BitVector n -> Int
size# bv = fromInteger (natVal bv)
{-# NOINLINE maxIndex# #-}
maxIndex# :: KnownNat n => BitVector n -> Int
maxIndex# bv = fromInteger (natVal bv) - 1
{-# NOINLINE index# #-}
index# :: KnownNat n => BitVector n -> Int -> Bit
index# bv@(BV m v) i
| i >= 0 && i < sz = Bit (smallInteger (dataToTag# (testBit m i)))
(smallInteger (dataToTag# (testBit v i)))
| otherwise = err
where
sz = fromInteger (natVal bv)
err = error $ concat [ "(!): "
, show i
, " is out of range ["
, show (sz - 1)
, "..0]"
]
{-# NOINLINE msb# #-}
msb# :: forall n . KnownNat n => BitVector n -> Bit
msb# (BV m v)
= let i = fromInteger (natVal (Proxy @n) - 1)
in Bit (smallInteger (dataToTag# (testBit m i)))
(smallInteger (dataToTag# (testBit v i)))
{-# NOINLINE lsb# #-}
lsb# :: BitVector n -> Bit
lsb# (BV m v) = Bit (smallInteger (dataToTag# (testBit m 0)))
(smallInteger (dataToTag# (testBit v 0)))
{-# NOINLINE slice# #-}
slice# :: BitVector (m + 1 + i) -> SNat m -> SNat n -> BitVector (m + 1 - n)
slice# (BV msk i) m n = BV (shiftR (msk .&. mask) n')
(shiftR (i .&. mask) n')
where
m' = snatToInteger m
n' = snatToNum n
mask = 2 ^ (m' + 1) - 1
{-# NOINLINE (++#) #-}
(++#) :: KnownNat m => BitVector n -> BitVector m -> BitVector (n + m)
(BV m1 v1) ++# bv2@(BV m2 v2) = BV (m1' .|. m2) (v1' .|. v2)
where
size2 = fromInteger (natVal bv2)
v1' = shiftL v1 size2
m1' = shiftL m1 size2
{-# NOINLINE replaceBit# #-}
replaceBit# :: KnownNat n => BitVector n -> Int -> Bit -> BitVector n
replaceBit# bv@(BV m v) i (Bit mb b)
| i >= 0 && i < sz = BV (clearBit m i .|. (mb `shiftL` i))
(if testBit b 0 && mb == 0 then setBit v i else clearBit v i)
| otherwise = err
where
sz = fromInteger (natVal bv)
err = error $ concat [ "replaceBit: "
, show i
, " is out of range ["
, show (sz - 1)
, "..0]"
]
{-# NOINLINE setSlice# #-}
setSlice#
:: BitVector (m + 1 + i)
-> SNat m
-> SNat n
-> BitVector (m + 1 - n)
-> BitVector (m + 1 + i)
setSlice# (BV iMask i) m n (BV jMask j) = BV ((iMask .&. mask) .|. jMask')
((i .&. mask) .|. j')
where
m' = snatToInteger m
n' = snatToInteger n
j' = shiftL j (fromInteger n')
jMask' = shiftL jMask (fromInteger n')
mask = complement ((2 ^ (m' + 1) - 1) `xor` (2 ^ n' - 1))
{-# NOINLINE split# #-}
split#
:: forall n m
. KnownNat n
=> BitVector (m + n)
-> (BitVector m, BitVector n)
split# (BV m i) = (BV lMask l, BV rMask r)
where
n = fromInteger (natVal (Proxy @n))
mask = 1 `shiftL` n
r = i `mod` mask
rMask = m `mod` mask
l = i `shiftR` n
lMask = m `shiftR` n
and#, or#, xor# :: BitVector n -> BitVector n -> BitVector n
{-# NOINLINE and# #-}
and# (BV m1 v1) (BV m2 v2) = BV mask (v1 .&. v2 .&. complement mask)
where
mask = (m1.&.v2 .|. m1.&.m2 .|. m2.&.v1)
{-# NOINLINE or# #-}
or# (BV m1 v1) (BV m2 v2) = BV mask ((v1.|.v2) .&. complement mask)
where
mask = m1 .&. complement v2 .|. m1.&.m2 .|. m2 .&. complement v1
{-# NOINLINE xor# #-}
xor# (BV m1 v1) (BV m2 v2) = BV mask ((v1 `xor` v2) .&. complement mask)
where
mask = m1 .|. m2
{-# NOINLINE complement# #-}
complement# :: KnownNat n => BitVector n -> BitVector n
complement# (BV m v) = fromInteger_INLINE m (complement v .&. complement m)
shiftL#, shiftR#, rotateL#, rotateR#
:: KnownNat n => BitVector n -> Int -> BitVector n
{-# NOINLINE shiftL# #-}
shiftL# (BV m v) i
| i < 0 = error
$ "'shiftL undefined for negative number: " ++ show i
| otherwise = fromInteger_INLINE (shiftL m i) (shiftL v i)
{-# NOINLINE shiftR# #-}
shiftR# (BV m v) i
| i < 0 = error
$ "'shiftR undefined for negative number: " ++ show i
| otherwise = BV (shiftR m i) (shiftR v i)
{-# NOINLINE rotateL# #-}
rotateL# _ b | b < 0 = error "'shiftL undefined for negative numbers"
rotateL# bv@(BV m v) b = fromInteger_INLINE (ml .|. mr) (vl .|. vr)
where
vl = shiftL v b'
vr = shiftR v b''
ml = shiftL m b'
mr = shiftR m b''
b' = b `mod` sz
b'' = sz - b'
sz = fromInteger (natVal bv)
{-# NOINLINE rotateR# #-}
rotateR# _ b | b < 0 = error "'shiftR undefined for negative numbers"
rotateR# bv@(BV m v) b = fromInteger_INLINE (ml .|. mr) (vl .|. vr)
where
vl = shiftR v b'
vr = shiftL v b''
ml = shiftR m b'
mr = shiftL m b''
b' = b `mod` sz
b'' = sz - b'
sz = fromInteger (natVal bv)
popCountBV :: forall n . KnownNat n => BitVector (n+1) -> I.Index (n+2)
popCountBV bv =
let v = V.bv2v bv
in sum (V.map (fromIntegral . pack#) v)
{-# INLINE popCountBV #-}
instance Resize BitVector where
resize = resizeBV
zeroExtend = (0 ++#)
signExtend = \bv -> (if msb# bv == low then id else complement) 0 ++# bv
truncateB = truncateB#
resizeBV :: forall n m . (KnownNat n, KnownNat m) => BitVector n -> BitVector m
resizeBV = case compareSNat @n @m (SNat @n) (SNat @m) of
SNatLE -> (++#) @n @(m-n) 0
SNatGT -> truncateB# @m @(n - m)
{-# INLINE resizeBV #-}
truncateB# :: forall a b . KnownNat a => BitVector (a + b) -> BitVector a
truncateB# (BV msk i) = fromInteger_INLINE msk i
{-# NOINLINE truncateB# #-}
instance KnownNat n => Lift (BitVector n) where
lift bv@(BV m i) = sigE [| fromInteger# m i |] (decBitVector (natVal bv))
{-# NOINLINE lift #-}
decBitVector :: Integer -> TypeQ
decBitVector n = appT (conT ''BitVector) (litT $ numTyLit n)
instance KnownNat n => SaturatingNum (BitVector n) where
satAdd SatWrap a b = a +# b
satAdd SatZero a b =
let r = plus# a b
in if msb# r == low
then truncateB# r
else minBound#
satAdd _ a b =
let r = plus# a b
in if msb# r == low
then truncateB# r
else maxBound#
satSub SatWrap a b = a -# b
satSub _ a b =
let r = minus# a b
in if msb# r == low
then truncateB# r
else minBound#
satMul SatWrap a b = a *# b
satMul SatZero a b =
let r = times# a b
(rL,rR) = split# r
in case rL of
0 -> rR
_ -> minBound#
satMul _ a b =
let r = times# a b
(rL,rR) = split# r
in case rL of
0 -> rR
_ -> maxBound#
instance KnownNat n => Arbitrary (BitVector n) where
arbitrary = arbitraryBoundedIntegral
shrink = shrinkSizedUnsigned
shrinkSizedUnsigned :: (KnownNat n, Integral (p n)) => p n -> [p n]
shrinkSizedUnsigned x | natVal x < 2 = case toInteger x of
1 -> [0]
_ -> []
| otherwise = shrinkIntegral x
{-# INLINE shrinkSizedUnsigned #-}
instance KnownNat n => CoArbitrary (BitVector n) where
coarbitrary = coarbitraryIntegral
type instance Index (BitVector n) = Int
type instance IxValue (BitVector n) = Bit
instance KnownNat n => Ixed (BitVector n) where
ix i f bv = replaceBit# bv i <$> f (index# bv i)
undefErrorI :: (HasCallStack, KnownNat m, KnownNat n) => String -> BitVector m -> BitVector n -> a
undefErrorI op bv1 bv2 = withFrozenCallStack $
errorX $ "Clash.Sized.BitVector." ++ op
++ " called with (partially) undefined arguments: "
++ show bv1 ++ " " ++ op ++" " ++ show bv2
undefErrorP :: (HasCallStack, KnownNat m, KnownNat n) => String -> BitVector m -> BitVector n -> a
undefErrorP op bv1 bv2 = withFrozenCallStack $
errorX $ "Clash.Sized.BitVector." ++ op
++ " called with (partially) undefined arguments: "
++ show bv1 ++ " " ++ show bv2
undefErrorP3 :: (HasCallStack, KnownNat m, KnownNat n, KnownNat o) => String -> BitVector m -> BitVector n -> BitVector o -> a
undefErrorP3 op bv1 bv2 bv3 = withFrozenCallStack $
errorX $ "Clash.Sized.BitVector." ++ op
++ " called with (partially) undefined arguments: "
++ show bv1 ++ " " ++ show bv2 ++ " " ++ show bv3
undefErrorU :: (HasCallStack, KnownNat n) => String -> BitVector n -> a
undefErrorU op bv1 = withFrozenCallStack $
errorX $ "Clash.Sized.BitVector." ++ op
++ " called with (partially) undefined argument: "
++ show bv1
undefError :: (HasCallStack, KnownNat n) => String -> [BitVector n] -> a
undefError op bvs = withFrozenCallStack $
errorX $ op
++ " called with (partially) undefined arguments: "
++ unwords (L.map show bvs)
checkUnpackUndef :: (KnownNat n, Typeable a)
=> (BitVector n -> a)
-> BitVector n -> a
checkUnpackUndef f bv@(BV 0 _) = f bv
checkUnpackUndef _ bv = res
where
ty = typeOf res
res = undefError (show ty ++ ".unpack") [bv]
{-# NOINLINE checkUnpackUndef #-}
undefined# :: forall n . KnownNat n => BitVector n
undefined# =
let m = 1 `shiftL` fromInteger (natVal (Proxy @n))
in BV (m-1) 0
{-# NOINLINE undefined# #-}
isLike :: BitVector n -> BitVector n -> Bool
isLike (BV cMask c) (BV eMask e) = e' == c' && e' == c''
where
e' = e .&. complement eMask
c' = (c .&. complement cMask) .&. complement eMask
c'' = (c .|. cMask) .&. complement eMask
{-# NOINLINE isLike #-}
fromBits :: [Bit] -> Integer
fromBits = L.foldl (\v b -> v `shiftL` 1 .|. fromIntegral b) 0
bitPattern :: String -> Q Pat
bitPattern s = [p| (($mask .&.) -> $target) |]
where
bs = parse <$> s
mask = litE . IntegerL . fromBits $ maybe 0 (const 1) <$> bs
target = litP . IntegerL . fromBits $ fromMaybe 0 <$> bs
parse '.' = Nothing
parse '0' = Just 0
parse '1' = Just 1
parse c = error $ "Invalid bit pattern: " ++ show c