-- | This module provides functions to reduce a list of TensorStacks into a
--   more compact list of TensorStacks
--
--   Functions ending in "R" are for sorting Records when used in a recursive
--   Tensort variant
--
--   TODO: See if we can clean up the type conversion here
module Data.Tensort.Utils.Reduce (reduceTensorStacks) where

import Data.Tensort.Utils.Compose (createTensor)
import Data.Tensort.Utils.Split (splitEvery)
import Data.Tensort.Utils.Types
  ( Memory (..),
    MemoryR (..),
    SMemory (..),
    STensorStack,
    STensorStacks,
    STensors (..),
    TensorStack,
    TensorStackR,
    TensortProps (..),
    fromSTensorBit,
    fromSTensorRec,
  )

-- | Take a list of TensorStacks and group them together in new
--   TensorStacks, each containing bytesize number of Tensors (former
--   TensorStacks), until the number of TensorStacks is equal to the bytesize

--   The Registers of the new TensorStacks are bubblesorted, as usual

-- | ==== __Examples__
-- >>> import Data.Tensort.Subalgorithms.Bubblesort (bubblesort)
-- >>> import Data.Tensort.Utils.MkTsProps (mkTsProps)
-- >>> reduceTensorStacks (mkTsProps 2 bubblesort) (STensorsBit [([(0, 33), (1, 38)], ByteMem [[31, 33], [35, 38]]), ([(0, 34), (1, 37)], ByteMem [[32, 14], [36, 37]]), ([(0, 23), (1, 27)], ByteMem [[21, 23], [25, 27]]), ([(0, 24), (1, 28)], ByteMem [[22, 24], [26, 28]]),([(0,13),(1,18)],ByteMem [[11,13],[15,18]]),([(0,14),(1,17)],ByteMem [[12,14],[16,17]]),([(0,3),(1,7)],ByteMem [[1,3],[5,7]]),([(0,4),(1,8)],ByteMem [[2,4],[6,8]])])
-- STensorBit ([(1,18),(0,38)],TensorMem [([(0,28),(1,38)],TensorMem [([(0,27),(1,28)],TensorMem [([(0,23),(1,27)],ByteMem [[21,23],[25,27]]),([(0,24),(1,28)],ByteMem [[22,24],[26,28]])]),([(1,37),(0,38)],TensorMem [([(0,33),(1,38)],ByteMem [[31,33],[35,38]]),([(0,34),(1,37)],ByteMem [[32,14],[36,37]])])]),([(0,8),(1,18)],TensorMem [([(0,7),(1,8)],TensorMem [([(0,3),(1,7)],ByteMem [[1,3],[5,7]]),([(0,4),(1,8)],ByteMem [[2,4],[6,8]])]),([(1,17),(0,18)],TensorMem [([(0,13),(1,18)],ByteMem [[11,13],[15,18]]),([(0,14),(1,17)],ByteMem [[12,14],[16,17]])])])])
reduceTensorStacks :: TensortProps -> STensorStacks -> STensorStack
reduceTensorStacks :: TensortProps -> STensorStacks -> STensorStack
reduceTensorStacks TensortProps
tsProps (STensorsBit [Tensor]
tensorStacks) =
  TensortProps -> [Tensor] -> STensorStack
reduceTensorStacksB TensortProps
tsProps [Tensor]
tensorStacks
reduceTensorStacks TensortProps
tsProps (STensorsRec [TensorR]
tensorStacks) =
  TensortProps -> [TensorR] -> STensorStack
reduceTensorStacksR TensortProps
tsProps [TensorR]
tensorStacks

reduceTensorStacksB :: TensortProps -> [TensorStack] -> STensorStack
reduceTensorStacksB :: TensortProps -> [Tensor] -> STensorStack
reduceTensorStacksB TensortProps
tsProps [Tensor]
tensorStacks = do
  let newTensorStacks :: [Tensor]
newTensorStacks = TensortProps -> [Tensor] -> [Tensor]
reduceTensorStacksSinglePass TensortProps
tsProps [Tensor]
tensorStacks
  if [Tensor] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Tensor]
newTensorStacks Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= TensortProps -> Int
bytesize TensortProps
tsProps
    then
      SortAlg -> SMemory -> STensorStack
createTensor
        (TensortProps -> SortAlg
subAlgorithm TensortProps
tsProps)
        (Memory -> SMemory
SMemoryBit ([Tensor] -> Memory
TensorMem [Tensor]
newTensorStacks))
    else TensortProps -> [Tensor] -> STensorStack
reduceTensorStacksB TensortProps
tsProps [Tensor]
newTensorStacks

reduceTensorStacksR :: TensortProps -> [TensorStackR] -> STensorStack
reduceTensorStacksR :: TensortProps -> [TensorR] -> STensorStack
reduceTensorStacksR TensortProps
tsProps [TensorR]
tensorStacks = do
  let newTensorStacks :: [TensorR]
