{-|
Copyright  :  (C) 2021-2024, QBayLogic B.V.
License    :  BSD2 (see the file LICENSE)
Maintainer :  QBayLogic B.V. <devops@qbaylogic.com>
-}

{-# LANGUAGE CPP #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ViewPatterns #-}

module Clash.Explicit.BlockRam.Internal where

import Data.Bits ((.&.), (.|.), shiftL, xor)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as L
import Data.ByteString.Builder (Builder, toLazyByteString, word8, word64BE)
import qualified Data.ByteString.Unsafe as B
#if !MIN_VERSION_base(4,20,0)
import Data.Foldable (foldl')
#endif
import Data.Word (Word64)
import GHC.Exts (Addr#)
import GHC.TypeLits (KnownNat, Nat)
import Numeric.Natural (Natural)
import System.IO.Unsafe (unsafePerformIO)

import Clash.Class.BitPack.Internal (BitPack, BitSize, pack)
import Clash.Promoted.Nat (natToNum)
import Clash.Sized.Internal.BitVector (Bit(..), BitVector(..))

-- | Efficient storage of memory content
--
-- It holds @n@ words of @'BitVector' m@.
data MemBlob (n :: Nat) (m :: Nat) where
  MemBlob
    :: ( KnownNat n
       , KnownNat m
       )
    => { MemBlob n m -> Int
memBlobRunsLen :: !Int
       , MemBlob n m -> Addr#
memBlobRuns :: Addr#
       , MemBlob n m -> Int
memBlobEndsLen :: !Int
       , MemBlob n m -> Addr#
memBlobEnds :: Addr#
       }
    -> MemBlob n m

instance Show (MemBlob n m) where
  showsPrec :: Int -> MemBlob n m -> ShowS
showsPrec Int
_ x :: MemBlob n m
x@MemBlob{} =
    (String
"$(memBlobTH @(BitVector " String -> ShowS
forall a. [a] -> [a] -> [a]
++) ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ShowS
forall a. Show a => a -> ShowS
shows ((Num Int, KnownNat m) => Int
forall (n :: Nat) a. (Num a, KnownNat n) => a
natToNum @m @Int) ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
.
      (String
") Nothing " String -> ShowS
forall a. [a] -> [a] -> [a]
++) ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [BitVector m] -> ShowS
forall a. Show a => a -> ShowS
shows (MemBlob n m -> [BitVector m]
forall (n :: Nat) (m :: Nat). MemBlob n m -> [BitVector m]
unpackMemBlob MemBlob n m
x) ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char
')'Char -> ShowS
forall a. a -> [a] -> [a]
:)

-- | Convert a 'MemBlob' back to a list
--
-- __NB__: Not synthesizable
unpackMemBlob
  :: forall n m
   . MemBlob n m
  -> [BitVector m]
unpackMemBlob :: MemBlob n m -> [BitVector m]
unpackMemBlob = IO [BitVector m] -> [BitVector m]
forall a. IO a -> a
unsafePerformIO (IO [BitVector m] -> [BitVector m])
-> (MemBlob n m -> IO [BitVector m])
-> MemBlob n m
-> [BitVector m]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemBlob n m -> IO [BitVector m]
forall (n :: Nat) (m :: Nat). MemBlob n m -> IO [BitVector m]
unpackMemBlob0

unpackMemBlob0
  :: forall n m
   . MemBlob n m
  -> IO [BitVector m]
unpackMemBlob0 :: MemBlob n m -> IO [BitVector m]
unpackMemBlob0 MemBlob{Addr#
Int
memBlobEnds :: Addr#
memBlobEndsLen :: Int
memBlobRuns :: Addr#
memBlobRunsLen :: Int
memBlobEnds :: forall (n :: Nat) (m :: Nat). MemBlob n m -> Addr#
memBlobEndsLen :: forall (n :: Nat) (m :: Nat). MemBlob n m -> Int
memBlobRuns :: forall (n :: Nat) (m :: Nat). MemBlob n m -> Addr#
memBlobRunsLen :: forall (n :: Nat) (m :: Nat). MemBlob n m -> Int
..} = do
  ByteString
runsB <- Int -> Addr# -> IO ByteString
B.unsafePackAddressLen Int
memBlobRunsLen Addr#
memBlobRuns
  ByteString
endsB <- Int -> Addr# -> IO ByteString
B.unsafePackAddressLen Int
memBlobEndsLen Addr#
memBlobEnds
  [BitVector m] -> IO [BitVector m]
forall (m :: Type -> Type) a. Monad m => a -> m a
return ([BitVector m] -> IO [BitVector m])
-> [BitVector m] -> IO [BitVector m]
forall a b. (a -> b) -> a -> b
$ (Natural -> BitVector m) -> [Natural] -> [BitVector m]
forall a b. (a -> b) -> [a] -> [b]
map (Natural -> Natural -> BitVector m
forall (n :: Nat). Natural -> Natural -> BitVector n
BV Natural
0) ([Natural] -> [BitVector m]) -> [Natural] -> [BitVector m]
forall a b. (a -> b) -> a -> b
$
    Int -> Int -> ByteString -> ByteString -> [Natural]
unpackNats (forall a. (Num a, KnownNat n) => a
forall (n :: Nat) a. (Num a, KnownNat n) => a
natToNum @n) (forall a. (Num a, KnownNat m) => a
forall (n :: Nat) a. (Num a, KnownNat n) => a
natToNum @m) ByteString
runsB ByteString
endsB

packBVs
  :: forall a f
   . ( Foldable f
     , BitPack a
     )
  => Maybe Bit
  -> f a
  -> Either String (Int, L.ByteString, L.ByteString)
packBVs :: Maybe Bit -> f a -> Either String (Int, ByteString, ByteString)
packBVs Maybe Bit
care f a
es =
  case Maybe Int
lenOrErr of
    Maybe Int
Nothing  -> String -> Either String (Int, ByteString, ByteString)
forall a b. a -> Either a b
Left String
err
    Just Int
len -> let (ByteString
runs, ByteString
ends) = Int -> (a -> Natural) -> f a -> (ByteString, ByteString)
forall a (f :: Type -> Type).
Foldable f =>
Int -> (a -> Natural) -> f a -> (ByteString, ByteString)
packAsNats Int
mI (BitVector (BitSize a) -> Natural
knownBVVal (BitVector (BitSize a) -> Natural)
-> (a -> BitVector (BitSize a)) -> a -> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> BitVector (BitSize a)
forall a. BitPack a => a -> BitVector (BitSize a)
pack) f a
es
                in (Int, ByteString, ByteString)
-> Either String (Int, ByteString, ByteString)
forall a b. b -> Either a b
Right (Int
len, ByteString
runs, ByteString
ends)
 where
  lenOrErr :: Maybe Int
lenOrErr = case Maybe Bit
care of
               Just (Bit Word
0 Word
_) -> Int -> Maybe Int
forall a. a -> Maybe a
Just (Int -> Maybe Int) -> Int -> Maybe Int
forall a b. (a -> b) -> a -> b
$ f a -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length f a
es
               Maybe Bit
_              -> (Maybe Int -> a -> Maybe Int) -> Maybe Int -> f a -> Maybe Int
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Maybe Int -> a -> Maybe Int
forall a a. (BitPack a, Num a) => Maybe a -> a -> Maybe a
lenOrErr0 (Int -> Maybe Int
forall a. a -> Maybe a
Just Int
0) f a
es
  lenOrErr0 :: Maybe a -> a -> Maybe a
lenOrErr0 (Just a
len) (a -> BitVector (BitSize a)
forall a. BitPack a => a -> BitVector (BitSize a)
pack -> BV Natural
0 Natural
_) = a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
$ a
len a -> a -> a
forall a. Num a => a -> a -> a
+ a
1
  lenOrErr0 Maybe a
_          a
_                = Maybe a
forall a. Maybe a
Nothing

  knownBVVal :: BitVector (BitSize a) -> Natural
knownBVVal bv :: BitVector (BitSize a)
bv@(BV Natural
_ Natural
val) = case Maybe Bit
care of
    Just (Bit Word
0 Word
bm) -> Word -> BitVector (BitSize a) -> Natural
maskBVVal Word
bm BitVector (BitSize a)
bv
    Maybe Bit
_               -> Natural
val

  maskBVVal :: Word -> BitVector (BitSize a) -> Natural
maskBVVal Word
_ (BV Natural
0    Natural
val) = Natural
val
  maskBVVal Word
0 (BV Natural
mask Natural
val) = Natural
val Natural -> Natural -> Natural
forall a. Bits a => a -> a -> a
.&. (Natural
mask Natural -> Natural -> Natural
forall a. Bits a => a -> a -> a
`xor` Natural
fullMask)
  maskBVVal Word
_ (BV Natural
mask Natural
val) = Natural
val Natural -> Natural -> Natural
forall a. Bits a => a -> a -> a
.|. Natural
mask

  mI :: Int
mI = (Num Int, KnownNat (BitSize a)) => Int
forall (n :: Nat) a. (Num a, KnownNat n) => a
natToNum @(BitSize a) @Int
  fullMask :: Natural
fullMask = (Natural
1 Natural -> Int -> Natural
forall a. Bits a => a -> Int -> a
`shiftL` Int
mI) Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
- Natural
1
  err :: String
err = String
"packBVs: cannot convert don't care values. " String -> ShowS
forall a. [a] -> [a] -> [a]
++
        String
"Please specify a mapping to a definite value."

packAsNats
  :: forall a f
   . Foldable f
  => Int
  -> (a -> Natural)
  -> f a
  -> (L.ByteString, L.ByteString)
packAsNats :: Int -> (a -> Natural) -> f a -> (ByteString, ByteString)
packAsNats Int
width a -> Natural
trans f a
es = (Builder -> ByteString
toLazyByteString Builder
runs0, Builder -> ByteString
toLazyByteString Builder
ends)
 where
  (Int
runL, Int
endL) = Int
width Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
`divMod` Int
8
  ends :: Builder
ends | Int
endC0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = Word64 -> Builder
word64BE Word64
endA0 Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
ends0
       | Bool
otherwise = Builder
ends0
  (Builder
runs0, Builder
ends0, Int
endC0, Word64
endA0) = (a
 -> (Builder, Builder, Int, Word64)
 -> (Builder, Builder, Int, Word64))
-> (Builder, Builder, Int, Word64)
-> f a
-> (Builder, Builder, Int, Word64)
forall (t :: Type -> Type) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr a
-> (Builder, Builder, Int, Word64)
-> (Builder, Builder, Int, Word64)
pack0 (Builder
forall a. Monoid a => a
mempty, Builder
forall a. Monoid a => a
mempty, Int
0, Word64
0) f a
es

  pack0 :: a -> (Builder, Builder, Int, Word64) ->
           (Builder, Builder, Int, Word64)
  pack0 :: a
-> (Builder, Builder, Int, Word64)
-> (Builder, Builder, Int, Word64)
pack0 a
val (Builder
runs1, Builder
ends1, Int
endC1, Word64
endA1) =
    let (Builder
ends2, Int
endC2, Word64
endA2) = Natural -> Builder -> Int -> Word64 -> (Builder, Int, Word64)
packEnd Natural
val2 Builder
ends1 Int
endC1 Word64
endA1
        (Natural
val2, Builder
runs2) = Int -> Natural -> Builder -> (Natural, Builder)
packRun Int
runL (a -> Natural
trans a
val) Builder
runs1
    in (Builder
runs2, Builder
ends2, Int
endC2, Word64
endA2)

  packRun :: Int -> Natural -> Builder -> (Natural, Builder)
  packRun :: Int -> Natural -> Builder -> (Natural, Builder)
packRun Int
0    Natural
val1 Builder
runs1 = (Natural
val1, Builder
runs1)
  packRun Int
runC Natural
val1 Builder
runs1 = let (Natural
val2, Natural
runB) = Natural
val1 Natural -> Natural -> (Natural, Natural)
forall a. Integral a => a -> a -> (a, a)
`divMod` Natural
256
                                runs2 :: Builder
runs2 = Word8 -> Builder
word8 (Natural -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral Natural
runB) Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
runs1
                            in Int -> Natural -> Builder -> (Natural, Builder)
packRun (Int
runC Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Natural
val2 Builder
runs2

  packEnd :: Natural -> Builder -> Int -> Word64 -> (Builder, Int, Word64)
  packEnd :: Natural -> Builder -> Int -> Word64 -> (Builder, Int, Word64)
packEnd Natural
val2 Builder
ends1 Int
endC1 Word64
endA1
    | Int
endL Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0   = (Builder
ends1, Int
endC1, Word64
endA1)
    | Int
endC2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
64 = let endA2 :: Word64
endA2 = Word64
endA1 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* (Word64
2 Word64 -> Int -> Word64
forall a b. (Num a, Integral b) => a -> b -> a
^ Int
endL) Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
valEnd
                    in (Builder
ends1, Int
endC2, Word64
endA2)
    | Bool
otherwise   = let ends2 :: Builder
ends2 = Word64 -> Builder
word64BE Word64
endA1 Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> Builder
ends1
                    in (Builder
ends2, Int
endL, Word64
valEnd)
   where
    endC2 :: Int
endC2 = Int
endC1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
endL
    valEnd :: Word64
valEnd = Natural -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Natural
val2

unpackNats
  :: Int
  -> Int
  -> B.ByteString
  -> B.ByteString
  -> [Natural]
unpackNats :: Int -> Int -> ByteString -> ByteString -> [Natural]
unpackNats Int
0 Int
_ ByteString
_ ByteString
_ = []
unpackNats Int
len Int
width ByteString
runBs ByteString
endBs
  | Int
width Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
8 = [Natural]
ends
  | Bool
otherwise = case [Natural]
ends of
                  (Natural
e0:[Natural]
es) -> Natural -> Int -> ByteString -> [Natural] -> [Natural]
go Natural
e0 Int
runL ByteString
runBs [Natural]
es
                  [Natural]
_ -> String -> [Natural]
forall a. HasCallStack => String -> a
error (String
"unpackNats failed for:" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> (Int, Int, ByteString, ByteString) -> String
forall a. Show a => a -> String
show (Int
len,Int
width,ByteString
runBs,ByteString
endBs))
 where
  (Int
runL, Int
endL) = Int
width Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
`divMod` Int
8
  ends :: [Natural]
ends = if Int
endL Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then
           Natural -> [Natural]
forall a. a -> [a]
repeat Natural
0
         else
           Int -> Int -> [Word64] -> [Natural]
unpackEnds Int
endL Int
len ([Word64] -> [Natural]) -> [Word64] -> [Natural]
forall a b. (a -> b) -> a -> b
$ ByteString -> [Word64]
unpackW64s ByteString
endBs

  go :: Natural -> Int -> B.ByteString -> [Natural] -> [Natural]
  go :: Natural -> Int -> ByteString -> [Natural] -> [Natural]
go Natural
val Int
0    ByteString
runBs0 [Natural]
ends0
    = let (Natural
end0,[Natural]
end0rest) = case [Natural]
ends0 of
            [] -> String -> (Natural, [Natural])
forall a. HasCallStack => String -> a
error String
"unpackNats: unexpected end of bytestring"
            (Natural
x:[Natural]
xs) -> (Natural
x,[Natural]
xs)
       in Natural
val Natural -> [Natural] -> [Natural]
forall a. a -> [a] -> [a]
: Natural -> Int -> ByteString -> [Natural] -> [Natural]
go Natural
end0 Int
runL ByteString
runBs0 [Natural]
end0rest
  go Natural
_   Int
_    ByteString
runBs0 [Natural]
_             | ByteString -> Bool
B.null ByteString
runBs0 = []
  go Natural
val Int
runC ByteString
runBs0 [Natural]
ends0
    = let (Word8
runB, ByteString
runBs1) = case ByteString -> Maybe (Word8, ByteString)
B.uncons ByteString
runBs0 of
             Maybe (Word8, ByteString)
Nothing -> String -> (Word8, ByteString)
forall a. HasCallStack => String -> a
error String
"unpackNats: unexpected end of bytestring"
             Just (Word8, ByteString)
xs -> (Word8, ByteString)
xs
          val0 :: Natural
val0 = Natural
val Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
* Natural
256 Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
+ Word8 -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
runB
      in Natural -> Int -> ByteString -> [Natural] -> [Natural]
go Natural
val0 (Int
runC Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) ByteString
runBs1 [Natural]
ends0

unpackW64s
  :: B.ByteString
  -> [Word64]
unpackW64s :: ByteString -> [Word64]
unpackW64s = Int -> Word64 -> ByteString -> [Word64]
go Int
8 Word64
0
 where
  go :: Int -> Word64 -> B.ByteString -> [Word64]
  go :: Int -> Word64 -> ByteString -> [Word64]
go Int
8 Word64
_   ByteString
endBs | ByteString -> Bool
B.null ByteString
endBs = []
  go Int
0 Word64
val ByteString
endBs = Word64
val Word64 -> [Word64] -> [Word64]
forall a. a -> [a] -> [a]
: Int -> Word64 -> ByteString -> [Word64]
go Int
8 Word64
0 ByteString
endBs
  go Int
n Word64
val ByteString
endBs = let (Word8
endB, ByteString
endBs0) = case ByteString -> Maybe (Word8, ByteString)
B.uncons ByteString
endBs of
                          Maybe (Word8, ByteString)
Nothing -> String -> (Word8, ByteString)
forall a. HasCallStack => String -> a
error String
"unpackW64s: unexpeded end of bytestring"
                          Just (Word8, ByteString)
xs -> (Word8, ByteString)
xs
                       val0 :: Word64
val0 = Word64
val Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
256 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word8 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
endB
                   in Int -> Word64 -> ByteString -> [Word64]
go (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Word64
val0 ByteString
endBs0

unpackEnds
  :: Int
  -> Int
  -> [Word64]
  -> [Natural]
unpackEnds :: Int -> Int -> [Word64] -> [Natural]
unpackEnds Int
_    Int
_   []     = []
unpackEnds Int
endL Int
len (Word64
w:[Word64]
ws) = Int -> Word64 -> [Word64] -> [Natural]
go Int
endCInit Word64
w [Word64]
ws
 where
  endPerWord :: Int
endPerWord = Int
64 Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
endL
  leader :: Int
leader = Int
len Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
endPerWord
  endCInit :: Int
endCInit | Int
leader Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Int
endPerWord
           | Bool
otherwise   = Int
leader

  go :: Int -> Word64 -> [Word64] -> [Natural]
go Int
0 Word64
_    []       = []
  go Int
0 Word64
_    (Word64
w0:[Word64]
ws0) = Int -> Word64 -> [Word64] -> [Natural]
go Int
endPerWord Word64
w0 [Word64]
ws0
  go Int
n Word64
endA [Word64]
ws0      = let (Word64
endA0, Word64
valEnd) = Word64
endA Word64 -> Word64 -> (Word64, Word64)
forall a. Integral a => a -> a -> (a, a)
`divMod` (Word64
2 Word64 -> Int -> Word64
forall a b. (Num a, Integral b) => a -> b -> a
^ Int
endL)
                       in Word64 -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
valEnd Natural -> [Natural] -> [Natural]
forall a. a -> [a] -> [a]
: Int -> Word64 -> [Word64] -> [Natural]
go (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Word64
endA0 [Word64]
ws0