-- | Module for creating Tensors from Bytes and Tensors
--
--   Functions ending in "R" are for sorting Records when used in a recursive
--   Tensort variant
--
--   TODO: See if we can clean up the type conversion here
module Data.Tensort.Utils.Compose
  ( createInitialTensors,
    createTensor,
  )
where

import Data.Tensort.Utils.SimplifyRegister
  ( applySortingFromSimplifiedRegister,
    simplifyRegister,
  )
import Data.Tensort.Utils.Split (splitEvery)
import Data.Tensort.Utils.Types
  ( Byte,
    ByteR,
    Memory (..),
    MemoryR (..),
    Record,
    RecordR,
    SBit (..),
    SBytes (..),
    SMemory (..),
    SRecord (..),
    STensor (..),
    STensors (..),
    SortAlg,
    Sortable (..),
    Tensor,
    TensorR,
    TensortProps (..),
    fromSBitBit,
    fromSBitRec,
    fromSRecordArrayBit,
    fromSRecordArrayRec,
    fromSTensorBit,
    fromSTensorRec,
    fromSortRec,
  )

-- | Convert a list of Bytes to a list of TensorStacks.

-- | This is accomplished by making a Tensor for each Byte, converting that
--   Tensor into a TensorStack (these are equivalent terms - see type
--   definitions for more info) and collating the TensorStacks into a list

-- | ==== __Examples__
-- >>> import Data.Tensort.Subalgorithms.Bubblesort (bubblesort)
-- >>> import Data.Tensort.Utils.MkTsProps (mkTsProps)
-- >>> createInitialTensors (mkTsProps 2 bubblesort) (SBytesBit [[2,4],[6,8],[1,3],[5,7]])
-- STensorsBit [([(0,3),(1,7)],ByteMem [[1,3],[5,7]]),([(0,4),(1,8)],ByteMem [[2,4],[6,8]])]
createInitialTensors :: TensortProps -> SBytes -> STensors
createInitialTensors :: TensortProps -> SBytes -> STensors
createInitialTensors TensortProps
tsProps (SBytesBit [Byte]
bytes) =
  [Tensor] -> STensors
STensorsBit (TensortProps -> [Byte] -> [Tensor]
createInitialTensorsBits TensortProps
tsProps [Byte]
bytes)
createInitialTensors TensortProps
tsProps (SBytesRec [ByteR]
recs) =
  [TensorR] -> STensors
STensorsRec (TensortProps -> [ByteR] -> [TensorR]
createInitialTensorsRecs TensortProps
tsProps [ByteR]
recs)

createInitialTensorsBits :: TensortProps -> [Byte] -> [Tensor]
createInitialTensorsBits :: TensortProps -> [Byte] -> [Tensor]
createInitialTensorsBits TensortProps
tsProps [Byte]
bytes =
  ([Byte] -> [Tensor] -> [Tensor])
-> [Tensor] -> [[Byte]] -> [Tensor]
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr [Byte] -> [Tensor] -> [Tensor]
acc [] (Int -> [Byte] -> [[Byte]]
forall a. Int -> [a] -> [[a]]
splitEvery (TensortProps -> Int
bytesize TensortProps
tsProps) [Byte]
bytes)
  where
    acc :: [Byte] -> [Tensor] -> [Tensor]
    acc :: [Byte] -> [Tensor] -> [Tensor]
acc [Byte]
byte [Tensor]
tensorStacks =
      [Tensor]
tensorStacks
        [Tensor] -> [Tensor] -> [Tensor]
forall a. [a] -> [a] -> [a]
++ [ STensor -> Tensor
fromSTensorBit
               (SortAlg -> SBytes -> STensor
getTensorFromBytes (TensortProps -> SortAlg
subAlgorithm TensortProps
tsProps) ([Byte] -> SBytes
SBytesBit [Byte]
byte))
           ]

createInitialTensorsRecs :: TensortProps -> [ByteR] -> [TensorR]
createInitialTensorsRecs :: TensortProps -> [ByteR] -> [TensorR]
createInitialTensorsRecs TensortProps
tsProps [ByteR]
bytesR =
  ([ByteR] -> [TensorR] -> [TensorR])
