#include "fusion-phases.h"
module Data.Array.Parallel.Unlifted.Distributed.Arrays (
lengthD, splitLenD, splitLenIdxD,
splitAsD, splitD, joinLengthD, joinD, splitJoinD, joinDM,
splitSegdD, splitSegdD', splitSD,
permuteD, bpermuteD, atomicUpdateD,
Distribution, balanced, unbalanced
) where
import Data.Array.Parallel.Base ( ST, runST)
import Data.Array.Parallel.Unlifted.Sequential.Vector as Seq
import Data.Array.Parallel.Unlifted.Sequential.Segmented
import Data.Array.Parallel.Unlifted.Distributed.Gang (
Gang, gangSize, seqGang)
import Data.Array.Parallel.Unlifted.Distributed.DistST (
DistST, stToDistST, myIndex )
import Data.Array.Parallel.Unlifted.Distributed.Types (
DT, Dist, mkDPrim, indexD, lengthD, newD, writeMD, zipD, unzipD, fstD, sndD,
elementsUSegdD,
checkGangD)
import Data.Array.Parallel.Unlifted.Distributed.Basics
import Data.Array.Parallel.Unlifted.Distributed.Combinators
import Data.Array.Parallel.Unlifted.Distributed.Scalars (
sumD)
import Data.Bits ( shiftR )
import Control.Monad ( when )
import GHC.Base ( quotInt, remInt )
here s = "Data.Array.Parallel.Unlifted.Distributed.Arrays." Prelude.++ s
data Distribution
balanced :: Distribution
balanced = error $ here "balanced: touched"
unbalanced :: Distribution
unbalanced = error $ here "unbalanced: touched"
splitLenD :: Gang -> Int -> Dist Int
splitLenD g n = generateD_cheap g len
where
!p = gangSize g
!l = n `quotInt` p
!m = n `remInt` p
len i | i < m = l+1
| otherwise = l
splitLenIdxD :: Gang -> Int -> Dist (Int,Int)
splitLenIdxD g n = generateD_cheap g len_idx
where
!p = gangSize g
!l = n `quotInt` p
!m = n `remInt` p
len_idx i | i < m = (l+1, i*(l+1))
| otherwise = (l, i*l + m)
joinLengthD :: Unbox a => Gang -> Dist (Vector a) -> Int
joinLengthD g = sumD g . lengthD
splitAsD :: Unbox a => Gang -> Dist Int -> Vector a -> Dist (Vector a)
splitAsD g dlen !arr = zipWithD (seqGang g) (Seq.slice arr) is dlen
where
is = fst $ scanD g (+) 0 dlen
splitD :: Unbox a => Gang -> Distribution -> Vector a -> Dist (Vector a)
splitD g _ arr = splitD_impl g arr
splitD_impl :: Unbox a => Gang -> Vector a -> Dist (Vector a)
splitD_impl g !arr = generateD_cheap g (\i -> Seq.slice arr (idx i) (len i))
where
n = Seq.length arr
!p = gangSize g
!l = n `quotInt` p
!m = n `remInt` p
idx i | i < m = (l+1)*i
| otherwise = l*i + m
len i | i < m = l+1
| otherwise = l
joinD :: Unbox a => Gang -> Distribution -> Dist (Vector a) -> Vector a
joinD g _ darr = joinD_impl g darr
joinD_impl :: forall a. Unbox a => Gang -> Dist (Vector a) -> Vector a
joinD_impl g !darr = checkGangD (here "joinD") g darr $
Seq.new n (\ma -> zipWithDST_ g (copy ma) di darr)
where
(!di,!n) = scanD g (+) 0 $ lengthD darr
copy :: forall s. MVector s a -> Int -> Vector a -> DistST s ()
copy ma i arr = stToDistST (Seq.copy (mslice i (Seq.length arr) ma) arr)
splitJoinD
:: (Unbox a, Unbox b)
=> Gang
-> (Dist (Vector a) -> Dist (Vector b))
-> Vector a
-> Vector b
splitJoinD g f !xs = joinD_impl g (f (splitD_impl g xs))
joinDM :: Unbox a => Gang -> Dist (Vector a) -> ST s (MVector s a)
joinDM g darr = checkGangD (here "joinDM") g darr $
do
marr <- Seq.newM n
zipWithDST_ g (copy marr) di darr
return marr
where
(!di,!n) = scanD g (+) 0 $ lengthD darr
copy ma i arr = stToDistST (Seq.copy (mslice i (Seq.length arr) ma) arr)
permuteD :: forall a. Unbox a => Gang -> Dist (Vector a) -> Dist (Vector Int) -> Vector a
permuteD g darr dis = Seq.new n (\ma -> zipWithDST_ g (permute ma) darr dis)
where
n = joinLengthD g darr
permute :: forall s. MVector s a -> Vector a -> Vector Int -> DistST s ()
permute ma arr is = stToDistST (Seq.mpermute ma arr is)
bpermuteD :: Unbox a => Gang -> Vector a -> Dist (Vector Int) -> Dist (Vector a)
bpermuteD g !as ds = mapD g (Seq.bpermute as) ds
atomicUpdateD :: forall a. Unbox a
=> Gang -> Dist (Vector a) -> Dist (Vector (Int,a)) -> Vector a
atomicUpdateD g darr upd = runST (
do
marr <- joinDM g darr
mapDST_ g (update marr) upd
Seq.unsafeFreeze marr
)
where
update :: forall s. MVector s a -> Vector (Int,a) -> DistST s ()
update marr arr = stToDistST (Seq.mupdate marr arr)
splitSegdD :: Gang -> USegd -> Dist USegd
splitSegdD g !segd = mapD g lengthsToUSegd
$ splitAsD g d lens
where
!d = snd
. mapAccumLD g chunk 0
. splitLenD g
$ elementsUSegd segd
n = lengthUSegd segd
lens = lengthsUSegd segd
chunk !i !k = let !j = go i k
in (j,ji)
go !i !k | i >= n = i
| m == 0 = go (i+1) k
| k <= 0 = i
| otherwise = go (i+1) (km)
where
m = lens ! i
search :: Int -> Vector Int -> Int
search !x ys = go 0 (Seq.length ys)
where
go i n | n <= 0 = i
| (ys!mid) < x = go (mid+1) (nhalf1)
| otherwise = go i half
where
half = n `shiftR` 1
mid = i + half
chunk :: USegd -> Int -> Int -> Bool -> (# Vector Int, Int, Int #)
chunk !segd !di !dn is_last
= (# lens', kleft_len, left_off #)
where
!lens' = runST (do
mlens' <- Seq.newM n'
when (left /= 0) $ Seq.write mlens' 0 left
Seq.copy (Seq.mdrop left_len mlens')
(Seq.slice lens k (k'k))
when (right /= 0) $ Seq.write mlens' (n' 1) right
Seq.unsafeFreeze mlens')
lens = lengthsUSegd segd
idxs = indicesUSegd segd
n = Seq.length lens
k = search di idxs
k' | is_last = n
| otherwise = search (di+dn) idxs
left | k == n = dn
| otherwise = min ((idxs!k) di) dn
right | k' == k = 0
| otherwise = di + dn (idxs ! (k'1))
left_len | left == 0 = 0
| otherwise = 1
left_off | left == 0 = 0
| otherwise = di idxs ! (k1)
n' = left_len + (k'k)
splitSegdD' :: Gang -> USegd -> Dist ((USegd,Int),Int)
splitSegdD' g !segd = imapD g mk
(splitLenIdxD g
(elementsUSegd segd))
where
!p = gangSize g
mk i (dn,di) = case chunk segd di dn (i == p1) of
(# lens, l, o #) -> ((lengthsToUSegd lens,l),o)
joinSegD :: Gang -> Dist USegd -> USegd
joinSegD g = lengthsToUSegd
. joinD g unbalanced
. mapD g lengthsUSegd
splitSD :: Unbox a => Gang -> Dist USegd -> Vector a -> Dist (Vector a)
splitSD g dsegd xs = splitAsD g (elementsUSegdD dsegd) xs