-- | Module for rendering a sorted list of Bits from a list of TensorStacks
--
--   Functions ending in "R" are for sorting Records when used in a recursive
--   Tensort variant
--
--   TODO: See if we can clean up the type conversion here
module Data.Tensort.Utils.Render (getSortedBitsFromTensor) where

import Data.Maybe (isNothing)
import Data.Tensort.Utils.Compose (createTensor)
import Data.Tensort.Utils.Types
  ( Bit,
    BitR,
    Memory (..),
    MemoryR (..),
    SBit (..),
    SMemory (..),
    STensor (..),
    STensorStack,
    SortAlg,
    Sortable (..),
    Tensor,
    TensorR,
    TensorStack,
    TensorStackR,
    fromJust,
    fromSTensorBit,
    fromSTensorRec,
    fromSortBit,
    fromSortRec,
  )

-- | Compile a sorted list of Bits from a list of TensorStacks

-- | ==== __Examples__
-- >>> import Data.Tensort.Subalgorithms.Bubblesort (bubblesort)
-- >>> getSortedBitsFromTensor bubblesort (STensorBit ([(0,5),(1,7)],ByteMem [[1,5],[3,7]]))
-- [SBitBit 1,SBitBit 3,SBitBit 5,SBitBit 7]
-- >>> getSortedBitsFromTensor bubblesort (STensorBit ([(0,8),(1,18)],TensorMem [([(0,7),(1,8)],TensorMem [([(0,3),(1,7)],ByteMem [[1,3],[5,7]]),([(0,4),(1,8)],ByteMem [[2,4],[6,8]])]),([(1,17),(0,18)],TensorMem [([(0,13),(1,18)],ByteMem [[11,13],[15,18]]),([(0,14),(1,17)],ByteMem [[12,14],[16,17]])])]))
-- [SBitBit 1,SBitBit 2,SBitBit 3,SBitBit 4,SBitBit 5,SBitBit 6,SBitBit 7,SBitBit 8,SBitBit 11,SBitBit 12,SBitBit 13,SBitBit 14,SBitBit 15,SBitBit 16,SBitBit 17,SBitBit 18]
getSortedBitsFromTensor :: SortAlg -> STensorStack -> [SBit]
getSortedBitsFromTensor :: SortAlg -> STensorStack -> [SBit]
getSortedBitsFromTensor SortAlg
subAlg (STensorBit Tensor
tensorRaw) =
  SortAlg -> Tensor -> [SBit]
getSortedBitsFromTensorB SortAlg
subAlg Tensor
tensorRaw
getSortedBitsFromTensor SortAlg
subAlg (STensorRec TensorR
tensorRaw) =
  SortAlg -> TensorR -> [SBit]
getSortedBitsFromTensorR SortAlg
subAlg TensorR
tensorRaw

getSortedBitsFromTensorB :: SortAlg -> TensorStack -> [SBit]
getSortedBitsFromTensorB :: SortAlg -> Tensor -> [SBit]
getSortedBitsFromTensorB SortAlg
subAlg Tensor
tensorRaw = Tensor -> [SBit] -> [SBit]
acc Tensor
tensorRaw []
  where
    acc :: TensorStack -> [SBit] -> [SBit]
    acc :: Tensor -> [SBit] -> [SBit]
