module Data.Array.Accelerate.Prelude (
zip, unzip,
foldAll, fold1All,
prescanl, postscanl, prescanr, postscanr,
scanlSeg, scanlSeg', scanl1Seg, prescanlSeg, postscanlSeg,
scanrSeg, scanrSeg', scanr1Seg, prescanrSeg, postscanrSeg
) where
import Prelude hiding (replicate, zip, unzip, map, scanl, scanl1, scanr, scanr1, zipWith,
filter, max, min, not, fst, snd, curry, uncurry)
import qualified Prelude
import Data.Array.Accelerate.Array.Sugar hiding ((!), ignore, shape, size, index)
import Data.Array.Accelerate.Language
zip :: (Shape sh, Elt a, Elt b)
=> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh (a, b))
zip = zipWith (curry lift)
unzip :: (Shape sh, Elt a, Elt b)
=> Acc (Array sh (a, b))
-> (Acc (Array sh a), Acc (Array sh b))
unzip arr = (map fst arr, map snd arr)
foldAll :: (Shape sh, Elt a)
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Array sh a)
-> Acc (Scalar a)
foldAll f e arr = fold f e (reshape (index1 $ size arr) arr)
fold1All :: (Shape sh, Elt a)
=> (Exp a -> Exp a -> Exp a)
-> Acc (Array sh a)
-> Acc (Scalar a)
fold1All f arr = fold1 f (reshape (index1 $ size arr) arr)
prescanl :: Elt a
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Vector a)
prescanl f e = Prelude.fst . scanl' f e
postscanl :: Elt a
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Vector a)
postscanl f e = map (e `f`) . scanl1 f
prescanr :: Elt a
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Vector a)
prescanr f e = Prelude.fst . scanr' f e
postscanr :: Elt a
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Vector a)
postscanr f e = map (`f` e) . scanr1 f
scanlSeg :: Elt a
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc Segments
-> Acc (Vector a)
scanlSeg f e arr seg = scans
where
scans = scanl1Seg f idInjArr seg'
idInjArr = zipWith (\h x -> h ==* 1 ? (fst x, snd x)) headFlags $ zip idsArr arrShifted
headFlags = permute (+) zerosArr' (\ix -> index1 $ segOffsets' ! ix)
$ generate (shape seg) (const 1)
arrShifted = backpermute nSh (\ix -> index1 $ shiftCoords ! ix) arr
idsArr = generate nSh (const e)
shiftCoords = permute (+) zerosArr' (ilift1 $ \i -> i + (offsetArr ! index1 i) + 1) coords
coords = Prelude.fst $ scanl' (+) 0 onesArr
offsetArr = scanl1 max $ permute (+) zerosArr (\ix -> index1 $ segOffsets ! ix) segIxs
segIxs = Prelude.fst $ scanl' (+) 0 $ generate (index1 $ size seg) (const 1)
segOffsets' = Prelude.fst $ scanl' (+) 0 seg'
segOffsets = Prelude.fst $ scanl' (+) 0 seg
nSh = index1 $ size arr + size seg
seg' = map (+ 1) seg
onesArr = generate (shape arr) (const 1)
zerosArr = generate (shape arr) (const 0)
zerosArr' = generate nSh (const 0)
scanlSeg' :: Elt a
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc Segments
-> (Acc (Vector a), Acc (Vector a))
scanlSeg' f e arr seg = (scans, sums)
where
scans = scanl1Seg f idInjArr seg
idInjArr = zipWith (\h x -> h ==* 1 ? (fst x, snd x)) headFlags $ zip idsArr arrShifted
headFlags = permute (+) zerosArr (\ix -> index1 $ segOffsets ! ix)
$ generate (shape seg) (const (1 :: Exp Int))
segOffsets = Prelude.fst $ scanl' (+) 0 seg
arrShifted = backpermute (shape arr) (ilift1 $ \i -> i ==* 0 ? (i, i 1)) arr
idsArr = generate (shape arr) (const e)
zerosArr = generate (shape arr) (const 0)
sums = map (`f` e)
$ backpermute (shape seg) (\ix -> index1 $ sumOffsets ! ix)
$ scanl1Seg f arr seg
sumOffsets = map (subtract 1) $ scanl1 (+) seg
scanl1Seg :: Elt a
=> (Exp a -> Exp a -> Exp a)
-> Acc (Vector a)
-> Acc Segments
-> Acc (Vector a)
scanl1Seg f arr seg = map snd $ scanl1 (mkSegApply f) $ zip (mkHeadFlags seg) arr
prescanlSeg :: Elt a
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc Segments
-> Acc (Vector a)
prescanlSeg f e arr seg = Prelude.fst $ scanlSeg' f e arr seg
postscanlSeg :: Elt a
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc Segments
-> Acc (Vector a)
postscanlSeg f e arr seg = map (e `f`) $ scanl1Seg f arr seg
scanrSeg :: Elt a
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc Segments
-> Acc (Vector a)
scanrSeg f e arr seg = scans
where
scans = scanr1Seg f idInjArr seg'
idInjArr = zipWith (\h x -> h ==* 1 ? (fst x, snd x)) tailFlags $ zip idsArr arrShifted
tailFlags = permute (+) zerosArr' (\ix -> index1 $ (segOffsets' ! ix) 1)
$ generate (shape seg) (const 1)
arrShifted = backpermute nSh (\ix -> index1 $ shiftCoords ! ix) arr
idsArr = generate nSh (const e)
shiftCoords = permute (+) zerosArr' (ilift1 $ \i -> i + (offsetArr ! index1 i)) coords
coords = Prelude.fst $ scanl' (+) 0 onesArr
offsetArr = scanl1 max $ permute (+) zerosArr (\ix -> index1 $ segOffsets ! ix) segIxs
segIxs = Prelude.fst $ scanl' (+) 0 $ generate (shape seg) (const 1)
segOffsets' = scanl1 (+) seg'
segOffsets = Prelude.fst $ scanl' (+) 0 seg
nSh = index1 $ size arr + size seg
seg' = map (+ 1) seg
onesArr = generate (shape arr) (const 1)
zerosArr = generate (shape arr) (const 0)
zerosArr' = generate nSh (const 0)
scanrSeg' :: Elt a
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc Segments
-> (Acc (Vector a), Acc (Vector a))
scanrSeg' f e arr seg = (scans, sums)
where
scans = scanr1Seg f idInjArr seg
idInjArr = zipWith (\t x -> t ==* 1 ? (fst x, snd x)) tailFlags $ zip idsArr arrShifted
tailFlags = permute (+) zerosArr (\ix -> index1 $ (segOffsets ! ix) 1)
$ generate (shape seg) (const (1 :: Exp Int))
segOffsets = scanl1 (+) seg
arrShifted = backpermute (shape arr) (ilift1 $ \i -> i ==* (size arr 1) ? (i, i + 1)) arr
idsArr = generate (shape arr) (const e)
zerosArr = generate (shape arr) (const 0)
sums = map (`f` e) $ backpermute (shape seg) (\ix -> index1 $ sumOffsets ! ix)
$ scanr1Seg f arr seg
sumOffsets = Prelude.fst $ scanl' (+) 0 seg
scanr1Seg :: Elt a
=> (Exp a -> Exp a -> Exp a)
-> Acc (Vector a)
-> Acc Segments
-> Acc (Vector a)
scanr1Seg f arr seg = map snd $ scanr1 (mkSegApply f) $ zip (mkTailFlags seg) arr
prescanrSeg :: Elt a
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc Segments
-> Acc (Vector a)
prescanrSeg f e arr seg = Prelude.fst $ scanrSeg' f e arr seg
postscanrSeg :: Elt a
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc Segments
-> Acc (Vector a)
postscanrSeg f e arr seg = map (`f` e) $ scanr1Seg f arr seg
mkHeadFlags :: Acc (Array DIM1 Int) -> Acc (Array DIM1 Int)
mkHeadFlags seg = permute (\_ _ -> 1) zerosArr (\ix -> index1 (segOffsets ! ix)) segOffsets
where
(segOffsets, len) = scanl' (+) 0 seg
zerosArr = generate (index1 $ the len) (const 0)
mkTailFlags :: Acc (Array DIM1 Int) -> Acc (Array DIM1 Int)
mkTailFlags seg
= permute (\_ _ -> 1) zerosArr (ilift1 $ \i -> (segOffsets ! index1 i) 1) segOffsets
where
segOffsets = scanl1 (+) seg
len = segOffsets ! index1 (size seg 1)
zerosArr = generate (index1 len) (const 0)
mkSegApply :: (Elt e)
=> (Exp e -> Exp e -> Exp e)
-> (Exp (Int, e) -> Exp (Int, e) -> Exp (Int, e))
mkSegApply op = apply
where
apply a b = lift (((aF ==* 1) ||* (bF ==* 1)) ? (1, 0), (bF ==* 1) ? (bV, aV `op` bV))
where
aF = fst a
aV = snd a
bF = fst b
bV = snd b