newTensorStacks = TensortProps -> [TensorR] -> [TensorR]
reduceTensorStacksRSinglePass TensortProps
tsProps [TensorR]
tensorStacks
  if [TensorR] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TensorR]
newTensorStacks Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= TensortProps -> Int
bytesize TensortProps
tsProps
    then
      SortAlg -> SMemory -> STensorStack
createTensor
        (TensortProps -> SortAlg
subAlgorithm TensortProps
tsProps)
        (MemoryR -> SMemory
SMemoryRec ([TensorR] -> MemoryR
TensorMemR [TensorR]
newTensorStacks))
    else TensortProps -> [TensorR] -> STensorStack
reduceTensorStacksR TensortProps
tsProps [TensorR]
newTensorStacks

-- | Take a list of TensorStacks  and group them together in new
--   TensorStacks each containing bytesize number of Tensors (former
--   TensorStacks)

--   The Registers of the new TensorStacks are bubblesorted, as usual

-- | ==== __Examples__
-- >>> import Data.Tensort.Subalgorithms.Bubblesort (bubblesort)
-- >>> import Data.Tensort.Utils.MkTsProps (mkTsProps)
-- >>> reduceTensorStacksSinglePass (mkTsProps 2 bubblesort) [([(0,13),(1,18)],ByteMem [[11,13],[15,18]]),([(0,14),(1,17)],ByteMem [[12,14],[16,17]]),([(0,3),(1,7)],ByteMem [[1,3],[5,7]]),([(0,4),(1,8)],ByteMem [[2,4],[6,8]])]
-- [([(0,7),(1,8)],TensorMem [([(0,3),(1,7)],ByteMem [[1,3],[5,7]]),([(0,4),(1,8)],ByteMem [[2,4],[6,8]])]),([(1,17),(0,18)],TensorMem [([(0,13),(1,18)],ByteMem [[11,13],[15,18]]),([(0,14),(1,17)],ByteMem [[12,14],[16,17]])])]
reduceTensorStacksSinglePass ::
  TensortProps ->
  [TensorStack] ->
  [TensorStack]
reduceTensorStacksSinglePass :: TensortProps -> [Tensor] -> [Tensor]
reduceTensorStacksSinglePass TensortProps
tsProps [Tensor]
tensorStacks =
  ([Tensor] -> [Tensor] -> [Tensor])
-> [Tensor] -> [[Tensor]] -> [Tensor]
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr [Tensor] -> [Tensor] -> [Tensor]
acc [] (Int -> [Tensor] -> [[Tensor]]
forall a. Int -> [a] -> [[a]]
splitEvery (TensortProps -> Int
bytesize TensortProps
tsProps) [Tensor]
tensorStacks)
  where
    acc :: [TensorStack] -> [TensorStack] -> [TensorStack]
    acc :: [Tensor] -> [Tensor] -> [Tensor]
acc [Tensor]
tensorStack [Tensor]
newTensorStacks =
      [Tensor]
newTensorStacks
        [Tensor] -> [Tensor] -> [Tensor]
forall a. [a] -> [a] -> [a]
++ [ STensorStack -> Tensor
fromSTensorBit
               ( SortAlg -> SMemory -> STensorStack
createTensor
                   (TensortProps -> SortAlg
subAlgorithm TensortProps
tsProps)
                   (Memory -> SMemory
SMemoryBit ([Tensor] -> Memory
TensorMem [Tensor]
tensorStack))
               )
           ]

reduceTensorStacksRSinglePass ::
  TensortProps ->
  [TensorStackR] ->
  [TensorStackR]
reduceTensorStacksRSinglePass :: TensortProps -> [TensorR] -> [TensorR]
reduceTensorStacksRSinglePass TensortProps
tsProps [TensorR]
tensorStacks =
  ([TensorR] -> [TensorR] -> [TensorR])
-> [TensorR] -> [[TensorR]] -> [TensorR]
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr [TensorR] -> [TensorR] -> [TensorR]
acc [] (Int -> [TensorR] -> [[TensorR]]
forall a. Int -> [a] -> [[a]]
splitEvery (TensortProps -> Int
bytesize TensortProps
tsProps) [TensorR]
tensorStacks)
  where
    acc :: [TensorStackR] -> [TensorStackR] -> [TensorStackR]
    acc :: [TensorR] -> [TensorR] -> [TensorR]
acc [TensorR]
tensorStack [TensorR]
newTensorStacks =
      [TensorR]
newTensorStacks
        [TensorR] -> [TensorR] -> [TensorR]
forall a. [a] -> [a] -> [a]
++ [ STensorStack -> TensorR
fromSTensorRec
               ( SortAlg -> SMemory -> STensorStack
createTensor
                   (TensortProps -> SortAlg
subAlgorithm TensortProps
tsProps)
                   (MemoryR -> SMemory
SMemoryRec ([TensorR] -> MemoryR
TensorMemR [TensorR]
tensorStack))
               )
           ]