-> [TensorR] -> [[ByteR]] -> [TensorR]
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr [ByteR] -> [TensorR] -> [TensorR]
acc [] (Int -> [ByteR] -> [[ByteR]]
forall a. Int -> [a] -> [[a]]
splitEvery (TensortProps -> Int
bytesize TensortProps
tsProps) [ByteR]
bytesR)
  where
    acc :: [ByteR] -> [TensorR] -> [TensorR]
    acc :: [ByteR] -> [TensorR] -> [TensorR]
acc [ByteR]
byteR [TensorR]
tensorStacks =
      [TensorR]
tensorStacks
        [TensorR] -> [TensorR] -> [TensorR]
forall a. [a] -> [a] -> [a]
++ [ STensor -> TensorR
fromSTensorRec
               (SortAlg -> SBytes -> STensor
getTensorFromBytes (TensortProps -> SortAlg
subAlgorithm TensortProps
tsProps) ([ByteR] -> SBytes
SBytesRec [ByteR]
byteR))
           ]

-- | Create a Tensor from a Memory
--   Aliases to getTensorFromBytes for ByteMem and getTensorFromTensors for
--   TensorMem
createTensor :: SortAlg -> SMemory -> STensor
createTensor :: SortAlg -> SMemory -> STensor
createTensor SortAlg
subAlg (SMemoryBit Memory
memory) = SortAlg -> Memory -> STensor
createTensorB SortAlg
subAlg Memory
memory
createTensor SortAlg
subAlg (SMemoryRec MemoryR
memoryR) = SortAlg -> MemoryR -> STensor
createTensorR SortAlg
subAlg MemoryR
memoryR

createTensorB :: SortAlg -> Memory -> STensor
createTensorB :: SortAlg -> Memory -> STensor
createTensorB SortAlg
subAlg (ByteMem [Byte]
bytes) =
  SortAlg -> SBytes -> STensor
getTensorFromBytes SortAlg
subAlg ([Byte] -> SBytes
SBytesBit [Byte]
bytes)
createTensorB SortAlg
subAlg (TensorMem [Tensor]
tensors) =
  SortAlg -> STensors -> STensor
getTensorFromTensors SortAlg
subAlg ([Tensor] -> STensors
STensorsBit [Tensor]
tensors)

createTensorR :: SortAlg -> MemoryR -> STensor
createTensorR :: SortAlg -> MemoryR -> STensor
createTensorR SortAlg
subAlg (ByteMemR [ByteR]
bytesR) =
  SortAlg -> SBytes -> STensor
getTensorFromBytes SortAlg
subAlg ([ByteR] -> SBytes
SBytesRec [ByteR]
bytesR)
createTensorR SortAlg
subAlg (TensorMemR [TensorR]
tensorsR) =
  SortAlg -> STensors -> STensor
getTensorFromTensors SortAlg
subAlg ([TensorR] -> STensors
STensorsRec [TensorR]
tensorsR)

-- | Convert a list of Bytes to a Tensor

-- | We do this by loading the list of Bytes into the new Tensor's Memory
--   and adding a sorted Register containing References to each Byte in Memory

-- | Each Record contains an Address pointing to the index of the referenced
--   Byte and a TopBit containing the value of the last (i.e. highest) Bit in
--   the referenced Byte

-- | The Register is sorted by the TopBits of each Record

-- | ==== __Examples__
-- >>> import Data.Tensort.Subalgorithms.Bubblesort (bubblesort)
-- >>> getTensorFromBytes bubblesort (SBytesBit [[2,4,6,8],[1,3,5,7]])
-- STensorBit ([(1,7),(0,8)],ByteMem [[2,4,6,8],[1,3,5,7]])
getTensorFromBytes :: SortAlg -> SBytes -> STensor
getTensorFromBytes :: SortAlg -> SBytes -> STensor
getTensorFromBytes SortAlg
subAlg (SBytesBit [Byte]
bytes) =
  Tensor -> STensor
STensorBit (SortAlg -> [Byte] -> Tensor
getTensorFromBytesB SortAlg
subAlg [Byte]
bytes)
getTensorFromBytes SortAlg
subAlg (SBytesRec [ByteR]
recs) =
  TensorR -> STensor
STensorRec (SortAlg -> [ByteR] -> TensorR
getTensorFromBytesR SortAlg
subAlg [ByteR]
recs)