acc Tensor
tensor [SBit]
sortedBits = do
      let (Int
nextBit, Maybe Tensor
tensor') = SortAlg -> Tensor -> (Int, Maybe Tensor)
removeTopBitFromTensor SortAlg
subAlg Tensor
tensor
      let nextBit' :: SBit
nextBit' = Int -> SBit
SBitBit Int
nextBit
      if Maybe Tensor -> Bool
forall a. Maybe a -> Bool
isNothing Maybe Tensor
tensor'
        then SBit
nextBit' SBit -> [SBit] -> [SBit]
forall a. a -> [a] -> [a]
: [SBit]
sortedBits
        else do
          Tensor -> [SBit] -> [SBit]
acc (Maybe Tensor -> Tensor
forall a. Maybe a -> a
fromJust Maybe Tensor
tensor') (SBit
nextBit' SBit -> [SBit] -> [SBit]
forall a. a -> [a] -> [a]
: [SBit]
sortedBits)

getSortedBitsFromTensorR :: SortAlg -> TensorStackR -> [SBit]
getSortedBitsFromTensorR :: SortAlg -> TensorR -> [SBit]
getSortedBitsFromTensorR SortAlg
subAlg TensorR
tensorRaw = TensorR -> [SBit] -> [SBit]
acc TensorR
tensorRaw []
  where
    acc :: TensorStackR -> [SBit] -> [SBit]
    acc :: TensorR -> [SBit] -> [SBit]
acc TensorR
tensor [SBit]
sortedBits = do
      let (BitR
nextBit, Maybe TensorR
tensor') = SortAlg -> TensorR -> (BitR, Maybe TensorR)
removeTopBitFromTensorR SortAlg
subAlg TensorR
tensor
      let nextBit' :: SBit
nextBit' = BitR -> SBit
SBitRec BitR
nextBit
      if Maybe TensorR -> Bool
forall a. Maybe a -> Bool
isNothing Maybe TensorR
tensor'
        then SBit
nextBit' SBit -> [SBit] -> [SBit]
forall a. a -> [a] -> [a]
: [SBit]
sortedBits
        else do
          TensorR -> [SBit] -> [SBit]
acc (Maybe TensorR -> TensorR
forall a. Maybe a -> a
fromJust Maybe TensorR
tensor') (SBit
nextBit' SBit -> [SBit] -> [SBit]
forall a. a -> [a] -> [a]
: [SBit]
sortedBits)

-- | For use in compiling a list of Tensors into a sorted list of Bits
--
-- | Removes the top Bit from a Tensor, rebalances the Tensor and returns
--   the removed Bit along with the rebalanced Tensor

-- | ==== __Examples__
--   >>> import Data.Tensort.Subalgorithms.Bubblesort (bubblesort)
--   >>> removeTopBitFromTensor bubblesort ([(0,5),(1,7)],ByteMem [[1,5],[3,7]])
--   (7,Just ([(1,3),(0,5)],ByteMem [[1,5],[3]]))
removeTopBitFromTensor :: SortAlg -> Tensor -> (Bit, Maybe Tensor)
removeTopBitFromTensor :: SortAlg -> Tensor -> (Int, Maybe Tensor)
removeTopBitFromTensor SortAlg
subAlg (Register
register, Memory
memory) = do
  let topRecord :: BitR
topRecord = Register -> BitR
forall a. HasCallStack => [a] -> a
last Register
register
  let topAddress :: Int
topAddress = BitR -> Int
forall a b. (a, b) -> a
fst BitR
topRecord
  let (Int
topBit, Maybe Memory
memory') = SortAlg -> Memory -> Int -> (Int, Maybe Memory)
removeBitFromMemory SortAlg
subAlg Memory
memory Int
topAddress
  if Maybe Memory -> Bool
forall a. Maybe a -> Bool
isNothing Maybe Memory
memory'
    then (Int
topBit, Maybe Tensor
forall a. Maybe a
Nothing)
    else
      ( Int
topBit,
        Tensor -> Maybe Tensor
forall a. a -> Maybe a
Just
          ( STensorStack -> Tensor
fromSTensorBit
              ( SortAlg -> SMemory -> STensorStack
createTensor
                  SortAlg
subAlg
                  (Memory -> SMemory
SMemoryBit (Maybe Memory -> Memory
forall a. Maybe a -> a
fromJust Maybe Memory
memory'))
              )
          )
      )

removeTopBitFromTensorR :: SortAlg -> TensorR -> (BitR, Maybe TensorR)
removeTopBitFromTensorR :: SortAlg -> TensorR -> (BitR, Maybe TensorR)
removeTopBitFromTensorR SortAlg
subAlg (RegisterR
register, MemoryR
memory) = do
  let topRecord :: RecordR
topRecord = RegisterR -> RecordR
forall a. HasCallStack => [a] -> a
last RegisterR
register
  let topAddress :: Int
topAddress = RecordR -> Int
forall a b. (a, b) -> a
fst RecordR
topRecord
  let (BitR
topBit, Maybe MemoryR
memory') = SortAlg -> MemoryR -> Int -> (BitR, Maybe MemoryR)
removeBitFromMemoryR SortAlg
subAlg MemoryR
memory Int
topAddress
  if Maybe MemoryR -> Bool
forall a. Maybe a -> Bool
isNothing Maybe MemoryR
memory'
    then (BitR
topBit, Maybe TensorR
forall a. Maybe a
Nothing)
    else
      ( BitR
topBit,
        TensorR -> Maybe TensorR
forall a. a -> Maybe a
Just
          ( STensorStack -> TensorR
fromSTensorRec
              ( SortAlg -> SMemory -> STensorStack
createTensor
                  SortAlg
subAlg
                  (MemoryR -> SMemory
SMemoryRec (Maybe MemoryR -> MemoryR
forall a. Maybe a -> a
fromJust Maybe MemoryR
memory'))
              )
          )
      )

removeBitFromMemory :: SortAlg -> Memory -> Int -> (Bit, Maybe Memory)
removeBitFromMemory :: SortAlg -> Memory -> Int -> (Int, Maybe Memory)
removeBitFromMemory SortAlg
subAlg (ByteMem [Byte]
bytes) Int
i = do
  let topByte :: Byte
topByte = [Byte]
bytes [Byte] -> Int -> Byte
forall a. HasCallStack => [a] -> Int -> a
!! Int
i
  let topBit :: Int
topBit = Byte -> Int
forall a. HasCallStack => [a] -> a
last Byte
topByte
  let topByte' :: Byte
topByte' = Byte -> Byte
forall a. HasCallStack => [a] -> [a]
init Byte
topByte
  case Byte -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Byte
topByte' of
    Int
0 -> do
      let bytes' :: [Byte]
bytes' = Int -> [Byte] -> [Byte]
forall a. Int -> [a] -> [a]
take Int
i [Byte]
bytes [Byte] -> [Byte] -> [Byte]
forall a. [a] -> [a] -> [a]
++ Int -> [Byte] -> [Byte]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Byte]
bytes
      if [Byte] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Byte]
bytes'
        then (Int
topBit, Maybe Memory
forall a. Maybe a
Nothing)
        else (Int
topBit, Memory -> Maybe Memory
forall a. a -> Maybe a
Just ([Byte] -> Memory
ByteMem [Byte]
bytes'))
    Int
1 -> do
      let bytes' :: [Byte]
bytes' = Int -> [Byte] -> [Byte]
forall a. Int -> [a] -> [a]
take Int
i [Byte]
bytes [Byte] -> [Byte] -> [Byte]
forall a. [a] -> [a] -> [a]
++ [Byte
topByte'] [Byte] -> [Byte] -> [Byte]
forall a. [a] -> [a] -> [a]
++ Int -> [Byte] -> [Byte]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Byte]
bytes
      (Int
topBit, Memory -> Maybe Memory
forall a. a -> Maybe a
Just ([Byte] -> Memory
ByteMem [Byte]
bytes'))
    Int
_ -> do
      let topByte'' :: Byte
topByte'' = Sortable -> Byte
fromSortBit (SortAlg
subAlg (Byte -> Sortable
SortBit Byte
topByte'))
      let bytes' :: [Byte]
bytes' = Int -> [Byte] -> [Byte]
forall a. Int -> [a] -> [a]
take Int
i [Byte]
bytes [Byte] -> [Byte] -> [Byte]
forall a. [a] -> [a] -> [a]
++ [Byte
topByte''] [Byte] -> [Byte] -> [Byte]
forall a. [a] -> [a] -> [a]
++ Int -> [Byte] -> [Byte]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Byte]
bytes
      (Int
topBit, Memory -> Maybe Memory
forall a. a -> Maybe a
Just ([Byte] -> Memory
ByteMem [Byte]
bytes'))
removeBitFromMemory SortAlg
subAlg (TensorMem [Tensor]
tensors) Int
i = do
  let topTensor :: Tensor
topTensor = [Tensor]
tensors [Tensor] -> Int -> Tensor
forall a. HasCallStack => [a] -> Int -> a
!! Int
i
  let (Int
topBit, Maybe Tensor
topTensor') = SortAlg -> Tensor -> (Int, Maybe Tensor)
removeTopBitFromTensor SortAlg
subAlg Tensor
topTensor
  if Maybe Tensor -> Bool
forall a. Maybe a -> Bool
isNothing Maybe Tensor
topTensor'
    then do
      let tensors' :: [Tensor]
tensors' = Int -> [Tensor] -> [Tensor]
forall a. Int -> [a] -> [a]
take Int
i [Tensor]
tensors [Tensor] -> [Tensor] -> [Tensor]
forall a. [a] -> [a] -> [a]
++ Int -> [Tensor] -> [Tensor]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Tensor]
tensors
      if [Tensor] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Tensor]
tensors'
        then (Int
topBit, Maybe Memory
forall a. Maybe a
Nothing)
        else (Int
topBit, Memory -> Maybe Memory
forall a. a -> Maybe a
Just ([Tensor] -> Memory
TensorMem [Tensor]
tensors'))
    else do
      let tensors' :: [Tensor]
tensors' =
            Int -> [Tensor] -> [Tensor]
forall a. Int -> [a] -> [a]
take Int
i [Tensor]
tensors [Tensor] -> [Tensor] -> [Tensor]
forall a. [a] -> [a] -> [a]
++ [Maybe Tensor -> Tensor
forall a. Maybe a -> a
fromJust Maybe Tensor
topTensor'] [Tensor] -> [Tensor] -> [Tensor]
forall a. [a] -> [a] -> [a]
++ Int -> [Tensor] -> [Tensor]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Tensor]
tensors
      (Int
topBit, Memory -> Maybe Memory
forall a. a -> Maybe a
Just ([Tensor] -> Memory
TensorMem [Tensor]
tensors'))

removeBitFromMemoryR :: SortAlg -> MemoryR -> Int -> (BitR, Maybe MemoryR)
removeBitFromMemoryR :: SortAlg -> MemoryR -> Int -> (BitR, Maybe MemoryR)
removeBitFromMemoryR SortAlg
subAlg (ByteMemR [Register]
bytesR) Int
i = do
  let topByteR :: Register
topByteR = [Register]
bytesR [Register] -> Int -> Register
forall a. HasCallStack => [a] -> Int -> a
!! Int
i
  let topBitR :: BitR
topBitR = Register -> BitR
forall a. HasCallStack => [a] -> a
last Register
topByteR
  let topByteR' :: Register
topByteR' = Register -> Register
forall a. HasCallStack => [a] -> [a]
init Register
topByteR
  case Register -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Register
topByteR' of
    Int
0 -> do
      let bytesR' :: [Register]
bytesR' = Int -> [Register] -> [Register]
forall a. Int -> [a] -> [a]
take Int
i [Register]
bytesR [Register] -> [Register] -> [Register]
forall a. [a] -> [a] -> [a]
++ Int -> [Register] -> [Register]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Register]
bytesR
      if [Register] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Register]
bytesR'
        then (BitR
topBitR, Maybe MemoryR
forall a. Maybe a
Nothing)
        else (BitR
topBitR, MemoryR -> Maybe MemoryR
forall a. a -> Maybe a
Just ([Register] -> MemoryR
ByteMemR [Register]
bytesR'))
    Int
1 -> do
      let bytesR' :: [Register]
bytesR' = Int -> [Register] -> [Register]
forall a. Int -> [a] -> [a]
take Int
i [Register]
bytesR [Register] -> [Register] -> [Register]
forall a. [a] -> [a] -> [a]
++ [Register
topByteR'] [Register] -> [Register] -> [Register]
forall a. [a] -> [a] -> [a]
++ Int -> [Register] -> [Register]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Register]
bytesR
      (BitR
topBitR, MemoryR -> Maybe MemoryR
forall a. a -> Maybe a
Just ([Register] -> MemoryR
ByteMemR [Register]
bytesR'))
    Int
_ -> do
      let topByteR'' :: Register
topByteR'' = Sortable -> Register
fromSortRec (SortAlg
subAlg (Register -> Sortable
SortRec Register
topByteR'))
      let bytesR' :: [Register]
bytesR' = Int -> [Register] -> [Register]
forall a. Int -> [a] -> [a]
take Int
i [Register]
bytesR [Register] -> [Register] -> [Register]
forall a. [a] -> [a] -> [a]
++ [Register
topByteR''] [Register] -> [Register] -> [Register]
forall a. [a] -> [a] -> [a]
++ Int -> [Register] -> [Register]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Register]
bytesR
      (BitR
topBitR, MemoryR -> Maybe MemoryR
forall a. a -> Maybe a
Just ([Register] -> MemoryR
ByteMemR [Register]
bytesR'))
removeBitFromMemoryR SortAlg
subAlg (TensorMemR [TensorR]
tensorsR) Int
i = do
  let topTensorR :: TensorR
topTensorR = [TensorR]
tensorsR [TensorR] -> Int -> TensorR
forall a. HasCallStack => [a] -> Int -> a
!! Int
i
  let (BitR
topBitR, Maybe TensorR
topTensorR') = SortAlg -> TensorR -> (BitR, Maybe TensorR)
removeTopBitFromTensorR SortAlg
subAlg TensorR
topTensorR
  if Maybe TensorR -> Bool
forall a. Maybe a -> Bool
isNothing Maybe TensorR
topTensorR'
    then do
      let tensorsR' :: [TensorR]
tensorsR' = Int -> [TensorR] -> [TensorR]
forall a. Int -> [a] -> [a]
take Int
i [TensorR]
tensorsR [TensorR] -> [TensorR] -> [TensorR]
forall a. [a] -> [a] -> [a]
++ Int -> [TensorR] -> [TensorR]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [TensorR]
tensorsR
      if [TensorR] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [TensorR]
tensorsR'
        then (BitR
topBitR, Maybe MemoryR
forall a. Maybe a
Nothing)
        else (BitR
topBitR, MemoryR -> Maybe MemoryR
forall a. a -> Maybe a
Just ([TensorR] -> MemoryR
TensorMemR [TensorR]
tensorsR'))
    else do
      let tensorsR' :: [TensorR]
tensorsR' =
            Int -> [TensorR] -> [TensorR]
forall a. Int -> [a] -> [a]
take Int
i [TensorR]
tensorsR [TensorR] -> [TensorR] -> [TensorR]
forall a. [a] -> [a] -> [a]
++ [Maybe TensorR -> TensorR
forall a. Maybe a -> a
fromJust Maybe TensorR
topTensorR'] [TensorR] -> [TensorR] -> [TensorR]
forall a. [a] -> [a] -> [a]
++ Int -> [TensorR] -> [TensorR]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [TensorR]
tensorsR
      (BitR
topBitR, MemoryR -> Maybe MemoryR
forall a. a -> Maybe a
Just ([TensorR] -> MemoryR
TensorMemR [TensorR]
tensorsR'))