#include "fusion-phases.h"
module Data.Array.Parallel.Unlifted.Distributed.Types (
DT, Dist, MDist, DPrim(..),
indexD, unitD, zipD, unzipD, fstD, sndD, lengthD,
newD,
deepSeqD,
lengthUSegdD, lengthsUSegdD, indicesUSegdD, elementsUSegdD,
newMD, readMD, writeMD, unsafeFreezeMD,
checkGangD, checkGangMD,
sizeD, sizeMD, measureD, debugD
) where
import Data.Array.Parallel.Unlifted.Distributed.Gang (
Gang, gangSize )
import Data.Array.Parallel.Unlifted.Sequential.Vector ( Unbox, Vector )
import qualified Data.Array.Parallel.Unlifted.Sequential.Vector as V
import Data.Array.Parallel.Unlifted.Sequential.Segmented
import Data.Array.Parallel.Base
import qualified Data.Vector.Unboxed as V
import qualified Data.Vector.Unboxed.Mutable as MV
import qualified Data.Vector as BV
import qualified Data.Vector.Mutable as MBV
import Data.Word (Word8)
import Control.Monad (liftM, liftM2, liftM3)
import Data.List ( intercalate )
infixl 9 `indexD`
here s = "Data.Array.Parallel.Unlifted.Distributed.Types." ++ s
class DT a where
data Dist a
data MDist a :: * -> *
indexD :: Dist a -> Int -> a
newMD :: Gang -> ST s (MDist a s)
readMD :: MDist a s -> Int -> ST s a
writeMD :: MDist a s -> Int -> a -> ST s ()
unsafeFreezeMD :: MDist a s -> ST s (Dist a)
deepSeqD :: a -> b -> b
deepSeqD = seq
sizeD :: Dist a -> Int
sizeMD :: MDist a s -> Int
measureD :: a -> String
measureD _ = "None"
instance (Show a, DT a) => Show (Dist a) where
show d = show (Prelude.map (indexD d) [0 .. sizeD d 1])
checkGangD :: DT a => String -> Gang -> Dist a -> b -> b
checkGangD loc g d v = checkEq loc "Wrong gang" (gangSize g) (sizeD d) v
checkGangMD :: DT a => String -> Gang -> MDist a s -> b -> b
checkGangMD loc g d v = checkEq loc "Wrong gang" (gangSize g) (sizeMD d) v
newD :: DT a => Gang -> (forall s . MDist a s -> ST s ()) -> Dist a
newD g init =
runST (do
mdt <- newMD g
init mdt
unsafeFreezeMD mdt)
debugD :: DT a => Dist a -> String
debugD d = "["
++ intercalate "," [measureD (indexD d i) | i <- [0 .. sizeD d1]]
++ "]"
class Unbox e => DPrim e where
mkDPrim :: V.Vector e -> Dist e
unDPrim :: Dist e -> V.Vector e
mkMDPrim :: MV.STVector s e -> MDist e s
unMDPrim :: MDist e s -> MV.STVector s e
primIndexD :: DPrim a => Dist a -> Int -> a
primIndexD = (V.!) . unDPrim
primNewMD :: DPrim a => Gang -> ST s (MDist a s)
primNewMD = liftM mkMDPrim . MV.new . gangSize
primReadMD :: DPrim a => MDist a s -> Int -> ST s a
primReadMD = MV.read . unMDPrim
primWriteMD :: DPrim a => MDist a s -> Int -> a -> ST s ()
primWriteMD = MV.write . unMDPrim
primUnsafeFreezeMD :: DPrim a => MDist a s -> ST s (Dist a)
primUnsafeFreezeMD = liftM mkDPrim . V.unsafeFreeze . unMDPrim
primSizeD :: DPrim a => Dist a -> Int
primSizeD = V.length . unDPrim
primSizeMD :: DPrim a => MDist a s -> Int
primSizeMD = MV.length . unMDPrim
instance DT () where
data Dist () = DUnit !Int
data MDist () s = MDUnit !Int
indexD (DUnit n) i = check (here "indexD[()]") n i $ ()
newMD = return . MDUnit . gangSize
readMD (MDUnit n) i = check (here "readMD[()]") n i $
return ()
writeMD (MDUnit n) i () = check (here "writeMD[()]") n i $
return ()
unsafeFreezeMD (MDUnit n) = return $ DUnit n
sizeD = error "dph-prim-par:sizeD[()] undefined"
sizeMD = error "dph-prim-par:sizeMD[()] undefined"
unitD :: Gang -> Dist ()
unitD = DUnit . gangSize
instance DPrim Bool where
mkDPrim = DBool
unDPrim (DBool a) = a
mkMDPrim = MDBool
unMDPrim (MDBool a) = a
instance DT Bool where
data Dist Bool = DBool !(V.Vector Bool)
data MDist Bool s = MDBool !(MV.STVector s Bool)
indexD = primIndexD
newMD = primNewMD
readMD = primReadMD
writeMD = primWriteMD
unsafeFreezeMD = primUnsafeFreezeMD
sizeD = primSizeD
sizeMD = primSizeMD
instance DPrim Char where
mkDPrim = DChar
unDPrim (DChar a) = a
mkMDPrim = MDChar
unMDPrim (MDChar a) = a
instance DT Char where
data Dist Char = DChar !(V.Vector Char)
data MDist Char s = MDChar !(MV.STVector s Char)
indexD = primIndexD
newMD = primNewMD
readMD = primReadMD
writeMD = primWriteMD
unsafeFreezeMD = primUnsafeFreezeMD
sizeD = primSizeD
sizeMD = primSizeMD
instance DPrim Int where
mkDPrim = DInt
unDPrim (DInt a) = a
mkMDPrim = MDInt
unMDPrim (MDInt a) = a
instance DT Int where
data Dist Int = DInt !(V.Vector Int)
data MDist Int s = MDInt !(MV.STVector s Int)
indexD = primIndexD
newMD = primNewMD
readMD = primReadMD
writeMD = primWriteMD
unsafeFreezeMD = primUnsafeFreezeMD
sizeD = primSizeD
sizeMD = primSizeMD
measureD n = "Int " ++ show n
instance DPrim Word8 where
mkDPrim = DWord8
unDPrim (DWord8 a) = a
mkMDPrim = MDWord8
unMDPrim (MDWord8 a) = a
instance DT Word8 where
data Dist Word8 = DWord8 !(V.Vector Word8)
data MDist Word8 s = MDWord8 !(MV.STVector s Word8)
indexD = primIndexD
newMD = primNewMD
readMD = primReadMD
writeMD = primWriteMD
unsafeFreezeMD = primUnsafeFreezeMD
sizeD = primSizeD
sizeMD = primSizeMD
instance DPrim Float where
mkDPrim = DFloat
unDPrim (DFloat a) = a
mkMDPrim = MDFloat
unMDPrim (MDFloat a) = a
instance DT Float where
data Dist Float = DFloat !(V.Vector Float)
data MDist Float s = MDFloat !(MV.STVector s Float)
indexD = primIndexD
newMD = primNewMD
readMD = primReadMD
writeMD = primWriteMD
unsafeFreezeMD = primUnsafeFreezeMD
sizeD = primSizeD
sizeMD = primSizeMD
instance DPrim Double where
mkDPrim = DDouble
unDPrim (DDouble a) = a
mkMDPrim = MDDouble
unMDPrim (MDDouble a) = a
instance DT Double where
data Dist Double = DDouble !(V.Vector Double)
data MDist Double s = MDDouble !(MV.STVector s Double)
indexD = primIndexD
newMD = primNewMD
readMD = primReadMD
writeMD = primWriteMD
unsafeFreezeMD = primUnsafeFreezeMD
sizeD = primSizeD
sizeMD = primSizeMD
instance (DT a, DT b) => DT (a,b) where
data Dist (a,b) = DProd !(Dist a) !(Dist b)
data MDist (a,b) s = MDProd !(MDist a s) !(MDist b s)
indexD d i = (fstD d `indexD` i,sndD d `indexD` i)
newMD g = liftM2 MDProd (newMD g) (newMD g)
readMD (MDProd xs ys) i = liftM2 (,) (readMD xs i) (readMD ys i)
writeMD (MDProd xs ys) i (x,y)
= writeMD xs i x >> writeMD ys i y
unsafeFreezeMD (MDProd xs ys)
= liftM2 DProd (unsafeFreezeMD xs)
(unsafeFreezeMD ys)
deepSeqD (x,y) z = deepSeqD x (deepSeqD y z)
sizeD (DProd x _) = sizeD x
sizeMD (MDProd x _) = sizeMD x
measureD (x,y) = "Pair " ++ "(" ++ measureD x ++ ") (" ++ measureD y ++ ")"
zipD :: (DT a, DT b) => Dist a -> Dist b -> Dist (a,b)
zipD !x !y = checkEq (here "zipDT") "Size mismatch" (sizeD x) (sizeD y) $
DProd x y
unzipD :: (DT a, DT b) => Dist (a,b) -> (Dist a, Dist b)
unzipD (DProd dx dy) = (dx,dy)
fstD :: (DT a, DT b) => Dist (a,b) -> Dist a
fstD = fst . unzipD
sndD :: (DT a, DT b) => Dist (a,b) -> Dist b
sndD = snd . unzipD
instance DT a => DT (Maybe a) where
data Dist (Maybe a) = DMaybe !(Dist Bool) !(Dist a)
data MDist (Maybe a) s = MDMaybe !(MDist Bool s) !(MDist a s)
indexD (DMaybe bs as) i
| bs `indexD` i = Just $ as `indexD` i
| otherwise = Nothing
newMD g = liftM2 MDMaybe (newMD g) (newMD g)
readMD (MDMaybe bs as) i =
do
b <- readMD bs i
if b then liftM Just $ readMD as i
else return Nothing
writeMD (MDMaybe bs as) i Nothing = writeMD bs i False
writeMD (MDMaybe bs as) i (Just x) = writeMD bs i True
>> writeMD as i x
unsafeFreezeMD (MDMaybe bs as) = liftM2 DMaybe (unsafeFreezeMD bs)
(unsafeFreezeMD as)
deepSeqD Nothing z = z
deepSeqD (Just x) z = deepSeqD x z
sizeD (DMaybe b _) = sizeD b
sizeMD (MDMaybe b _) = sizeMD b
measureD Nothing = "Nothing"
measureD (Just x) = "Just (" ++ measureD x ++ ")"
instance Unbox a => DT (V.Vector a) where
data Dist (Vector a) = DVector !(Dist Int) !(BV.Vector (Vector a))
data MDist (Vector a) s = MDVector !(MDist Int s) !(MBV.STVector s (Vector a))
indexD (DVector _ a) i = a BV.! i
newMD g = liftM2 MDVector (newMD g) (MBV.replicate (gangSize g)
(error "MDist (Vector a) - uninitalised"))
readMD (MDVector _ marr) = MBV.read marr
writeMD (MDVector mlen marr) i a =
do
writeMD mlen i (V.length a)
MBV.write marr i $! a
unsafeFreezeMD (MDVector len a) = liftM2 DVector (unsafeFreezeMD len)
(BV.unsafeFreeze a)
sizeD (DVector _ a) = BV.length a
sizeMD (MDVector _ a) = MBV.length a
measureD xs = "Vector " ++ show (V.length xs)
lengthD :: Unbox a => Dist (Vector a) -> Dist Int
lengthD (DVector l _) = l
instance DT USegd where
data Dist USegd = DUSegd !(Dist (Vector Int))
!(Dist (Vector Int))
!(Dist Int)
data MDist USegd s = MDUSegd !(MDist (Vector Int) s)
!(MDist (Vector Int) s)
!(MDist Int s)
indexD (DUSegd lens idxs eles) i
= mkUSegd (indexD lens i) (indexD idxs i) (indexD eles i)
newMD g = liftM3 MDUSegd (newMD g) (newMD g) (newMD g)
readMD (MDUSegd lens idxs eles) i
= liftM3 mkUSegd (readMD lens i) (readMD idxs i) (readMD eles i)
writeMD (MDUSegd lens idxs eles) i segd
= do
writeMD lens i (lengthsUSegd segd)
writeMD idxs i (indicesUSegd segd)
writeMD eles i (elementsUSegd segd)
unsafeFreezeMD (MDUSegd lens idxs eles)
= liftM3 DUSegd (unsafeFreezeMD lens)
(unsafeFreezeMD idxs)
(unsafeFreezeMD eles)
deepSeqD segd z = deepSeqD (lengthsUSegd segd)
$ deepSeqD (indicesUSegd segd)
$ deepSeqD (elementsUSegd segd) z
sizeD (DUSegd _ _ eles) = sizeD eles
sizeMD (MDUSegd _ _ eles) = sizeMD eles
measureD segd = "Segd " ++ show (lengthUSegd segd) ++ " " ++ show (elementsUSegd segd)
lengthUSegdD :: Dist USegd -> Dist Int
lengthUSegdD (DUSegd lens _ _) = lengthD lens
lengthsUSegdD :: Dist USegd -> Dist (Vector Int)
lengthsUSegdD (DUSegd lens _ _ ) = lens
indicesUSegdD :: Dist USegd -> Dist (Vector Int)
indicesUSegdD (DUSegd _ idxs _) = idxs
elementsUSegdD :: Dist USegd -> Dist Int
elementsUSegdD (DUSegd _ _ dns) = dns