{-|
Copyright  :  (C) 2013-2016, University of Twente,
                  2016-2017, Myrtle Software Ltd
License    :  BSD2 (see the file LICENSE)
Maintainer :  Christiaan Baaij <christiaan.baaij@gmail.com>
-}

{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

{-# LANGUAGE Trustworthy #-}

{-# OPTIONS_GHC -fplugin=GHC.TypeLits.Extra.Solver #-}
{-# OPTIONS_GHC -fplugin=GHC.TypeLits.Normalise #-}
{-# OPTIONS_GHC -fplugin=GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_HADDOCK show-extensions #-}

#include "MachDeps.h"

module Clash.Class.BitPack
  ( BitPack (..)
  , bitCoerce
  , bitCoerceMap
  , boolToBV
  , boolToBit
  , bitToBool
  , packXWith

  -- * Internals
  , GBitPack(..)
  )
where

import Control.Exception              (catch, evaluate)
import Data.Binary.IEEE754            (doubleToWord, floatToWord, wordToDouble,
                                       wordToFloat)

#if MIN_VERSION_base(4,12,0)
import Data.Complex                   (Complex)
import Data.Ord                       (Down)
#endif

import Data.Int
import Data.Word
import Foreign.C.Types                (CUShort)
import GHC.TypeLits                   (KnownNat, Nat, type (+), type (-))
import Numeric.Half                   (Half (..))
import GHC.Generics
import GHC.TypeLits.Extra             (CLog, Max)
import Prelude                        hiding (map)
import System.IO.Unsafe               (unsafeDupablePerformIO)

import Clash.Promoted.Nat             (SNat(..), snatToNum)
import Clash.Class.BitPack.Internal   (deriveBitPackTuples)
import Clash.Class.Resize             (zeroExtend, resize)
import Clash.Sized.BitVector          (Bit, BitVector, (++#))
import Clash.Sized.Internal.BitVector
  (pack#, split#, checkUnpackUndef, undefined#, unpack#, unsafeToNatural)
import Clash.XException

{- $setup
>>> :set -XDataKinds
>>> import Clash.Prelude
-}

-- | Convert to and from a 'BitVector'
class KnownNat (BitSize a) => BitPack a where
  -- | Number of 'Clash.Sized.BitVector.Bit's needed to represents elements
  -- of type @a@
  --
  -- Can be derived using `GHC.Generics`:
  --
  -- > import Clash.Prelude
  -- > import GHC.Generics
  -- >
  -- > data MyProductType = MyProductType { a :: Int, b :: Bool }
  -- >   deriving (Generic, BitPack)
  type BitSize a :: Nat
  type BitSize a = (CLog 2 (GConstructorCount (Rep a))) + (GFieldSize (Rep a))
  -- | Convert element of type @a@ to a 'BitVector'
  --
  -- >>> pack (-5 :: Signed 6)
  -- 11_1011
  pack   :: a -> BitVector (BitSize a)
  default pack
    :: ( Generic a
       , GBitPack (Rep a)
       , KnownNat (BitSize a)
       , KnownNat constrSize
       , KnownNat fieldSize
       , constrSize ~ CLog 2 (GConstructorCount (Rep a))
       , fieldSize ~ GFieldSize (Rep a)
       , (constrSize + fieldSize) ~ BitSize a
       )
    => a -> BitVector (BitSize a)
  pack = (a -> BitVector (BitSize a)) -> a -> BitVector (BitSize a)
forall (n :: Nat) a.
KnownNat n =>
(a -> BitVector n) -> a -> BitVector n
packXWith a -> BitVector (BitSize a)
forall a (n :: Nat).
(GBitPack (Rep a), Generic a, KnownNat (GFieldSize (Rep a)),
 KnownNat n) =>
a -> BitVector (n + GFieldSize (Rep a))
go
   where
    go :: a -> BitVector (n + GFieldSize (Rep a))
go a
a = BitVector 64 -> BitVector n
forall (f :: Nat -> Type) (a :: Nat) (b :: Nat).
(Resize f, KnownNat a, KnownNat b) =>
f a -> f b
resize (Int -> BitVector (BitSize Int)
forall a. BitPack a => a -> BitVector (BitSize a)
pack Int
sc) BitVector n
-> BitVector (GFieldSize (Rep a))
-> BitVector (n + GFieldSize (Rep a))
forall (m :: Nat) (n :: Nat).
KnownNat m =>
BitVector n -> BitVector m -> BitVector (n + m)
++# BitVector (GFieldSize (Rep a))
packedFields
     where
      (Int
sc, BitVector (GFieldSize (Rep a))
packedFields) = Int -> Rep a Any -> (Int, BitVector (GFieldSize (Rep a)))
forall (f :: Type -> Type) a.
GBitPack f =>
Int -> f a -> (Int, BitVector (GFieldSize f))
gPackFields Int
0 (a -> Rep a Any
forall a x. Generic a => a -> Rep a x
from a
a)

  -- | Convert a 'BitVector' to an element of type @a@
  --
  -- >>> pack (-5 :: Signed 6)
  -- 11_1011
  -- >>> let x = pack (-5 :: Signed 6)
  -- >>> unpack x :: Unsigned 6
  -- 59
  -- >>> pack (59 :: Unsigned 6)
  -- 11_1011
  unpack :: BitVector (BitSize a) -> a
  default unpack
    :: ( Generic a
       , GBitPack (Rep a)
       , KnownNat constrSize
       , KnownNat fieldSize
       , constrSize ~ CLog 2 (GConstructorCount (Rep a))
       , fieldSize ~ GFieldSize (Rep a)
       , (constrSize + fieldSize) ~ BitSize a
       )
    => BitVector (BitSize a) -> a
  unpack BitVector (BitSize a)
b =
    Rep a Any -> a
forall a x. Generic a => Rep a x -> a
to (Int -> Int -> BitVector (GFieldSize (Rep a)) -> Rep a Any
forall (f :: Type -> Type) a.
GBitPack f =>
Int -> Int -> BitVector (GFieldSize f) -> f a
gUnpack Int
sc Int
0 BitVector fieldSize
BitVector (GFieldSize (Rep a))
bFields)
   where
    ((BitVector 64 -> Int) -> BitVector 64 -> Int
forall (n :: Nat) a.
(KnownNat n, Typeable a) =>
(BitVector n -> a) -> BitVector n -> a
checkUnpackUndef BitVector 64 -> Int
forall a. BitPack a => BitVector (BitSize a) -> a
unpack (BitVector 64 -> Int)
-> (BitVector constrSize -> BitVector 64)
-> BitVector constrSize
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BitVector constrSize -> BitVector 64
forall (f :: Nat -> Type) (a :: Nat) (b :: Nat).
(Resize f, KnownNat a, KnownNat b) =>
f a -> f b
resize -> Int
sc, BitVector fieldSize
bFields) = BitVector (constrSize + fieldSize)
-> (BitVector constrSize, BitVector fieldSize)
forall (n :: Nat) (m :: Nat).
KnownNat n =>
BitVector (m + n) -> (BitVector m, BitVector n)
split# BitVector (constrSize + fieldSize)
BitVector (BitSize a)
b

packXWith
  :: KnownNat n
  => (a -> BitVector n)
  -> a
  -> BitVector n
packXWith :: (a -> BitVector n) -> a -> BitVector n
packXWith a -> BitVector n
f a
x =
  IO (BitVector n) -> BitVector n
forall a. IO a -> a
unsafeDupablePerformIO (IO (BitVector n)
-> (XException -> IO (BitVector n)) -> IO (BitVector n)
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch (a -> BitVector n
f (a -> BitVector n) -> IO a -> IO (BitVector n)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> IO a
forall a. a -> IO a
evaluate a
x)
                                (\(XException String
_) -> BitVector n -> IO (BitVector n)
forall (m :: Type -> Type) a. Monad m => a -> m a
return BitVector n
forall (n :: Nat). KnownNat n => BitVector n
undefined#))
{-# NOINLINE packXWith #-}

{-# INLINE[1] bitCoerce #-}
-- | Coerce a value from one type to another through its bit representation.
--
-- >>> pack (-5 :: Signed 6)
-- 11_1011
-- >>> bitCoerce (-5 :: Signed 6) :: Unsigned 6
-- 59
-- >>> pack (59 :: Unsigned 6)
-- 11_1011
bitCoerce
  :: (BitPack a, BitPack b, BitSize a ~ BitSize b)
  => a
  -> b
bitCoerce :: a -> b
bitCoerce = BitVector (BitSize b) -> b
forall a. BitPack a => BitVector (BitSize a) -> a
unpack (BitVector (BitSize b) -> b)
-> (a -> BitVector (BitSize b)) -> a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> BitVector (BitSize b)
forall a. BitPack a => a -> BitVector (BitSize a)
pack

-- | Map a value by first coercing to another type through its bit representation.
--
-- >>> pack (-5 :: Signed 32)
-- 1111_1111_1111_1111_1111_1111_1111_1011
-- >>> bitCoerceMap @(Vec 4 (BitVector 8)) (replace 1 0) (-5 :: Signed 32)
-- -16711685
-- >>> pack (-16711685 :: Signed 32)
-- 1111_1111_0000_0000_1111_1111_1111_1011
bitCoerceMap
  :: forall a b . (BitPack a, BitPack b, BitSize a ~ BitSize b)
  => (a -> a)
  -> b
  -> b
bitCoerceMap :: (a -> a) -> b -> b
bitCoerceMap a -> a
f = a -> b
forall a b. (BitPack a, BitPack b, BitSize a ~ BitSize b) => a -> b
bitCoerce (a -> b) -> (b -> a) -> b -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
f (a -> a) -> (b -> a) -> b -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> a
forall a b. (BitPack a, BitPack b, BitSize a ~ BitSize b) => a -> b
bitCoerce

instance BitPack Bool where
  type BitSize Bool = 1
  pack :: Bool -> BitVector (BitSize Bool)
pack   = let go :: Bool -> p
go Bool
b = if Bool
b then p
1 else p
0 in (Bool -> BitVector 1) -> Bool -> BitVector 1
forall (n :: Nat) a.
KnownNat n =>
(a -> BitVector n) -> a -> BitVector n
packXWith Bool -> BitVector 1
forall p. Num p => Bool -> p
go
  unpack :: BitVector (BitSize Bool) -> Bool
unpack = (BitVector 1 -> Bool) -> BitVector 1 -> Bool
forall (n :: Nat) a.
(KnownNat n, Typeable a) =>
(BitVector n -> a) -> BitVector n -> a
checkUnpackUndef ((BitVector 1 -> Bool) -> BitVector 1 -> Bool)
-> (BitVector 1 -> Bool) -> BitVector 1 -> Bool
forall a b. (a -> b) -> a -> b
$ \BitVector 1
bv -> if BitVector 1
bv BitVector 1 -> BitVector 1 -> Bool
forall a. Eq a => a -> a -> Bool
== BitVector 1
1 then Bool
True else Bool
False

instance KnownNat n => BitPack (BitVector n) where
  type BitSize (BitVector n) = n
  pack :: BitVector n -> BitVector (BitSize (BitVector n))
pack     = (BitVector n -> BitVector n) -> BitVector n -> BitVector n
forall (n :: Nat) a.
KnownNat n =>
(a -> BitVector n) -> a -> BitVector n
packXWith BitVector n -> BitVector n
forall a. a -> a
id
  unpack :: BitVector (BitSize (BitVector n)) -> BitVector n
unpack BitVector (BitSize (BitVector n))
v = BitVector n
BitVector (BitSize (BitVector n))
v

instance BitPack Bit where
  type BitSize Bit = 1
  pack :: Bit -> BitVector (BitSize Bit)
pack   = (Bit -> BitVector 1) -> Bit -> BitVector 1
forall (n :: Nat) a.
KnownNat n =>
(a -> BitVector n) -> a -> BitVector n
packXWith Bit -> BitVector 1
pack#
  unpack :: BitVector (BitSize Bit) -> Bit
unpack = BitVector 1 -> Bit
BitVector (BitSize Bit) -> Bit
unpack#

instance BitPack Int where
  type BitSize Int = WORD_SIZE_IN_BITS
  pack :: Int -> BitVector (BitSize Int)
pack   = (Int -> BitVector 64) -> Int -> BitVector 64
forall (n :: Nat) a.
KnownNat n =>
(a -> BitVector n) -> a -> BitVector n
packXWith Int -> BitVector 64
forall a b. (Integral a, Num b) => a -> b
fromIntegral
  unpack :: BitVector (BitSize Int) -> Int
unpack = (BitVector 64 -> Int) -> BitVector 64 -> Int
forall (n :: Nat) a.
(KnownNat n, Typeable a) =>
(BitVector n -> a) -> BitVector n -> a
checkUnpackUndef BitVector 64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral

instance BitPack Int8 where
  type BitSize Int8 = 8
  pack :: Int8 -> BitVector (BitSize Int8)
pack   = (Int8 -> BitVector 8) -> Int8 -> BitVector 8
forall (n :: Nat) a.
KnownNat n =>
(a -> BitVector n) -> a -> BitVector n
packXWith Int8 -> BitVector 8
forall a b. (Integral a, Num b) => a -> b
fromIntegral
  unpack :: BitVector (BitSize Int8) -> Int8
unpack = (BitVector 8 -> Int8) -> BitVector 8 -> Int8
forall (n :: Nat) a.
(KnownNat n, Typeable a) =>
(BitVector n -> a) -> BitVector n -> a
checkUnpackUndef BitVector 8 -> Int8
forall a b. (Integral a, Num b) => a -> b
fromIntegral

instance BitPack Int16 where
  type BitSize Int16 = 16
  pack :: Int16 -> BitVector (BitSize Int16)
pack   = (Int16 -> BitVector 16) -> Int16 -> BitVector 16
forall (n :: Nat) a.
KnownNat n =>
(a -> BitVector n) -> a -> BitVector n
packXWith Int16 -> BitVector 16
forall a b. (Integral a, Num b) => a -> b
fromIntegral
  unpack :: BitVector (BitSize Int16) -> Int16
unpack = (BitVector 16 -> Int16) -> BitVector 16 -> Int16
forall (n :: Nat) a.
(KnownNat n, Typeable a) =>
(BitVector n -> a) -> BitVector n -> a
checkUnpackUndef BitVector 16 -> Int16
forall a b. (Integral a, Num b) => a -> b
fromIntegral

instance BitPack Int32 where
  type BitSize Int32 = 32
  pack :: Int32 -> BitVector (BitSize Int32)
pack   = (Int32 -> BitVector 32) -> Int32 -> BitVector 32
forall (n :: Nat) a.
KnownNat n =>
(a -> BitVector n) -> a -> BitVector n
packXWith Int32 -> BitVector 32
forall a b. (Integral a, Num b) => a -> b
fromIntegral
  unpack :: BitVector (BitSize Int32) -> Int32
unpack = (BitVector 32 -> Int32) -> BitVector 32 -> Int32
forall (n :: Nat) a.
(KnownNat n, Typeable a) =>
(BitVector n -> a) -> BitVector n -> a
checkUnpackUndef BitVector 32 -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral

instance BitPack Int64 where
  type BitSize Int64 = 64
  pack :: Int64 -> BitVector (BitSize Int64)
pack   = (Int64 -> BitVector 64) -> Int64 -> BitVector 64
forall (n :: Nat) a.
KnownNat n =>
(a -> BitVector n) -> a -> BitVector n
packXWith Int64 -> BitVector 64
forall a b. (Integral a, Num b) => a -> b
fromIntegral
  unpack :: BitVector (BitSize Int64) -> Int64
unpack = (BitVector 64 -> Int64) -> BitVector 64 -> Int64
forall (n :: Nat) a.
(KnownNat n, Typeable a) =>
(BitVector n -> a) -> BitVector n -> a
checkUnpackUndef BitVector 64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral

instance BitPack Word where
  type BitSize Word = WORD_SIZE_IN_BITS
  pack :: Word -> BitVector (BitSize Word)
pack   = (Word -> BitVector 64) -> Word -> BitVector 64
forall (n :: Nat) a.
KnownNat n =>
(a -> BitVector n) -> a -> BitVector n
packXWith Word -> BitVector 64
forall a b. (Integral a, Num b) => a -> b
fromIntegral
  unpack :: BitVector (BitSize Word) -> Word
unpack = (BitVector 64 -> Word) -> BitVector 64 -> Word
forall (n :: Nat) a.
(KnownNat n, Typeable a) =>
(BitVector n -> a) -> BitVector n -> a
checkUnpackUndef BitVector 64 -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral

instance BitPack Word8 where
  type BitSize Word8 = 8
  pack :: Word8 -> BitVector (BitSize Word8)
pack   = (Word8 -> BitVector 8) -> Word8 -> BitVector 8
forall (n :: Nat) a.
KnownNat n =>
(a -> BitVector n) -> a -> BitVector n
packXWith Word8 -> BitVector 8
forall a b. (Integral a, Num b) => a -> b
fromIntegral
  unpack :: BitVector (BitSize Word8) -> Word8
unpack = (BitVector 8 -> Word8) -> BitVector 8 -> Word8
forall (n :: Nat) a.
(KnownNat n, Typeable a) =>
(BitVector n -> a) -> BitVector n -> a
checkUnpackUndef BitVector 8 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral

instance BitPack Word16 where
  type BitSize Word16 = 16
  pack :: Word16 -> BitVector (BitSize Word16)
pack   = (Word16 -> BitVector 16) -> Word16 -> BitVector 16
forall (n :: Nat) a.
KnownNat n =>
(a -> BitVector n) -> a -> BitVector n
packXWith Word16 -> BitVector 16
forall a b. (Integral a, Num b) => a -> b
fromIntegral
  unpack :: BitVector (BitSize Word16) -> Word16
unpack = (BitVector 16 -> Word16) -> BitVector 16 -> Word16
forall (n :: Nat) a.
(KnownNat n, Typeable a) =>
(BitVector n -> a) -> BitVector n -> a
checkUnpackUndef BitVector 16 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral

instance BitPack Word32 where
  type BitSize Word32 = 32
  pack :: Word32 -> BitVector (BitSize Word32)
pack   = (Word32 -> BitVector 32) -> Word32 -> BitVector 32
forall (n :: Nat) a.
KnownNat n =>
(a -> BitVector n) -> a -> BitVector n
packXWith Word32 -> BitVector 32
forall a b. (Integral a, Num b) => a -> b
fromIntegral
  unpack :: BitVector (BitSize Word32) -> Word32
unpack = (BitVector 32 -> Word32) -> BitVector 32 -> Word32
forall (n :: Nat) a.
(KnownNat n, Typeable a) =>
(BitVector n -> a) -> BitVector n -> a
checkUnpackUndef BitVector 32 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral

instance BitPack Word64 where
  type BitSize Word64 = 64
  pack :: Word64 -> BitVector (BitSize Word64)
pack   = (Word64 -> BitVector 64) -> Word64 -> BitVector 64
forall (n :: Nat) a.
KnownNat n =>
(a -> BitVector n) -> a -> BitVector n
packXWith Word64 -> BitVector 64
forall a b. (Integral a, Num b) => a -> b
fromIntegral
  unpack :: BitVector (BitSize Word64) -> Word64
unpack = (BitVector 64 -> Word64) -> BitVector 64 -> Word64
forall (n :: Nat) a.
(KnownNat n, Typeable a) =>
(BitVector n -> a) -> BitVector n -> a
checkUnpackUndef BitVector 64 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral

instance BitPack Float where
  type BitSize Float = 32
  pack :: Float -> BitVector (BitSize Float)
pack   = (Float -> BitVector 32) -> Float -> BitVector 32
forall (n :: Nat) a.
KnownNat n =>
(a -> BitVector n) -> a -> BitVector n
packXWith Float -> BitVector 32
packFloat#
  unpack :: BitVector (BitSize Float) -> Float
unpack = (BitVector 32 -> Float) -> BitVector 32 -> Float
forall (n :: Nat) a.
(KnownNat n, Typeable a) =>
(BitVector n -> a) -> BitVector n -> a
checkUnpackUndef BitVector 32 -> Float
unpackFloat#

packFloat# :: Float -> BitVector 32
packFloat# :: Float -> BitVector 32
packFloat# = Word32 -> BitVector 32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> BitVector 32)
-> (Float -> Word32) -> Float -> BitVector 32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Float -> Word32
floatToWord
{-# NOINLINE packFloat# #-}

unpackFloat# :: BitVector 32 -> Float
unpackFloat# :: BitVector 32 -> Float
unpackFloat# (BitVector 32 -> Natural
forall (n :: Nat). BitVector n -> Natural
unsafeToNatural -> Natural
w) = Word32 -> Float
wordToFloat (Natural -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Natural
w)
{-# NOINLINE unpackFloat# #-}

instance BitPack Double where
  type BitSize Double = 64
  pack :: Double -> BitVector (BitSize Double)
pack   = (Double -> BitVector 64) -> Double -> BitVector 64
forall (n :: Nat) a.
KnownNat n =>
(a -> BitVector n) -> a -> BitVector n
packXWith Double -> BitVector 64
packDouble#
  unpack :: BitVector (BitSize Double) -> Double
unpack = (BitVector 64 -> Double) -> BitVector 64 -> Double
forall (n :: Nat) a.
(KnownNat n, Typeable a) =>
(BitVector n -> a) -> BitVector n -> a
checkUnpackUndef BitVector 64 -> Double
unpackDouble#

packDouble# :: Double -> BitVector 64
packDouble# :: Double -> BitVector 64
packDouble# = Word64 -> BitVector 64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64 -> BitVector 64)
-> (Double -> Word64) -> Double -> BitVector 64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Word64
doubleToWord
{-# NOINLINE packDouble# #-}

unpackDouble# :: BitVector 64 -> Double
unpackDouble# :: BitVector 64 -> Double
unpackDouble# (BitVector 64 -> Natural
forall (n :: Nat). BitVector n -> Natural
unsafeToNatural -> Natural
w) = Word64 -> Double
wordToDouble (Natural -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Natural
w)
{-# NOINLINE unpackDouble# #-}

instance BitPack CUShort where
  type BitSize CUShort = 16
  pack :: CUShort -> BitVector (BitSize CUShort)
pack   = (CUShort -> BitVector 16) -> CUShort -> BitVector 16
forall (n :: Nat) a.
KnownNat n =>
(a -> BitVector n) -> a -> BitVector n
packXWith CUShort -> BitVector 16
forall a b. (Integral a, Num b) => a -> b
fromIntegral
  unpack :: BitVector (BitSize CUShort) -> CUShort
unpack = (BitVector 16 -> CUShort) -> BitVector 16 -> CUShort
forall (n :: Nat) a.
(KnownNat n, Typeable a) =>
(BitVector n -> a) -> BitVector n -> a
checkUnpackUndef BitVector 16 -> CUShort
forall a b. (Integral a, Num b) => a -> b
fromIntegral

instance BitPack Half where
  type BitSize Half = 16
  pack :: Half -> BitVector (BitSize Half)
pack (Half CUShort
x) = CUShort -> BitVector (BitSize CUShort)
forall a. BitPack a => a -> BitVector (BitSize a)
pack CUShort
x
  unpack :: BitVector (BitSize Half) -> Half
unpack        = (BitVector 16 -> Half) -> BitVector 16 -> Half
forall (n :: Nat) a.
(KnownNat n, Typeable a) =>
(BitVector n -> a) -> BitVector n -> a
checkUnpackUndef ((BitVector 16 -> Half) -> BitVector 16 -> Half)
-> (BitVector 16 -> Half) -> BitVector 16 -> Half
forall a b. (a -> b) -> a -> b
$ \BitVector 16
x -> CUShort -> Half
Half (BitVector (BitSize CUShort) -> CUShort
forall a. BitPack a => BitVector (BitSize a) -> a
unpack BitVector 16
BitVector (BitSize CUShort)
x)

instance BitPack () where
  type BitSize () = 0
  pack :: () -> BitVector (BitSize ())
pack   ()
_ = BitVector (BitSize ())
forall a. Bounded a => a
minBound
  unpack :: BitVector (BitSize ()) -> ()
unpack BitVector (BitSize ())
_ = ()

-- | __N.B.__: The documentation only shows instances up to /3/-tuples. By
-- default, instances up to and including /12/-tuples will exist. If the flag
-- @large-tuples@ is set instances up to the GHC imposed limit will exist. The
-- GHC imposed limit is either 62 or 64 depending on the GHC version.
instance (BitPack a, BitPack b) =>
    BitPack (a,b) where
  type BitSize (a,b) = BitSize a + BitSize b
  pack :: (a, b) -> BitVector (BitSize (a, b))
pack = let go :: (a, a) -> BitVector (BitSize a + BitSize a)
go (a
a,a
b) = a -> BitVector (BitSize a)
forall a. BitPack a => a -> BitVector (BitSize a)
pack a
a BitVector (BitSize a)
-> BitVector (BitSize a) -> BitVector (BitSize a + BitSize a)
forall (m :: Nat) (n :: Nat).
KnownNat m =>
BitVector n -> BitVector m -> BitVector (n + m)
++# a -> BitVector (BitSize a)
forall a. BitPack a => a -> BitVector (BitSize a)
pack a
b in ((a, b) -> BitVector (BitSize a + BitSize b))
-> (a, b) -> BitVector (BitSize a + BitSize b)
forall (n :: Nat) a.
KnownNat n =>
(a -> BitVector n) -> a -> BitVector n
packXWith (a, b) -> BitVector (BitSize a + BitSize b)
forall a a.
(BitPack a, BitPack a) =>
(a, a) -> BitVector (BitSize a + BitSize a)
go
  unpack :: BitVector (BitSize (a, b)) -> (a, b)
unpack BitVector (BitSize (a, b))
ab  = let (BitVector (BitSize a)
a,BitVector (BitSize b)
b) = BitVector (BitSize a + BitSize b)
-> (BitVector (BitSize a), BitVector (BitSize b))
forall (n :: Nat) (m :: Nat).
KnownNat n =>
BitVector (m + n) -> (BitVector m, BitVector n)
split# BitVector (BitSize a + BitSize b)
BitVector (BitSize (a, b))
ab in (BitVector (BitSize a) -> a
forall a. BitPack a => BitVector (BitSize a) -> a
unpack BitVector (BitSize a)
a, BitVector (BitSize b) -> b
forall a. BitPack a => BitVector (BitSize a) -> a
unpack BitVector (BitSize b)
b)

class GBitPack f where
  -- | Size of fields. If multiple constructors exist, this is the maximum of
  -- the sum of each of the constructors fields.
  type GFieldSize f :: Nat

  -- | Number of constructors this type has. Indirectly indicates how many bits
  -- are needed to represent the constructor.
  type GConstructorCount f :: Nat

  -- | Pack fields of a type. Caller should pack and prepend the constructor bits.
  gPackFields
    :: Int
    -- ^ Current constructor
    -> f a
    -- ^ Data to pack
    -> (Int, BitVector (GFieldSize f))
    -- ^ (Constructor number, Packed fields)

  -- | Unpack whole type.
  gUnpack
    :: Int
    -- ^ Construct with constructor /n/
    -> Int
    -- ^ Current constructor
    -> BitVector (GFieldSize f)
    -- ^ BitVector containing fields
    -> f a
    -- ^ Unpacked result

instance GBitPack a => GBitPack (M1 m d a) where
  type GFieldSize (M1 m d a) = GFieldSize a
  type GConstructorCount (M1 m d a) = GConstructorCount a

  gPackFields :: Int -> M1 m d a a -> (Int, BitVector (GFieldSize (M1 m d a)))
gPackFields Int
cc (M1 a a
m1) = Int -> a a -> (Int, BitVector (GFieldSize a))
forall (f :: Type -> Type) a.
GBitPack f =>
Int -> f a -> (Int, BitVector (GFieldSize f))
gPackFields Int
cc a a
m1
  gUnpack :: Int -> Int -> BitVector (GFieldSize (M1 m d a)) -> M1 m d a a
gUnpack Int
c Int
cc BitVector (GFieldSize (M1 m d a))
b = a a -> M1 m d a a
forall k i (c :: Meta) (f :: k -> Type) (p :: k). f p -> M1 i c f p
M1 (Int -> Int -> BitVector (GFieldSize a) -> a a
forall (f :: Type -> Type) a.
GBitPack f =>
Int -> Int -> BitVector (GFieldSize f) -> f a
gUnpack Int
c Int
cc BitVector (GFieldSize a)
BitVector (GFieldSize (M1 m d a))
b)

instance ( KnownNat (GFieldSize g)
         , KnownNat (GFieldSize f)
         , KnownNat (GConstructorCount f)
         , GBitPack f
         , GBitPack g
         ) => GBitPack (f :+: g) where
  type GFieldSize (f :+: g) = Max (GFieldSize f) (GFieldSize g)
  type GConstructorCount (f :+: g) = GConstructorCount f + GConstructorCount g

  gPackFields :: Int -> (:+:) f g a -> (Int, BitVector (GFieldSize (f :+: g)))
gPackFields Int
cc (L1 f a
l) =
    let (Int
sc, BitVector (GFieldSize f)
packed) = Int -> f a -> (Int, BitVector (GFieldSize f))
forall (f :: Type -> Type) a.
GBitPack f =>
Int -> f a -> (Int, BitVector (GFieldSize f))
gPackFields Int
cc f a
l in
    let padding :: BitVector (Max (GFieldSize f) (GFieldSize g) - GFieldSize f)
padding = BitVector (Max (GFieldSize f) (GFieldSize g) - GFieldSize f)
forall (n :: Nat). KnownNat n => BitVector n
undefined# :: BitVector (Max (GFieldSize f) (GFieldSize g) - GFieldSize f) in
    (Int
sc, BitVector (GFieldSize f)
packed BitVector (GFieldSize f)
-> BitVector (Max (GFieldSize f) (GFieldSize g) - GFieldSize f)
-> BitVector
     (GFieldSize f + (Max (GFieldSize f) (GFieldSize g) - GFieldSize f))
forall (m :: Nat) (n :: Nat).
KnownNat m =>
BitVector n -> BitVector m -> BitVector (n + m)
++# BitVector (Max (GFieldSize f) (GFieldSize g) - GFieldSize f)
padding)
  gPackFields Int
cc (R1 g a
r) =
    let cLeft :: Int
cLeft = SNat (GConstructorCount f) -> Int
forall a (n :: Nat). Num a => SNat n -> a
snatToNum (KnownNat (GConstructorCount f) => SNat (GConstructorCount f)
forall (n :: Nat). KnownNat n => SNat n
SNat @(GConstructorCount f)) in
    let (Int
sc, BitVector (GFieldSize g)
packed) = Int -> g a -> (Int, BitVector (GFieldSize g))
forall (f :: Type -> Type) a.
GBitPack f =>
Int -> f a -> (Int, BitVector (GFieldSize f))
gPackFields (Int
cc Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
cLeft) g a
r in
    let padding :: BitVector (Max (GFieldSize f) (GFieldSize g) - GFieldSize g)
padding = BitVector (Max (GFieldSize f) (GFieldSize g) - GFieldSize g)
forall (n :: Nat). KnownNat n => BitVector n
undefined# :: BitVector (Max (GFieldSize f) (GFieldSize g) - GFieldSize g) in
    (Int
sc, BitVector (GFieldSize g)
packed BitVector (GFieldSize g)
-> BitVector (Max (GFieldSize f) (GFieldSize g) - GFieldSize g)
-> BitVector
     (GFieldSize g + (Max (GFieldSize f) (GFieldSize g) - GFieldSize g))
forall (m :: Nat) (n :: Nat).
KnownNat m =>
BitVector n -> BitVector m -> BitVector (n + m)
++# BitVector (Max (GFieldSize f) (GFieldSize g) - GFieldSize g)
padding)

  gUnpack :: Int -> Int -> BitVector (GFieldSize (f :+: g)) -> (:+:) f g a
gUnpack Int
c Int
cc BitVector (GFieldSize (f :+: g))
b =
    let cLeft :: Int
cLeft = SNat (GConstructorCount f) -> Int
forall a (n :: Nat). Num a => SNat n -> a
snatToNum (KnownNat (GConstructorCount f) => SNat (GConstructorCount f)
forall (n :: Nat). KnownNat n => SNat n
SNat @(GConstructorCount f)) in
    if Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
cc Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
cLeft then
      f a -> (:+:) f g a
forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> (:+:) f g p
L1 (Int -> Int -> BitVector (GFieldSize f) -> f a
forall (f :: Type -> Type) a.
GBitPack f =>
Int -> Int -> BitVector (GFieldSize f) -> f a
gUnpack Int
c Int
cc BitVector (GFieldSize f)
f)
    else
      g a -> (:+:) f g a
forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
g p -> (:+:) f g p
R1 (Int -> Int -> BitVector (GFieldSize g) -> g a
forall (f :: Type -> Type) a.
GBitPack f =>
Int -> Int -> BitVector (GFieldSize f) -> f a
gUnpack Int
c (Int
cc Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
cLeft) BitVector (GFieldSize g)
g)

   where
    -- It's a thing of beauty, if I may say so myself!
    (BitVector (GFieldSize f)
f, _ :: BitVector (Max (GFieldSize f) (GFieldSize g) - GFieldSize f)) = BitVector
  (GFieldSize f + (Max (GFieldSize f) (GFieldSize g) - GFieldSize f))
-> (BitVector (GFieldSize f),
    BitVector (Max (GFieldSize f) (GFieldSize g) - GFieldSize f))
forall (n :: Nat) (m :: Nat).
KnownNat n =>
BitVector (m + n) -> (BitVector m, BitVector n)
split# BitVector
  (GFieldSize f + (Max (GFieldSize f) (GFieldSize g) - GFieldSize f))
BitVector (GFieldSize (f :+: g))
b
    (BitVector (GFieldSize g)
g, _ :: BitVector (Max (GFieldSize f) (GFieldSize g) - GFieldSize g)) = BitVector
  (GFieldSize g + (Max (GFieldSize f) (GFieldSize g) - GFieldSize g))
-> (BitVector (GFieldSize g),
    BitVector (Max (GFieldSize f) (GFieldSize g) - GFieldSize g))
forall (n :: Nat) (m :: Nat).
KnownNat n =>
BitVector (m + n) -> (BitVector m, BitVector n)
split# BitVector
  (GFieldSize g + (Max (GFieldSize f) (GFieldSize g) - GFieldSize g))
BitVector (GFieldSize (f :+: g))
b


instance (KnownNat (GFieldSize g), KnownNat (GFieldSize f), GBitPack f, GBitPack g) => GBitPack (f :*: g) where
  type GFieldSize (f :*: g) = GFieldSize f + GFieldSize g
  type GConstructorCount (f :*: g) = 1

  gPackFields :: Int -> (:*:) f g a -> (Int, BitVector (GFieldSize (f :*: g)))
gPackFields Int
cc (:*:) f g a
fg =
    (Int
cc, ((:*:) f g a -> BitVector (GFieldSize f + GFieldSize g))
-> (:*:) f g a -> BitVector (GFieldSize f + GFieldSize g)
forall (n :: Nat) a.
KnownNat n =>
(a -> BitVector n) -> a -> BitVector n
packXWith (:*:) f g a -> BitVector (GFieldSize f + GFieldSize g)
go (:*:) f g a
fg)
   where
    go :: (:*:) f g a -> BitVector (GFieldSize f + GFieldSize g)
go (f a
l0 :*: g a
r0) =
      let (Int
_, BitVector (GFieldSize f)
l1) = Int -> f a -> (Int, BitVector (GFieldSize f))
forall (f :: Type -> Type) a.
GBitPack f =>
Int -> f a -> (Int, BitVector (GFieldSize f))
gPackFields Int
cc f a
l0 in
      let (Int
_, BitVector (GFieldSize g)
r1) = Int -> g a -> (Int, BitVector (GFieldSize g))
forall (f :: Type -> Type) a.
GBitPack f =>
Int -> f a -> (Int, BitVector (GFieldSize f))
gPackFields Int
cc g a
r0 in
      BitVector (GFieldSize f)
l1 BitVector (GFieldSize f)
-> BitVector (GFieldSize g)
-> BitVector (GFieldSize f + GFieldSize g)
forall (m :: Nat) (n :: Nat).
KnownNat m =>
BitVector n -> BitVector m -> BitVector (n + m)
++# BitVector (GFieldSize g)
r1

  gUnpack :: Int -> Int -> BitVector (GFieldSize (f :*: g)) -> (:*:) f g a
gUnpack Int
c Int
cc BitVector (GFieldSize (f :*: g))
b =
    Int -> Int -> BitVector (GFieldSize f) -> f a
forall (f :: Type -> Type) a.
GBitPack f =>
Int -> Int -> BitVector (GFieldSize f) -> f a
gUnpack Int
c Int
cc BitVector (GFieldSize f)
front f a -> g a -> (:*:) f g a
forall k (f :: k -> Type) (g :: k -> Type) (p :: k).
f p -> g p -> (:*:) f g p
:*: Int -> Int -> BitVector (GFieldSize g) -> g a
forall (f :: Type -> Type) a.
GBitPack f =>
Int -> Int -> BitVector (GFieldSize f) -> f a
gUnpack Int
c Int
cc BitVector (GFieldSize g)
back
   where
    (BitVector (GFieldSize f)
front, BitVector (GFieldSize g)
back) = BitVector (GFieldSize f + GFieldSize g)
-> (BitVector (GFieldSize f), BitVector (GFieldSize g))
forall (n :: Nat) (m :: Nat).
KnownNat n =>
BitVector (m + n) -> (BitVector m, BitVector n)
split# BitVector (GFieldSize f + GFieldSize g)
BitVector (GFieldSize (f :*: g))
b

instance BitPack c => GBitPack (K1 i c) where
  type GFieldSize (K1 i c) = BitSize c
  type GConstructorCount (K1 i c)  = 1

  gPackFields :: Int -> K1 i c a -> (Int, BitVector (GFieldSize (K1 i c)))
gPackFields Int
cc (K1 c
i) = (Int
cc, c -> BitVector (BitSize c)
forall a. BitPack a => a -> BitVector (BitSize a)
pack c
i)
  gUnpack :: Int -> Int -> BitVector (GFieldSize (K1 i c)) -> K1 i c a
gUnpack Int
_c Int
_cc BitVector (GFieldSize (K1 i c))
b      = c -> K1 i c a
forall k i c (p :: k). c -> K1 i c p
K1 (BitVector (BitSize c) -> c
forall a. BitPack a => BitVector (BitSize a) -> a
unpack BitVector (GFieldSize (K1 i c))
BitVector (BitSize c)
b)

instance GBitPack U1 where
  type GFieldSize U1 = 0
  type GConstructorCount U1 = 1

  gPackFields :: Int -> U1 a -> (Int, BitVector (GFieldSize U1))
gPackFields Int
cc U1 a
U1 = (Int
cc, BitVector (GFieldSize U1)
0)
  gUnpack :: Int -> Int -> BitVector (GFieldSize U1) -> U1 a
gUnpack Int
_c Int
_cc BitVector (GFieldSize U1)
_b = U1 a
forall k (p :: k). U1 p
U1

-- Instances derived using Generic
instance ( BitPack a
         , BitPack b
         ) => BitPack (Either a b)

instance BitPack a => BitPack (Maybe a)

#if MIN_VERSION_base(4,12,0)
instance BitPack a => BitPack (Complex a)
instance BitPack a => BitPack (Down a)
#endif

-- | Zero-extend a 'Bool'ean value to a 'BitVector' of the appropriate size.
--
-- >>> boolToBV True :: BitVector 6
-- 00_0001
-- >>> boolToBV False :: BitVector 6
-- 00_0000
boolToBV :: KnownNat n => Bool -> BitVector (n + 1)
boolToBV :: Bool -> BitVector (n + 1)
boolToBV = BitVector 1 -> BitVector (n + 1)
forall (f :: Nat -> Type) (a :: Nat) (b :: Nat).
(Resize f, KnownNat a, KnownNat b) =>
f a -> f (b + a)
zeroExtend (BitVector 1 -> BitVector (n + 1))
-> (Bool -> BitVector 1) -> Bool -> BitVector (n + 1)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bool -> BitVector 1
forall a. BitPack a => a -> BitVector (BitSize a)
pack

-- | Convert a Bool to a Bit
boolToBit :: Bool -> Bit
boolToBit :: Bool -> Bit
boolToBit = Bool -> Bit
forall a b. (BitPack a, BitPack b, BitSize a ~ BitSize b) => a -> b
bitCoerce

-- | Convert a Bool to a Bit
bitToBool :: Bit -> Bool
bitToBool :: Bit -> Bool
bitToBool = Bit -> Bool
forall a b. (BitPack a, BitPack b, BitSize a ~ BitSize b) => a -> b
bitCoerce

-- Derive the BitPack instance for tuples of size 3 to 62
deriveBitPackTuples ''BitPack ''BitSize 'pack 'unpack