getTensorFromBytesB :: SortAlg -> [Byte] -> Tensor
getTensorFromBytesB :: SortAlg -> [Byte] -> Tensor
getTensorFromBytesB SortAlg
subAlg [Byte]
bytes = do
  let register :: ByteR
register = [Byte] -> ByteR -> Int -> ByteR
acc [Byte]
bytes [] Int
0
  let register' :: ByteR
register' = Sortable -> ByteR
fromSortRec (SortAlg
subAlg (ByteR -> Sortable
SortRec ByteR
register))
  (ByteR
register', [Byte] -> Memory
ByteMem [Byte]
bytes)
  where
    acc :: [Byte] -> [Record] -> Int -> [Record]
    acc :: [Byte] -> ByteR -> Int -> ByteR
acc [] ByteR
register Int
_ = ByteR
register
    acc ([] : [Byte]
remainingBytes) ByteR
register Int
i = [Byte] -> ByteR -> Int -> ByteR
acc [Byte]
remainingBytes ByteR
register (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
    acc (Byte
byte : [Byte]
remainingBytes) ByteR
register Int
i =
      [Byte] -> ByteR -> Int -> ByteR
acc [Byte]
remainingBytes (ByteR
register ByteR -> ByteR -> ByteR
forall a. [a] -> [a] -> [a]
++ [(Int
i, Byte -> Int
forall a. HasCallStack => [a] -> a
last Byte
byte)]) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

getTensorFromBytesR :: SortAlg -> [ByteR] -> TensorR
getTensorFromBytesR :: SortAlg -> [ByteR] -> TensorR
getTensorFromBytesR SortAlg
subAlg [ByteR]
bytesR = do
  let registerR :: [RecordR]
registerR = [ByteR] -> [RecordR] -> Int -> [RecordR]
acc [ByteR]
bytesR [] Int
0
  let simplifiedRegiser :: ByteR
simplifiedRegiser = [RecordR] -> ByteR
simplifyRegister [RecordR]
registerR
  let simplifiedRegiser' :: ByteR
simplifiedRegiser' = Sortable -> ByteR
fromSortRec (SortAlg
subAlg (ByteR -> Sortable
SortRec ByteR
simplifiedRegiser))
  let registerR' :: [RecordR]
registerR' =
        ByteR -> [RecordR] -> [RecordR]
applySortingFromSimplifiedRegister ByteR
simplifiedRegiser' [RecordR]
registerR
  ([RecordR]
registerR', [ByteR] -> MemoryR
ByteMemR [ByteR]
bytesR)
  where
    acc :: [ByteR] -> [RecordR] -> Int -> [RecordR]
    acc :: [ByteR] -> [RecordR] -> Int -> [RecordR]
acc [] [RecordR]
register Int
_ = [RecordR]
register
    acc ([] : [ByteR]
remainingBytesR) [RecordR]
registerR Int
i =
      [ByteR] -> [RecordR] -> Int -> [RecordR]
acc [ByteR]
remainingBytesR [RecordR]
registerR (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
    acc (ByteR
byteR : [ByteR]
remainingBytesR) [RecordR]
registerR Int
i =
      [ByteR] -> [RecordR] -> Int -> [RecordR]
acc [ByteR]
remainingBytesR ([RecordR]
registerR [RecordR] -> [RecordR] -> [RecordR]
forall a. [a] -> [a] -> [a]
++ [(Int
i, ByteR -> TopBitR
forall a. HasCallStack => [a] -> a
last ByteR
byteR)]) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

-- | Create a TensorStack with the collated and sorted References from the
--   Tensors as the Register and the original Tensors as the data

-- | ==== __Examples__
-- >>> import Data.Tensort.Subalgorithms.Bubblesort (bubblesort)
-- >>> getTensorFromTensors bubblesort (STensorsBit [([(0,13),(1,18)],ByteMem [[11,13],[15,18]]),([(1,14),(0,17)],ByteMem [[16,17],[12,14]])])
-- STensorBit ([(1,17),(0,18)],TensorMem [([(0,13),(1,18)],ByteMem [[11,13],[15,18]]),([(1,14),(0,17)],ByteMem [[16,17],[12,14]])])
getTensorFromTensors :: SortAlg -> STensors -> STensor
getTensorFromTensors :: SortAlg -> STensors -> STensor
getTensorFromTensors SortAlg
subAlg (STensorsBit [Tensor]
tensors) =
  Tensor -> STensor
STensorBit (SortAlg -> [Tensor] -> Tensor
getTensorFromTensorsB SortAlg
subAlg [Tensor]
tensors)
getTensorFromTensors SortAlg
subAlg (STensorsRec [TensorR]
tensors) =
  TensorR -> STensor
STensorRec (SortAlg -> [TensorR] -> TensorR
getTensorFromTensorsR SortAlg
subAlg [TensorR]
tensors)

getTensorFromTensorsB :: SortAlg -> [Tensor] -> Tensor
getTensorFromTensorsB :: SortAlg -> [Tensor] -> Tensor
getTensorFromTensorsB SortAlg
subAlg [Tensor]
tensors =
  ( Sortable -> ByteR
fromSortRec
      ( SortAlg
subAlg
          ( ByteR -> Sortable
SortRec
              ( [SRecord] -> ByteR
fromSRecordArrayBit
                  (STensors -> [SRecord]
getRegisterFromTensors ([Tensor] -> STensors
STensorsBit [Tensor]
tensors))
              )
          )
      ),
    [Tensor] -> Memory
TensorMem [Tensor]
tensors
  )

getTensorFromTensorsR :: SortAlg -> [TensorR] -> TensorR
getTensorFromTensorsR :: SortAlg -> [TensorR] -> TensorR
getTensorFromTensorsR SortAlg
subAlg [TensorR]
tensorsR = do
  let registerR :: [SRecord]
registerR = STensors -> [SRecord]
getRegisterFromTensors ([TensorR] -> STensors
STensorsRec [TensorR]
tensorsR)
  let simplifiedRegiser :: ByteR
simplifiedRegiser = [RecordR] -> ByteR
simplifyRegister ([SRecord] -> [RecordR]
fromSRecordArrayRec [SRecord]
registerR)
  let simplifiedRegiser' :: ByteR
simplifiedRegiser' = Sortable -> ByteR
fromSortRec (SortAlg
subAlg (ByteR -> Sortable
SortRec ByteR
simplifiedRegiser))
  let registerR' :: [RecordR]
registerR' =
        ByteR -> [RecordR] -> [RecordR]
applySortingFromSimplifiedRegister
          ByteR
simplifiedRegiser'
          ([SRecord] -> [RecordR]
fromSRecordArrayRec [SRecord]
registerR)
  ([RecordR]
registerR', [TensorR] -> MemoryR
TensorMemR [TensorR]
tensorsR)

-- | For each Tensor, produces a Record by combining the top bit of the
--  Tensor with an index value for its Address

-- | Note that this output is not sorted. Sorting is done in the
--   getTensorFromTensors function

-- | ==== __Examples__
-- >>> getRegisterFromTensors (STensorsBit [([(0,13),(1,18)],ByteMem [[11,13],[15,18]]),([(0,14),(1,17)],ByteMem [[12,14],[16,17]]),([(0,3),(1,7)],ByteMem [[1,3],[5,7]]),([(0,4),(1,8)],ByteMem [[2,4],[6,8]])])
-- [SRecordBit (0,18),SRecordBit (1,17),SRecordBit (2,7),SRecordBit (3,8)]
getRegisterFromTensors :: STensors -> [SRecord]
getRegisterFromTensors :: STensors -> [SRecord]
getRegisterFromTensors (STensorsBit [Tensor]
tensors) = [Tensor] -> [SRecord]
getRegisterFromTensorsB [Tensor]
tensors
getRegisterFromTensors (STensorsRec [TensorR]
tensors) = [TensorR] -> [SRecord]
getRegisterFromTensorsR [TensorR]
tensors

getRegisterFromTensorsB :: [Tensor] -> [SRecord]
getRegisterFromTensorsB :: [Tensor] -> [SRecord]
getRegisterFromTensorsB [Tensor]
tensors = [Tensor] -> [SRecord] -> [SRecord]
acc [Tensor]
tensors []
  where
    acc :: [Tensor] -> [SRecord] -> [SRecord]
    acc :: [Tensor] -> [SRecord] -> [SRecord]
acc [] [SRecord]
records = [SRecord]
records
    acc (([], Memory
_) : [Tensor]
remainingTensors) [SRecord]
records = [Tensor] -> [SRecord] -> [SRecord]
acc [Tensor]
remainingTensors [SRecord]
records
    acc (Tensor
tensor : [Tensor]
remainingTensors) [SRecord]
records =
      [Tensor] -> [SRecord] -> [SRecord]
acc
        [Tensor]
remainingTensors
        ( [SRecord]
records
            [SRecord] -> [SRecord] -> [SRecord]
forall a. [a] -> [a] -> [a]
++ [ TopBitR -> SRecord
SRecordBit
                   ( Int
i,
                     SBit -> Int
fromSBitBit
                       (STensor -> SBit
getTopBitFromTensorStack (Tensor -> STensor
STensorBit Tensor
tensor))
                   )
               ]
        )
      where
        i :: Int
i = [SRecord] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SRecord]
records

getRegisterFromTensorsR :: [TensorR] -> [SRecord]
getRegisterFromTensorsR :: [TensorR] -> [SRecord]
getRegisterFromTensorsR [TensorR]
tensorsR = [TensorR] -> [SRecord] -> [SRecord]
acc [TensorR]
tensorsR []
  where
    acc :: [TensorR] -> [SRecord] -> [SRecord]
    acc :: [TensorR] -> [SRecord] -> [SRecord]
acc [] [SRecord]
records = [SRecord]
records
    acc (([], MemoryR
_) : [TensorR]
remainingTensorsR) [SRecord]
records = [TensorR] -> [SRecord] -> [SRecord]
acc [TensorR]
remainingTensorsR [SRecord]
records
    acc (TensorR
tensorR : [TensorR]
remainingTensorsR) [SRecord]
records =
      [TensorR] -> [SRecord] -> [SRecord]
acc
        [TensorR]
remainingTensorsR
        ( [SRecord]
records
            [SRecord] -> [SRecord] -> [SRecord]
forall a. [a] -> [a] -> [a]
++ [ RecordR -> SRecord
SRecordRec
                   ( Int
i,
                     SBit -> TopBitR
fromSBitRec
                       (STensor -> SBit
getTopBitFromTensorStack (TensorR -> STensor
STensorRec TensorR
tensorR))
                   )
               ]
        )
      where
        i :: Int
i = [SRecord] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SRecord]
records

