module Data.Tensort.Utils.Reduce (reduceTensorStacks) where
import Data.Tensort.Utils.Compose (createTensor)
import Data.Tensort.Utils.Split (splitEvery)
import Data.Tensort.Utils.Types
( Memory (..),
MemoryR (..),
SMemory (..),
STensorStack,
STensorStacks,
STensors (..),
TensorStack,
TensorStackR,
TensortProps (..),
fromSTensorBit,
fromSTensorRec,
)
reduceTensorStacks :: TensortProps -> STensorStacks -> STensorStack
reduceTensorStacks :: TensortProps -> STensorStacks -> STensorStack
reduceTensorStacks TensortProps
tsProps (STensorsBit [Tensor]
tensorStacks) =
TensortProps -> [Tensor] -> STensorStack
reduceTensorStacksB TensortProps
tsProps [Tensor]
tensorStacks
reduceTensorStacks TensortProps
tsProps (STensorsRec [TensorR]
tensorStacks) =
TensortProps -> [TensorR] -> STensorStack
reduceTensorStacksR TensortProps
tsProps [TensorR]
tensorStacks
reduceTensorStacksB :: TensortProps -> [TensorStack] -> STensorStack
reduceTensorStacksB :: TensortProps -> [Tensor] -> STensorStack
reduceTensorStacksB TensortProps
tsProps [Tensor]
tensorStacks = do
let newTensorStacks :: [Tensor]
newTensorStacks = TensortProps -> [Tensor] -> [Tensor]
reduceTensorStacksSinglePass TensortProps
tsProps [Tensor]
tensorStacks
if [Tensor] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Tensor]
newTensorStacks Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= TensortProps -> Int
bytesize TensortProps
tsProps
then
SortAlg -> SMemory -> STensorStack
createTensor
(TensortProps -> SortAlg
subAlgorithm TensortProps
tsProps)
(Memory -> SMemory
SMemoryBit ([Tensor] -> Memory
TensorMem [Tensor]
newTensorStacks))
else TensortProps -> [Tensor] -> STensorStack
reduceTensorStacksB TensortProps
tsProps [Tensor]
newTensorStacks
reduceTensorStacksR :: TensortProps -> [TensorStackR] -> STensorStack
reduceTensorStacksR :: TensortProps -> [TensorR] -> STensorStack
reduceTensorStacksR TensortProps
tsProps [TensorR]
tensorStacks = do
let newTensorStacks :: [TensorR]
newTensorStacks = TensortProps -> [TensorR] -> [TensorR]
reduceTensorStacksRSinglePass TensortProps
tsProps [TensorR]
tensorStacks
if [TensorR] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TensorR]
newTensorStacks Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= TensortProps -> Int
bytesize TensortProps
tsProps
then
SortAlg -> SMemory -> STensorStack
createTensor
(TensortProps -> SortAlg
subAlgorithm TensortProps
tsProps)
(MemoryR -> SMemory
SMemoryRec ([TensorR] -> MemoryR
TensorMemR [TensorR]
newTensorStacks))
else TensortProps -> [TensorR] -> STensorStack
reduceTensorStacksR TensortProps
tsProps [TensorR]
newTensorStacks
reduceTensorStacksSinglePass ::
TensortProps ->
[TensorStack] ->
[TensorStack]
reduceTensorStacksSinglePass :: TensortProps -> [Tensor] -> [Tensor]
reduceTensorStacksSinglePass TensortProps
tsProps [Tensor]
tensorStacks =
([Tensor] -> [Tensor] -> [Tensor])
-> [Tensor] -> [[Tensor]] -> [Tensor]
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr [Tensor] -> [Tensor] -> [Tensor]
acc [] (Int -> [Tensor] -> [[Tensor]]
forall a. Int -> [a] -> [[a]]
splitEvery (TensortProps -> Int
bytesize TensortProps
tsProps) [Tensor]
tensorStacks)
where
acc :: [TensorStack] -> [TensorStack] -> [TensorStack]
acc :: [Tensor] -> [Tensor] -> [Tensor]
acc [Tensor]
tensorStack [Tensor]
newTensorStacks =
[Tensor]
newTensorStacks
[Tensor] -> [Tensor] -> [Tensor]
forall a. [a] -> [a] -> [a]
++ [ STensorStack -> Tensor
fromSTensorBit
( SortAlg -> SMemory -> STensorStack
createTensor
(TensortProps -> SortAlg
subAlgorithm TensortProps
tsProps)
(Memory -> SMemory
SMemoryBit ([Tensor] -> Memory
TensorMem [Tensor]
tensorStack))
)
]
reduceTensorStacksRSinglePass ::
TensortProps ->
[TensorStackR] ->
[TensorStackR]
reduceTensorStacksRSinglePass :: TensortProps -> [TensorR] -> [TensorR]
reduceTensorStacksRSinglePass TensortProps
tsProps [TensorR]
tensorStacks =
([TensorR] -> [TensorR] -> [TensorR])
-> [TensorR] -> [[TensorR]] -> [TensorR]
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr [TensorR] -> [TensorR] -> [TensorR]
acc [] (Int -> [TensorR] -> [[TensorR]]
forall a. Int -> [a] -> [[a]]
splitEvery (TensortProps -> Int
bytesize TensortProps
tsProps) [TensorR]
tensorStacks)
where
acc :: [TensorStackR] -> [TensorStackR] -> [TensorStackR]
acc :: [TensorR] -> [TensorR] -> [TensorR]
acc [TensorR]
tensorStack [TensorR]
newTensorStacks =
[TensorR]
newTensorStacks
[TensorR] -> [TensorR] -> [TensorR]
forall a. [a] -> [a] -> [a]
++ [ STensorStack -> TensorR
fromSTensorRec
( SortAlg -> SMemory -> STensorStack
createTensor
(TensortProps -> SortAlg
subAlgorithm TensortProps
tsProps)
(MemoryR -> SMemory
SMemoryRec ([TensorR] -> MemoryR
TensorMemR [TensorR]
tensorStack))
)
]