{-# OPTIONS -Wall -fno-warn-orphans -fno-warn-missing-signatures #-} {-# LANGUAGE CPP #-} #include "fusion-phases.h" -- | Distribution of Segment Descriptors module Data.Array.Parallel.Unlifted.Distributed.Types.USegd ( mkDUSegd , lengthD , takeLengthsD , takeIndicesD , takeElementsD) where import Data.Array.Parallel.Unlifted.Distributed.Types.Base import Data.Array.Parallel.Unlifted.Sequential.USegd (USegd) import Data.Array.Parallel.Unlifted.Sequential.Vector (Vector) import Data.Array.Parallel.Pretty import Control.Monad import qualified Data.Array.Parallel.Unlifted.Distributed.Types.Vector as DV import qualified Data.Array.Parallel.Unlifted.Sequential.USegd as USegd import Prelude as P instance DT USegd where data Dist USegd = DUSegd !(Dist (Vector Int)) -- segment lengths !(Dist (Vector Int)) -- segment indices !(Dist Int) -- number of elements in this chunk data MDist USegd s = MDUSegd !(MDist (Vector Int) s) -- segment lengths !(MDist (Vector Int) s) -- segment indices !(MDist Int s) -- number of elements in this chunk indexD str (DUSegd lens idxs eles) i = USegd.mkUSegd (indexD (str ++ "/indexD[USegd]") lens i) (indexD (str ++ "/indexD[USegd]") idxs i) (indexD (str ++ "/indexD[USegd]") eles i) newMD g = liftM3 MDUSegd (newMD g) (newMD g) (newMD g) readMD (MDUSegd lens idxs eles) i = liftM3 USegd.mkUSegd (readMD lens i) (readMD idxs i) (readMD eles i) writeMD (MDUSegd lens idxs eles) i segd = do writeMD lens i (USegd.takeLengths segd) writeMD idxs i (USegd.takeIndices segd) writeMD eles i (USegd.takeElements segd) unsafeFreezeMD (MDUSegd lens idxs eles) = liftM3 DUSegd (unsafeFreezeMD lens) (unsafeFreezeMD idxs) (unsafeFreezeMD eles) deepSeqD segd z = deepSeqD (USegd.takeLengths segd) $ deepSeqD (USegd.takeIndices segd) $ deepSeqD (USegd.takeElements segd) z sizeD (DUSegd _ _ eles) = sizeD eles sizeMD (MDUSegd _ _ eles) = sizeMD eles measureD segd = "Segd " P.++ show (USegd.length segd) P.++ " " P.++ show (USegd.takeElements segd) instance PprPhysical (Dist USegd) where pprp (DUSegd lens indices elements) = text "DUSegd" $$ (nest 7 $ vcat [ text "lengths: " <+> pprp lens , text "indices: " <+> pprp indices , text "elements:" <+> pprp elements]) -- | O(1). Construct a distributed segment descriptor mkDUSegd :: Dist (Vector Int) -- ^ segment lengths -> Dist (Vector Int) -- ^ segment indices -> Dist Int -- ^ number of elements in each chunk -> Dist USegd mkDUSegd = DUSegd -- | O(1). Yield the overall number of segments. lengthD :: Dist USegd -> Dist Int lengthD (DUSegd lens _ _) = DV.lengthD lens {-# INLINE_DIST lengthD #-} -- | O(1). Yield the lengths of the individual segments. takeLengthsD :: Dist USegd -> Dist (Vector Int) takeLengthsD (DUSegd lens _ _ ) = lens {-# INLINE_DIST takeLengthsD #-} -- | O(1). Yield the segment indices of a segment descriptor. takeIndicesD :: Dist USegd -> Dist (Vector Int) takeIndicesD (DUSegd _ idxs _) = idxs {-# INLINE_DIST takeIndicesD #-} -- | O(1). Yield the number of data elements. takeElementsD :: Dist USegd -> Dist Int takeElementsD (DUSegd _ _ dns) = dns {-# INLINE_DIST takeElementsD #-}