-- | Get the top Bit from a TensorStack

-- | The top Bit is the last Bit in the last Byte referenced in the last record
--   of the Tensor referenced in the last record of the last Tensor of...
--   and so on until you reach the top level of the TensorStack

-- | This is also expected to be the highest value in the TensorStack

-- | ==== __Examples__
-- >>> getTopBitFromTensorStack (STensorBit ([(0,28),(1,38)],TensorMem [([(0,27),(1,28)],TensorMem [([(0,23),(1,27)],ByteMem [[21,23],[25,27]]),([(0,24),(1,28)],ByteMem [[22,24],[26,28]])]),([(1,37),(0,38)],TensorMem [([(0,33),(1,38)],ByteMem [[31,33],[35,38]]),([(0,34),(1,37)],ByteMem [[32,14],[36,37]])])]))
-- SBitBit 38
getTopBitFromTensorStack :: STensor -> SBit
getTopBitFromTensorStack :: STensor -> SBit
getTopBitFromTensorStack (STensorBit Tensor
tensor) =
  Tensor -> SBit
getTopBitFromTensorStackB Tensor
tensor
getTopBitFromTensorStack (STensorRec TensorR
tensorR) =
  TensorR -> SBit
getTopBitFromTensorStackR TensorR
tensorR

getTopBitFromTensorStackB :: Tensor -> SBit
getTopBitFromTensorStackB :: Tensor -> SBit
getTopBitFromTensorStackB (ByteR
register, Memory
_) = Int -> SBit
SBitBit (TopBitR -> Int
forall a b. (a, b) -> b
snd (ByteR -> TopBitR
forall a. HasCallStack => [a] -> a
last ByteR
register))

getTopBitFromTensorStackR :: TensorR -> SBit
getTopBitFromTensorStackR :: TensorR -> SBit
getTopBitFromTensorStackR ([RecordR]
registerR, MemoryR
_) = TopBitR -> SBit
SBitRec (RecordR -> TopBitR
forall a b. (a, b) -> b
snd ([RecordR] -> RecordR
forall a. HasCallStack => [a] -> a
last [RecordR]
registerR))