{-# OPTIONS -fno-warn-incomplete-patterns #-} {-# LANGUAGE CPP #-} #include "fusion-phases.h" -- | Distributed types. module Data.Array.Parallel.Unlifted.Distributed.Types ( -- * Distributed types DT, Dist, MDist, DPrim(..), -- * Operations on immutable distributed types indexD, unitD, zipD, unzipD, fstD, sndD, lengthD, newD, -- zipSD, unzipSD, fstSD, sndSD, deepSeqD, lengthUSegdD, lengthsUSegdD, indicesUSegdD, elementsUSegdD, -- * Operations on mutable distributed types newMD, readMD, writeMD, unsafeFreezeMD, -- * Assertions checkGangD, checkGangMD, -- * Debugging functions sizeD, sizeMD, measureD, debugD ) where import Data.Array.Parallel.Unlifted.Distributed.Gang ( Gang, gangSize ) import Data.Array.Parallel.Unlifted.Sequential.Vector ( Unbox, Vector ) import qualified Data.Array.Parallel.Unlifted.Sequential.Vector as V import Data.Array.Parallel.Unlifted.Sequential.Segmented import Data.Array.Parallel.Base import qualified Data.Vector.Unboxed as V import qualified Data.Vector.Unboxed.Mutable as MV import qualified Data.Vector as BV import qualified Data.Vector.Mutable as MBV import Data.Word (Word8) import Control.Monad (liftM, liftM2, liftM3) import Data.List ( intercalate ) infixl 9 `indexD` here s = "Data.Array.Parallel.Unlifted.Distributed.Types." ++ s -- Distributed Types ---------------------------------------------------------- -- | Class of distributable types. Instances of 'DT' can be -- distributed across all workers of a 'Gang'. -- All such types must be hyperstrict as we do not want to pass thunks -- into distributed computations. class DT a where data Dist a data MDist a :: * -> * -- | Extract a single element of an immutable distributed value. indexD :: Dist a -> Int -> a -- | Create an unitialised distributed value for the given 'Gang'. -- The gang is used (only) to know how many elements are needed -- in the distributed value. newMD :: Gang -> ST s (MDist a s) -- | Extract an element from a mutable distributed value. readMD :: MDist a s -> Int -> ST s a -- | Write an element of a mutable distributed value. writeMD :: MDist a s -> Int -> a -> ST s () -- | Unsafely freeze a mutable distributed value. unsafeFreezeMD :: MDist a s -> ST s (Dist a) deepSeqD :: a -> b -> b deepSeqD = seq -- Debugging ------------------------ -- | Number of elements in the distributed value. -- For debugging only, as we shouldn't depend on the size of the gang. sizeD :: Dist a -> Int -- | Number of elements in the mutable distributed value. -- For debugging only, as we shouldn't care about the actual number. sizeMD :: MDist a s -> Int -- | Show a distributed value. -- For debugging only. measureD :: a -> String measureD _ = "None" -- Show instance (for debugging only) instance (Show a, DT a) => Show (Dist a) where show d = show (Prelude.map (indexD d) [0 .. sizeD d - 1]) -- Checking ------------------------------------------------------------------- -- | Check that the sizes of the 'Gang' and of the distributed value match. checkGangD :: DT a => String -> Gang -> Dist a -> b -> b checkGangD loc g d v = checkEq loc "Wrong gang" (gangSize g) (sizeD d) v -- | Check that the sizes of the 'Gang' and of the mutable distributed value match. checkGangMD :: DT a => String -> Gang -> MDist a s -> b -> b checkGangMD loc g d v = checkEq loc "Wrong gang" (gangSize g) (sizeMD d) v -- Operations ----------------------------------------------------------------- -- | Given a computation that can write its result to a mutable distributed value, -- run the computation to generate an immutable distributed value. newD :: DT a => Gang -> (forall s . MDist a s -> ST s ()) -> Dist a newD g init = runST (do mdt <- newMD g init mdt unsafeFreezeMD mdt) -- | Show all members of a distributed value. debugD :: DT a => Dist a -> String debugD d = "[" ++ intercalate "," [measureD (indexD d i) | i <- [0 .. sizeD d-1]] ++ "]" -- DPrim ---------------------------------------------------------------------- -- | For distributed primitive values, we can just store all the members in -- a vector. The vector has the same length as the number of threads in the gang. -- class Unbox e => DPrim e where -- | Make an immutable distributed value. mkDPrim :: V.Vector e -> Dist e -- | Unpack an immutable distributed value back into a vector. unDPrim :: Dist e -> V.Vector e -- | Make a mutable distributed value. mkMDPrim :: MV.STVector s e -> MDist e s -- | Unpack a mutable distributed value back into a vector. unMDPrim :: MDist e s -> MV.STVector s e -- | Get the member corresponding to a thread index. primIndexD :: DPrim a => Dist a -> Int -> a {-# INLINE primIndexD #-} primIndexD = (V.!) . unDPrim -- | Create a new distributed value, having as many members as threads -- in the given 'Gang'. primNewMD :: DPrim a => Gang -> ST s (MDist a s) {-# INLINE primNewMD #-} primNewMD = liftM mkMDPrim . MV.new . gangSize -- | Read the member of a distributed value corresponding to the given thread index. primReadMD :: DPrim a => MDist a s -> Int -> ST s a {-# INLINE primReadMD #-} primReadMD = MV.read . unMDPrim -- | Write the member of a distributed value corresponding to the given thread index. primWriteMD :: DPrim a => MDist a s -> Int -> a -> ST s () {-# INLINE primWriteMD #-} primWriteMD = MV.write . unMDPrim -- | Freeze a mutable distributed value to an immutable one. -- You promise not to update the mutable one any further. primUnsafeFreezeMD :: DPrim a => MDist a s -> ST s (Dist a) {-# INLINE primUnsafeFreezeMD #-} primUnsafeFreezeMD = liftM mkDPrim . V.unsafeFreeze . unMDPrim -- | Get the size of a distributed value, that is, the number of threads -- in the gang that it was created for. primSizeD :: DPrim a => Dist a -> Int {-# INLINE primSizeD #-} primSizeD = V.length . unDPrim -- | Get the size of a distributed mutable value, that is, the number of threads -- in the gang it was created for. primSizeMD :: DPrim a => MDist a s -> Int {-# INLINE primSizeMD #-} primSizeMD = MV.length . unMDPrim -- Unit ----------------------------------------------------------------------- instance DT () where data Dist () = DUnit !Int data MDist () s = MDUnit !Int indexD (DUnit n) i = check (here "indexD[()]") n i $ () newMD = return . MDUnit . gangSize readMD (MDUnit n) i = check (here "readMD[()]") n i $ return () writeMD (MDUnit n) i () = check (here "writeMD[()]") n i $ return () unsafeFreezeMD (MDUnit n) = return $ DUnit n sizeD = error "dph-prim-par:sizeD[()] undefined" sizeMD = error "dph-prim-par:sizeMD[()] undefined" -- | Yield a distributed unit. unitD :: Gang -> Dist () {-# INLINE_DIST unitD #-} unitD = DUnit . gangSize -- Bool ----------------------------------------------------------------------- instance DPrim Bool where mkDPrim = DBool unDPrim (DBool a) = a mkMDPrim = MDBool unMDPrim (MDBool a) = a instance DT Bool where data Dist Bool = DBool !(V.Vector Bool) data MDist Bool s = MDBool !(MV.STVector s Bool) indexD = primIndexD newMD = primNewMD readMD = primReadMD writeMD = primWriteMD unsafeFreezeMD = primUnsafeFreezeMD sizeD = primSizeD sizeMD = primSizeMD -- Char ----------------------------------------------------------------------- instance DPrim Char where mkDPrim = DChar unDPrim (DChar a) = a mkMDPrim = MDChar unMDPrim (MDChar a) = a instance DT Char where data Dist Char = DChar !(V.Vector Char) data MDist Char s = MDChar !(MV.STVector s Char) indexD = primIndexD newMD = primNewMD readMD = primReadMD writeMD = primWriteMD unsafeFreezeMD = primUnsafeFreezeMD sizeD = primSizeD sizeMD = primSizeMD -- Int ------------------------------------------------------------------------ instance DPrim Int where mkDPrim = DInt unDPrim (DInt a) = a mkMDPrim = MDInt unMDPrim (MDInt a) = a instance DT Int where data Dist Int = DInt !(V.Vector Int) data MDist Int s = MDInt !(MV.STVector s Int) indexD = primIndexD newMD = primNewMD readMD = primReadMD writeMD = primWriteMD unsafeFreezeMD = primUnsafeFreezeMD sizeD = primSizeD sizeMD = primSizeMD measureD n = "Int " ++ show n -- Word8 ---------------------------------------------------------------------- instance DPrim Word8 where mkDPrim = DWord8 unDPrim (DWord8 a) = a mkMDPrim = MDWord8 unMDPrim (MDWord8 a) = a instance DT Word8 where data Dist Word8 = DWord8 !(V.Vector Word8) data MDist Word8 s = MDWord8 !(MV.STVector s Word8) indexD = primIndexD newMD = primNewMD readMD = primReadMD writeMD = primWriteMD unsafeFreezeMD = primUnsafeFreezeMD sizeD = primSizeD sizeMD = primSizeMD -- Float ---------------------------------------------------------------------- instance DPrim Float where mkDPrim = DFloat unDPrim (DFloat a) = a mkMDPrim = MDFloat unMDPrim (MDFloat a) = a instance DT Float where data Dist Float = DFloat !(V.Vector Float) data MDist Float s = MDFloat !(MV.STVector s Float) indexD = primIndexD newMD = primNewMD readMD = primReadMD writeMD = primWriteMD unsafeFreezeMD = primUnsafeFreezeMD sizeD = primSizeD sizeMD = primSizeMD -- Double --------------------------------------------------------------------- instance DPrim Double where mkDPrim = DDouble unDPrim (DDouble a) = a mkMDPrim = MDDouble unMDPrim (MDDouble a) = a instance DT Double where data Dist Double = DDouble !(V.Vector Double) data MDist Double s = MDDouble !(MV.STVector s Double) indexD = primIndexD newMD = primNewMD readMD = primReadMD writeMD = primWriteMD unsafeFreezeMD = primUnsafeFreezeMD sizeD = primSizeD sizeMD = primSizeMD -- Pairs ---------------------------------------------------------------------- instance (DT a, DT b) => DT (a,b) where data Dist (a,b) = DProd !(Dist a) !(Dist b) data MDist (a,b) s = MDProd !(MDist a s) !(MDist b s) indexD d i = (fstD d `indexD` i,sndD d `indexD` i) newMD g = liftM2 MDProd (newMD g) (newMD g) readMD (MDProd xs ys) i = liftM2 (,) (readMD xs i) (readMD ys i) writeMD (MDProd xs ys) i (x,y) = writeMD xs i x >> writeMD ys i y unsafeFreezeMD (MDProd xs ys) = liftM2 DProd (unsafeFreezeMD xs) (unsafeFreezeMD ys) {-# INLINE deepSeqD #-} deepSeqD (x,y) z = deepSeqD x (deepSeqD y z) sizeD (DProd x _) = sizeD x sizeMD (MDProd x _) = sizeMD x measureD (x,y) = "Pair " ++ "(" ++ measureD x ++ ") (" ++ measureD y ++ ")" -- | Pairing of distributed values. -- /The two values must belong to the same/ 'Gang'. zipD :: (DT a, DT b) => Dist a -> Dist b -> Dist (a,b) {-# INLINE [0] zipD #-} zipD !x !y = checkEq (here "zipDT") "Size mismatch" (sizeD x) (sizeD y) $ DProd x y -- | Unpairing of distributed values. unzipD :: (DT a, DT b) => Dist (a,b) -> (Dist a, Dist b) {-# INLINE_DIST unzipD #-} unzipD (DProd dx dy) = (dx,dy) -- | Extract the first elements of a distributed pair. fstD :: (DT a, DT b) => Dist (a,b) -> Dist a {-# INLINE_DIST fstD #-} fstD = fst . unzipD -- | Extract the second elements of a distributed pair. sndD :: (DT a, DT b) => Dist (a,b) -> Dist b {-# INLINE_DIST sndD #-} sndD = snd . unzipD -- Maybe ---------------------------------------------------------------------- instance DT a => DT (Maybe a) where data Dist (Maybe a) = DMaybe !(Dist Bool) !(Dist a) data MDist (Maybe a) s = MDMaybe !(MDist Bool s) !(MDist a s) indexD (DMaybe bs as) i | bs `indexD` i = Just $ as `indexD` i | otherwise = Nothing newMD g = liftM2 MDMaybe (newMD g) (newMD g) readMD (MDMaybe bs as) i = do b <- readMD bs i if b then liftM Just $ readMD as i else return Nothing writeMD (MDMaybe bs as) i Nothing = writeMD bs i False writeMD (MDMaybe bs as) i (Just x) = writeMD bs i True >> writeMD as i x unsafeFreezeMD (MDMaybe bs as) = liftM2 DMaybe (unsafeFreezeMD bs) (unsafeFreezeMD as) {-# INLINE deepSeqD #-} deepSeqD Nothing z = z deepSeqD (Just x) z = deepSeqD x z sizeD (DMaybe b _) = sizeD b sizeMD (MDMaybe b _) = sizeMD b measureD Nothing = "Nothing" measureD (Just x) = "Just (" ++ measureD x ++ ")" -- Vector --------------------------------------------------------------------- instance Unbox a => DT (V.Vector a) where data Dist (Vector a) = DVector !(Dist Int) !(BV.Vector (Vector a)) data MDist (Vector a) s = MDVector !(MDist Int s) !(MBV.STVector s (Vector a)) indexD (DVector _ a) i = a BV.! i newMD g = liftM2 MDVector (newMD g) (MBV.replicate (gangSize g) (error "MDist (Vector a) - uninitalised")) readMD (MDVector _ marr) = MBV.read marr writeMD (MDVector mlen marr) i a = do writeMD mlen i (V.length a) MBV.write marr i $! a unsafeFreezeMD (MDVector len a) = liftM2 DVector (unsafeFreezeMD len) (BV.unsafeFreeze a) sizeD (DVector _ a) = BV.length a sizeMD (MDVector _ a) = MBV.length a measureD xs = "Vector " ++ show (V.length xs) -- | Yield the distributed length of a distributed array. lengthD :: Unbox a => Dist (Vector a) -> Dist Int lengthD (DVector l _) = l -- USegd ---------------------------------------------------------------------- instance DT USegd where data Dist USegd = DUSegd !(Dist (Vector Int)) !(Dist (Vector Int)) !(Dist Int) data MDist USegd s = MDUSegd !(MDist (Vector Int) s) !(MDist (Vector Int) s) !(MDist Int s) indexD (DUSegd lens idxs eles) i = mkUSegd (indexD lens i) (indexD idxs i) (indexD eles i) newMD g = liftM3 MDUSegd (newMD g) (newMD g) (newMD g) readMD (MDUSegd lens idxs eles) i = liftM3 mkUSegd (readMD lens i) (readMD idxs i) (readMD eles i) writeMD (MDUSegd lens idxs eles) i segd = do writeMD lens i (lengthsUSegd segd) writeMD idxs i (indicesUSegd segd) writeMD eles i (elementsUSegd segd) unsafeFreezeMD (MDUSegd lens idxs eles) = liftM3 DUSegd (unsafeFreezeMD lens) (unsafeFreezeMD idxs) (unsafeFreezeMD eles) deepSeqD segd z = deepSeqD (lengthsUSegd segd) $ deepSeqD (indicesUSegd segd) $ deepSeqD (elementsUSegd segd) z sizeD (DUSegd _ _ eles) = sizeD eles sizeMD (MDUSegd _ _ eles) = sizeMD eles measureD segd = "Segd " ++ show (lengthUSegd segd) ++ " " ++ show (elementsUSegd segd) lengthUSegdD :: Dist USegd -> Dist Int {-# INLINE_DIST lengthUSegdD #-} lengthUSegdD (DUSegd lens _ _) = lengthD lens lengthsUSegdD :: Dist USegd -> Dist (Vector Int) {-# INLINE_DIST lengthsUSegdD #-} lengthsUSegdD (DUSegd lens _ _ ) = lens indicesUSegdD :: Dist USegd -> Dist (Vector Int) {-# INLINE_DIST indicesUSegdD #-} indicesUSegdD (DUSegd _ idxs _) = idxs elementsUSegdD :: Dist USegd -> Dist Int {-# INLINE_DIST elementsUSegdD #-} elementsUSegdD (DUSegd _ _ dns) = dns