{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
module Data.BitVector.Sized.BitLayout
(
Chunk(..)
, chunk
, BitLayout
, empty, singleChunk, (<:)
, inject
, extract
, layoutLens, layoutsLens
) where
import Data.BitVector.Sized
import Data.Foldable
import qualified Data.Functor.Product as P
import Control.Lens (lens, Simple, Lens)
import Data.Parameterized
import Data.Parameterized.List
import qualified Data.Sequence as S
import Data.Sequence (Seq)
import GHC.TypeLits
import Text.PrettyPrint.HughesPJClass (Pretty(..), text)
data Chunk (w :: Nat) :: * where
Chunk :: NatRepr w
-> Int
-> Chunk w
chunk :: KnownNat w => Int -> Chunk w
chunk = Chunk knownNat
deriving instance Show (Chunk w)
instance ShowF Chunk where
showF = show
instance Pretty (Chunk w) where
pPrint (Chunk wRepr start)
| width > 0 = text $
"[" ++ show (start + width - 1) ++ "..." ++ show start ++ "]"
| otherwise = text $ "[" ++ show start ++ "]"
where width = fromIntegral (natValue wRepr)
instance Pretty (Some Chunk) where
pPrint (Some (Chunk wRepr start))
| width > 0 = text $
"[" ++ show (start + width - 1) ++ "..." ++ show start ++ "]"
| otherwise = text $ "[" ++ show start ++ "]"
where width = fromIntegral (natValue wRepr)
data BitLayout (t :: Nat) (s :: Nat) :: * where
BitLayout :: NatRepr t -> NatRepr s -> Seq (Some Chunk) -> BitLayout t s
instance Pretty (BitLayout t s) where
pPrint (BitLayout _ _ chks) = text $ show (pPrint <$> reverse $ toList chks)
deriving instance Show (BitLayout t s)
empty :: KnownNat t => BitLayout t 0
empty = BitLayout knownNat knownNat S.empty
singleChunk :: (KnownNat w, KnownNat w') => Int -> BitLayout w w'
singleChunk idx = chunk idx <: empty
(<:) :: Chunk r
-> BitLayout t s
-> BitLayout t (r + s)
chk@(Chunk rRepr _) <: bl@(BitLayout tRepr sRepr chunks) =
if chk `chunkFits` bl
then BitLayout tRepr (rRepr `addNat` sRepr) (chunks S.|> Some chk)
else error $
"chunk " ++ show chk ++ " does not fit in layout of size " ++
show (natValue tRepr) ++ ": " ++ show bl
infixr 6 <:
chunkFits :: Chunk r -> BitLayout t s -> Bool
chunkFits chk@(Chunk rRepr start) (BitLayout tRepr sRepr chunks) =
(natValue rRepr + natValue sRepr <= natValue tRepr) &&
(fromIntegral start + natValue rRepr <= natValue tRepr) &&
(0 <= start) &&
noOverlaps chk (toList chunks)
noOverlaps :: Chunk r -> [Some Chunk] -> Bool
noOverlaps chk = all (chunksDontOverlap (Some chk))
chunksDontOverlap :: Some Chunk -> Some Chunk -> Bool
chunksDontOverlap (Some (Chunk chunkRepr1 start1)) (Some (Chunk chunkRepr2 start2)) =
if start1 <= start2
then start1 + chunkWidth1 <= start2
else start2 + chunkWidth2 <= start1
where chunkWidth1 = fromIntegral (natValue chunkRepr1)
chunkWidth2 = fromIntegral (natValue chunkRepr2)
bvOrAt :: Int
-> BitVector s
-> BitVector t
-> BitVector t
bvOrAt start sVec tVec@(BV tRepr _) =
(bvZextWithRepr tRepr sVec `bvShift` start) `bvOr` tVec
bvOrAtAll :: NatRepr t
-> [Some Chunk]
-> BitVector s
-> BitVector t
bvOrAtAll tRepr [] _ = BV tRepr 0
bvOrAtAll tRepr (Some (Chunk chunkRepr chunkStart) : chunks) sVec =
bvOrAt chunkStart (bvTruncBits sVec chunkWidth) (bvOrAtAll tRepr chunks (sVec `bvShift` (- chunkWidth)))
where chunkWidth = fromIntegral (natValue chunkRepr)
inject :: BitLayout t s
-> BitVector t
-> BitVector s
-> BitVector t
inject (BitLayout tRepr _ chunks) tVec sVec =
bvOrAtAll tRepr (toList chunks) sVec `bvOr` tVec
extractChunk :: NatRepr s
-> Int
-> Some Chunk
-> BitVector t
-> BitVector s
extractChunk sRepr sStart (Some (Chunk chunkRepr chunkStart)) tVec =
bvShift extractedChunk sStart
where extractedChunk = bvZextWithRepr sRepr (bvExtractWithRepr chunkRepr chunkStart tVec)
extractAll :: NatRepr s
-> Int
-> [Some Chunk]
-> BitVector t
-> BitVector s
extractAll sRepr _ [] _ = BV sRepr 0
extractAll sRepr outStart (chk@(Some (Chunk chunkRepr _)) : chunks) tVec =
extractChunk sRepr outStart chk tVec `bvOr`
extractAll sRepr (outStart + chunkWidth) chunks tVec
where chunkWidth = fromInteger (natValue chunkRepr)
extract :: BitLayout t s
-> BitVector t
-> BitVector s
extract (BitLayout _ sRepr chunks) = extractAll sRepr 0 (toList chunks)
layoutLens :: BitLayout t s -> Simple Lens (BitVector t) (BitVector s)
layoutLens layout = lens (extract layout) (inject layout)
layoutsLens :: forall ws . List (BitLayout 32) ws -> Simple Lens (BitVector 32) (List BitVector ws)
layoutsLens layouts = lens
(\bv -> imap (const $ flip extract bv) layouts)
(\bv bvFlds -> ifoldr (\_ (P.Pair fld layout) bv' -> inject layout bv' fld)
bv
(izipWith (const P.Pair) bvFlds layouts))