{-# LANGUAGE EmptyDataDecls, ScopedTypeVariables #-} {-# LANGUAGE CPP #-} #include "fusion-phases.h" -- | Operations on distributed arrays. module Data.Array.Parallel.Unlifted.Distributed.Arrays ( lengthD, splitLenD, splitLenIdxD, splitAsD, splitD, joinLengthD, joinD, splitJoinD, joinDM, splitSegdD, splitSegdD', splitSD, permuteD, bpermuteD, atomicUpdateD, Distribution, balanced, unbalanced ) where import Data.Array.Parallel.Base ( ST, runST) import Data.Array.Parallel.Unlifted.Sequential.Vector as Seq import Data.Array.Parallel.Unlifted.Sequential.Segmented import Data.Array.Parallel.Unlifted.Distributed.Gang ( Gang, gangSize, seqGang) import Data.Array.Parallel.Unlifted.Distributed.DistST ( DistST, stToDistST, myIndex ) import Data.Array.Parallel.Unlifted.Distributed.Types ( DT, Dist, mkDPrim, indexD, lengthD, newD, writeMD, zipD, unzipD, fstD, sndD, elementsUSegdD, checkGangD) import Data.Array.Parallel.Unlifted.Distributed.Basics import Data.Array.Parallel.Unlifted.Distributed.Combinators import Data.Array.Parallel.Unlifted.Distributed.Scalars ( sumD) import Data.Bits ( shiftR ) import Control.Monad ( when ) import GHC.Base ( quotInt, remInt ) here s = "Data.Array.Parallel.Unlifted.Distributed.Arrays." Prelude.++ s -- Distribution --------------------------------------------------------------- -- | This is a phantom parameter used to record whether a distributed value -- is balanced evenly among the threads. It's used to signal this property -- between RULES, but the actual value is never used. data Distribution balanced :: Distribution {-# NOINLINE balanced #-} balanced = error $ here "balanced: touched" unbalanced :: Distribution {-# NOINLINE unbalanced #-} unbalanced = error $ here "unbalanced: touched" -- Splitting and Joining array lengths ---------------------------------------- -- | Distribute an array length over a 'Gang'. -- Each thread holds the number of elements it's reponsible for. splitLenD :: Gang -> Int -> Dist Int {-# INLINE splitLenD #-} splitLenD g n = generateD_cheap g len where !p = gangSize g !l = n `quotInt` p !m = n `remInt` p {-# INLINE [0] len #-} len i | i < m = l+1 | otherwise = l -- | Distribute an array length over a 'Gang'. -- Each thread holds the number of elements it's responsible for, -- and the index of the start of its chunk. splitLenIdxD :: Gang -> Int -> Dist (Int,Int) {-# INLINE splitLenIdxD #-} splitLenIdxD g n = generateD_cheap g len_idx where !p = gangSize g !l = n `quotInt` p !m = n `remInt` p {-# INLINE [0] len_idx #-} len_idx i | i < m = (l+1, i*(l+1)) | otherwise = (l, i*l + m) -- | Get the overall length of a distributed array. -- We ask each thread for its chunk length, and sum them all up. joinLengthD :: Unbox a => Gang -> Dist (Vector a) -> Int {-# INLINE joinLengthD #-} joinLengthD g = sumD g . lengthD -- Splitting and Joining arrays ----------------------------------------------- -- | Distribute an array over a 'Gang' such that each threads gets the given -- number of elements. splitAsD :: Unbox a => Gang -> Dist Int -> Vector a -> Dist (Vector a) {-# INLINE_DIST splitAsD #-} splitAsD g dlen !arr = zipWithD (seqGang g) (Seq.slice arr) is dlen where is = fst $ scanD g (+) 0 dlen -- | Distribute an array over a 'Gang'. -- -- NOTE: This is defined in terms of splitD_impl to avoid introducing loops -- through RULES. Without it, splitJoinD would be a loop breaker. splitD :: Unbox a => Gang -> Distribution -> Vector a -> Dist (Vector a) {-# INLINE_DIST splitD #-} splitD g _ arr = splitD_impl g arr splitD_impl :: Unbox a => Gang -> Vector a -> Dist (Vector a) {-# INLINE_DIST splitD_impl #-} splitD_impl g !arr = generateD_cheap g (\i -> Seq.slice arr (idx i) (len i)) where n = Seq.length arr !p = gangSize g !l = n `quotInt` p !m = n `remInt` p {-# INLINE [0] idx #-} idx i | i < m = (l+1)*i | otherwise = l*i + m {-# INLINE [0] len #-} len i | i < m = l+1 | otherwise = l -- | Join a distributed array. -- -- NOTE: This is defined in terms of joinD_impl to avoid introducing loops -- through RULES. Without it, splitJoinD would be a loop breaker. joinD :: Unbox a => Gang -> Distribution -> Dist (Vector a) -> Vector a {-# INLINE CONLIKE [1] joinD #-} joinD g _ darr = joinD_impl g darr joinD_impl :: forall a. Unbox a => Gang -> Dist (Vector a) -> Vector a {-# INLINE_DIST joinD_impl #-} joinD_impl g !darr = checkGangD (here "joinD") g darr $ Seq.new n (\ma -> zipWithDST_ g (copy ma) di darr) where (!di,!n) = scanD g (+) 0 $ lengthD darr copy :: forall s. MVector s a -> Int -> Vector a -> DistST s () copy ma i arr = stToDistST (Seq.copy (mslice i (Seq.length arr) ma) arr) -- | Split a vector over a gang, run a distributed computation, then -- join the pieces together again. splitJoinD :: (Unbox a, Unbox b) => Gang -> (Dist (Vector a) -> Dist (Vector b)) -> Vector a -> Vector b {-# INLINE_DIST splitJoinD #-} splitJoinD g f !xs = joinD_impl g (f (splitD_impl g xs)) -- | Join a distributed array, yielding a mutable global array joinDM :: Unbox a => Gang -> Dist (Vector a) -> ST s (MVector s a) {-# INLINE joinDM #-} joinDM g darr = checkGangD (here "joinDM") g darr $ do marr <- Seq.newM n zipWithDST_ g (copy marr) di darr return marr where (!di,!n) = scanD g (+) 0 $ lengthD darr -- copy ma i arr = stToDistST (Seq.copy (mslice i (Seq.length arr) ma) arr) {-# RULES "splitD[unbalanced]/joinD" forall g b da. splitD g unbalanced (joinD g b da) = da "splitD[balanced]/joinD" forall g da. splitD g balanced (joinD g balanced da) = da "splitD/splitJoinD" forall g b f xs. splitD g b (splitJoinD g f xs) = f (splitD g b xs) "splitJoinD/joinD" forall g b f da. splitJoinD g f (joinD g b da) = joinD g b (f da) "splitJoinD/splitJoinD" forall g f1 f2 xs. splitJoinD g f1 (splitJoinD g f2 xs) = splitJoinD g (f1 . f2) xs #-} {-# RULES "Seq.zip/joinD[1]" forall g xs ys. Seq.zip (joinD g balanced xs) ys = joinD g balanced (zipWithD g Seq.zip xs (splitD g balanced ys)) "Seq.zip/joinD[2]" forall g xs ys. Seq.zip xs (joinD g balanced ys) = joinD g balanced (zipWithD g Seq.zip (splitD g balanced xs) ys) "Seq.zip/splitJoinD" forall gang f g xs ys. Seq.zip (splitJoinD gang (imapD gang f) xs) (splitJoinD gang (imapD gang g) ys) = splitJoinD gang (imapD gang (\i zs -> let (as,bs) = Seq.unzip zs in Seq.zip (f i as) (g i bs))) (Seq.zip xs ys) #-} -- Permutation ---------------------------------------------------------------- -- | Permute for distributed arrays. permuteD :: forall a. Unbox a => Gang -> Dist (Vector a) -> Dist (Vector Int) -> Vector a {-# INLINE_DIST permuteD #-} permuteD g darr dis = Seq.new n (\ma -> zipWithDST_ g (permute ma) darr dis) where n = joinLengthD g darr permute :: forall s. MVector s a -> Vector a -> Vector Int -> DistST s () permute ma arr is = stToDistST (Seq.mpermute ma arr is) -- NOTE: The bang is necessary because the array must be fully evaluated -- before we pass it to the parallel computation. bpermuteD :: Unbox a => Gang -> Vector a -> Dist (Vector Int) -> Dist (Vector a) {-# INLINE bpermuteD #-} bpermuteD g !as ds = mapD g (Seq.bpermute as) ds -- Update --------------------------------------------------------------------- -- NB: This does not (and cannot) try to prevent two threads from writing to -- the same position. We probably want to consider this an (unchecked) user -- error. atomicUpdateD :: forall a. Unbox a => Gang -> Dist (Vector a) -> Dist (Vector (Int,a)) -> Vector a {-# INLINE atomicUpdateD #-} atomicUpdateD g darr upd = runST ( do marr <- joinDM g darr mapDST_ g (update marr) upd Seq.unsafeFreeze marr ) where update :: forall s. MVector s a -> Vector (Int,a) -> DistST s () update marr arr = stToDistST (Seq.mupdate marr arr) --- Splitting and Joining segment descriptors --------------------------------- splitSegdD :: Gang -> USegd -> Dist USegd {-# NOINLINE splitSegdD #-} splitSegdD g !segd = mapD g lengthsToUSegd $ splitAsD g d lens where !d = snd . mapAccumLD g chunk 0 . splitLenD g $ elementsUSegd segd n = lengthUSegd segd lens = lengthsUSegd segd chunk !i !k = let !j = go i k in (j,j-i) go !i !k | i >= n = i | m == 0 = go (i+1) k | k <= 0 = i | otherwise = go (i+1) (k-m) where m = lens ! i search :: Int -> Vector Int -> Int search !x ys = go 0 (Seq.length ys) where go i n | n <= 0 = i | (ys!mid) < x = go (mid+1) (n-half-1) | otherwise = go i half where half = n `shiftR` 1 mid = i + half chunk :: USegd -> Int -> Int -> Bool -> (# Vector Int, Int, Int #) chunk !segd !di !dn is_last = (# lens', k-left_len, left_off #) where !lens' = runST (do mlens' <- Seq.newM n' when (left /= 0) $ Seq.write mlens' 0 left Seq.copy (Seq.mdrop left_len mlens') (Seq.slice lens k (k'-k)) when (right /= 0) $ Seq.write mlens' (n' - 1) right Seq.unsafeFreeze mlens') lens = lengthsUSegd segd idxs = indicesUSegd segd n = Seq.length lens k = search di idxs k' | is_last = n | otherwise = search (di+dn) idxs left | k == n = dn | otherwise = min ((idxs!k) - di) dn right | k' == k = 0 | otherwise = di + dn - (idxs ! (k'-1)) left_len | left == 0 = 0 | otherwise = 1 left_off | left == 0 = 0 | otherwise = di - idxs ! (k-1) n' = left_len + (k'-k) splitSegdD' :: Gang -> USegd -> Dist ((USegd,Int),Int) {-# INLINE splitSegdD' #-} splitSegdD' g !segd = imapD g mk (splitLenIdxD g (elementsUSegd segd)) where !p = gangSize g mk i (dn,di) = case chunk segd di dn (i == p-1) of (# lens, l, o #) -> ((lengthsToUSegd lens,l),o) joinSegD :: Gang -> Dist USegd -> USegd {-# INLINE_DIST joinSegD #-} joinSegD g = lengthsToUSegd . joinD g unbalanced . mapD g lengthsUSegd splitSD :: Unbox a => Gang -> Dist USegd -> Vector a -> Dist (Vector a) {-# INLINE_DIST splitSD #-} splitSD g dsegd xs = splitAsD g (elementsUSegdD dsegd) xs {-# RULES "splitSD/splitJoinD" forall g d f xs. splitSD g d (splitJoinD g f xs) = f (splitSD g d xs) "splitSD/Seq.zip" forall g d xs ys. splitSD g d (Seq.zip xs ys) = zipWithD g Seq.zip (splitSD g d xs) (splitSD g d ys) #-}