module Data.Array.Accelerate.Prelude (
zip, zip3, zip4,
unzip, unzip3, unzip4,
foldAll, fold1All,
prescanl, postscanl, prescanr, postscanr,
scanlSeg, scanl'Seg, scanl1Seg, prescanlSeg, postscanlSeg,
scanrSeg, scanr'Seg, scanr1Seg, prescanrSeg, postscanrSeg,
flatten,
fill, enumFromN, enumFromStepN,
gather, gatherIf, scatter, scatterIf,
init, tail, take, drop, slit
) where
import Prelude hiding (
replicate, zip, zip3, unzip, unzip3, map, zipWith, scanl, scanl1, scanr,
scanr1, init, tail, take, drop, filter, max, min, not, fst, snd, curry,
uncurry, fromIntegral, abs, pred )
import qualified Prelude
import Data.Array.Accelerate.Array.Sugar hiding ((!), ignore, shape, size, index)
import Data.Array.Accelerate.Language
import Data.Array.Accelerate.Type
zip :: (Shape sh, Elt a, Elt b)
=> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh (a, b))
zip = zipWith (curry lift)
zip3 :: forall sh. forall a. forall b. forall c. (Shape sh, Elt a, Elt b, Elt c)
=> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh (a, b, c))
zip3 as bs cs
= zipWith (\a bc -> let (b, c) = unlift bc :: (Exp b, Exp c) in lift (a, b, c)) as
$ zip bs cs
zip4 :: forall sh. forall a. forall b. forall c. forall d. (Shape sh, Elt a, Elt b, Elt c, Elt d)
=> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh d)
-> Acc (Array sh (a, b, c, d))
zip4 as bs cs ds
= zipWith (\a bcd -> let (b, c, d) = unlift bcd :: (Exp b, Exp c, Exp d) in lift (a, b, c, d)) as
$ zip3 bs cs ds
unzip :: (Shape sh, Elt a, Elt b)
=> Acc (Array sh (a, b))
-> (Acc (Array sh a), Acc (Array sh b))
unzip arr = (map fst arr, map snd arr)
unzip3
:: forall sh a b c. (Shape sh, Elt a, Elt b, Elt c)
=> Acc (Array sh (a, b, c))
-> (Acc (Array sh a), Acc (Array sh b), Acc (Array sh c))
unzip3 abcs = (as, bs, cs)
where
(bs, cs) = unzip bcs
(as, bcs) = unzip $ map swizzle abcs
swizzle :: Exp (a, b, c) -> Exp (a, (b, c))
swizzle abc = let (a, b, c) = unlift abc :: (Exp a, Exp b, Exp c)
bc = lift (b, c) :: Exp (b, c)
in lift (a, bc)
unzip4
:: forall sh a b c d. (Shape sh, Elt a, Elt b, Elt c, Elt d)
=> Acc (Array sh (a, b, c, d))
-> (Acc (Array sh a), Acc (Array sh b), Acc (Array sh c), Acc (Array sh d))
unzip4 abcds = (as, bs, cs, ds)
where
(abs, cds) = unzip $ map swizzle abcds
(as, bs) = unzip abs
(cs, ds) = unzip cds
swizzle :: Exp (a, b, c, d) -> Exp ((a, b), (c, d))
swizzle abcd = let (a, b, c, d) = unlift abcd :: (Exp a, Exp b, Exp c, Exp d)
ab = lift (a, b) :: Exp (a, b)
cd = lift (c, d) :: Exp (c, d)
in lift (ab, cd)
foldAll :: (Shape sh, Elt a)
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Array sh a)
-> Acc (Scalar a)
foldAll f e arr = fold f e (reshape (index1 $ size arr) arr)
fold1All :: (Shape sh, Elt a)
=> (Exp a -> Exp a -> Exp a)
-> Acc (Array sh a)
-> Acc (Scalar a)
fold1All f arr = fold1 f (reshape (index1 $ size arr) arr)
prescanl :: Elt a
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Vector a)
prescanl f e = Prelude.fst . scanl' f e
postscanl :: Elt a
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Vector a)
postscanl f e = map (e `f`) . scanl1 f
prescanr :: Elt a
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Vector a)
prescanr f e = Prelude.fst . scanr' f e
postscanr :: Elt a
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Vector a)
postscanr f e = map (`f` e) . scanr1 f
scanlSeg :: forall a i. (Elt a, Elt i, IsIntegral i)
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Segments i)
-> Acc (Vector a)
scanlSeg f e arr seg = scans
where
scans = scanl1Seg f idInjArr seg'
idInjArr = zipWith (\h x -> h ==* 1 ? (fst x, snd x)) headFlags $ zip idsArr arrShifted
headFlags = permute (+) zerosArr' (\ix -> index1' $ segOffsets' ! ix)
$ generate (shape seg) (const (1 :: Exp i))
arrShifted = backpermute nSh (\ix -> index1' $ shiftCoords ! ix) arr
idsArr = generate nSh (const e)
shiftCoords = permute (+) zerosArr' (ilift1 $ \i -> i + (offsetArr ! index1' i) + 1) coords
coords = Prelude.fst $ scanl' (+) 0 onesArr
offsetArr = scanl1 max $ permute (+) zerosArr (\ix -> index1' $ segOffsets ! ix) segIxs
segIxs = Prelude.fst $ scanl' (+) 0 $ generate (index1' $ size seg) (const 1)
segOffsets' = Prelude.fst $ scanl' (+) 0 seg'
segOffsets = Prelude.fst $ scanl' (+) 0 seg
nSh = index1' $ size arr + size seg
seg' = map (+ 1) seg
onesArr = generate (shape arr) (const 1)
zerosArr = generate (shape arr) (const 0)
zerosArr' = generate nSh (const 0)
scanl'Seg :: forall a i. (Elt a, Elt i, IsIntegral i)
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Segments i)
-> (Acc (Vector a), Acc (Vector a))
scanl'Seg f e arr seg = (scans, sums)
where
scans = scanl1Seg f idInjArr seg
idInjArr = zipWith (\h x -> h ==* 1 ? (fst x, snd x)) headFlags $ zip idsArr arrShifted
headFlags = permute (+) zerosArr (\ix -> index1' $ segOffsets ! ix)
$ generate (shape seg) (const (1 :: Exp i))
segOffsets = Prelude.fst $ scanl' (+) 0 seg
arrShifted = backpermute (shape arr) (ilift1 $ \i -> i ==* 0 ? (i, i 1)) arr
idsArr = generate (shape arr) (const e)
zerosArr = generate (shape arr) (const 0)
sums = map (`f` e)
$ backpermute (shape seg) (\ix -> index1' $ sumOffsets ! ix)
$ scanl1Seg f arr seg
sumOffsets = map (subtract 1) $ scanl1 (+) seg
scanl1Seg :: (Elt a, Elt i, IsIntegral i)
=> (Exp a -> Exp a -> Exp a)
-> Acc (Vector a)
-> Acc (Segments i)
-> Acc (Vector a)
scanl1Seg f arr seg = map snd $ scanl1 (mkSegApply f) $ zip (mkHeadFlags seg) arr
prescanlSeg :: (Elt a, Elt i, IsIntegral i)
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Segments i)
-> Acc (Vector a)
prescanlSeg f e arr seg = Prelude.fst $ scanl'Seg f e arr seg
postscanlSeg :: (Elt a, Elt i, IsIntegral i)
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Segments i)
-> Acc (Vector a)
postscanlSeg f e arr seg = map (e `f`) $ scanl1Seg f arr seg
scanrSeg :: forall a i. (Elt a, Elt i, IsIntegral i)
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Segments i)
-> Acc (Vector a)
scanrSeg f e arr seg = scans
where
scans = scanr1Seg f idInjArr seg'
idInjArr = zipWith (\h x -> h ==* 1 ? (fst x, snd x)) tailFlags $ zip idsArr arrShifted
tailFlags = permute (+) zerosArr' (\ix -> index1' $ (segOffsets' ! ix) 1)
$ generate (shape seg) (const (1 :: Exp i))
arrShifted = backpermute nSh (\ix -> index1' $ shiftCoords ! ix) arr
idsArr = generate nSh (const e)
shiftCoords = permute (+) zerosArr' (ilift1 $ \i -> i + (offsetArr ! index1' i)) coords
coords = Prelude.fst $ scanl' (+) 0 onesArr
offsetArr = scanl1 max $ permute (+) zerosArr (\ix -> index1' $ segOffsets ! ix) segIxs
segIxs = Prelude.fst $ scanl' (+) 0 $ generate (shape seg) (const 1)
segOffsets' = scanl1 (+) seg'
segOffsets = Prelude.fst $ scanl' (+) 0 seg
nSh = index1' $ size arr + size seg
seg' = map (+ 1) seg
onesArr = generate (shape arr) (const 1)
zerosArr = generate (shape arr) (const 0)
zerosArr' = generate nSh (const 0)
scanr'Seg :: forall a i. (Elt a, Elt i, IsIntegral i)
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Segments i)
-> (Acc (Vector a), Acc (Vector a))
scanr'Seg f e arr seg = (scans, sums)
where
scans = scanr1Seg f idInjArr seg
idInjArr = zipWith (\t x -> t ==* 1 ? (fst x, snd x)) tailFlags $ zip idsArr arrShifted
tailFlags = permute (+) zerosArr (\ix -> index1' $ (segOffsets ! ix) 1)
$ generate (shape seg) (const (1 :: Exp i))
segOffsets = scanl1 (+) seg
arrShifted = backpermute (shape arr) (ilift1 $ \i -> i ==* (size arr 1) ? (i, i + 1)) arr
idsArr = generate (shape arr) (const e)
zerosArr = generate (shape arr) (const 0)
sums = map (`f` e) $ backpermute (shape seg) (\ix -> index1' $ sumOffsets ! ix)
$ scanr1Seg f arr seg
sumOffsets = Prelude.fst $ scanl' (+) 0 seg
scanr1Seg :: (Elt a, Elt i, IsIntegral i)
=> (Exp a -> Exp a -> Exp a)
-> Acc (Vector a)
-> Acc (Segments i)
-> Acc (Vector a)
scanr1Seg f arr seg = map snd $ scanr1 (mkSegApply f) $ zip (mkTailFlags seg) arr
prescanrSeg :: (Elt a, Elt i, IsIntegral i)
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Segments i)
-> Acc (Vector a)
prescanrSeg f e arr seg = Prelude.fst $ scanr'Seg f e arr seg
postscanrSeg :: (Elt a, Elt i, IsIntegral i)
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Segments i)
-> Acc (Vector a)
postscanrSeg f e arr seg = map (`f` e) $ scanr1Seg f arr seg
mkHeadFlags :: (Elt i, IsIntegral i) => Acc (Segments i) -> Acc (Segments i)
mkHeadFlags seg = permute (\_ _ -> 1) zerosArr (\ix -> index1' (segOffsets ! ix)) segOffsets
where
(segOffsets, len) = scanl' (+) 0 seg
zerosArr = generate (index1' $ the len) (const 0)
mkTailFlags :: (Elt i, IsIntegral i) => Acc (Segments i) -> Acc (Segments i)
mkTailFlags seg
= permute (\_ _ -> 1) zerosArr (ilift1 $ \i -> (fromIntegral $ segOffsets ! index1' i) 1) segOffsets
where
segOffsets = scanl1 (+) seg
len = segOffsets ! index1' (size seg 1)
zerosArr = generate (index1' len) (const 0)
mkSegApply :: (Elt e, Elt i, IsIntegral i)
=> (Exp e -> Exp e -> Exp e)
-> (Exp (i, e) -> Exp (i, e) -> Exp (i, e))
mkSegApply op = apply
where
apply a b = lift (fromIntegral $ boolToInt (aF ==* 1 ||* bF ==* 1), bF ==* 1 ? (bV, aV `op` bV))
where
aF = fst a
aV = snd a
bF = fst b
bV = snd b
index1' :: (Elt i, IsIntegral i) => Exp i -> Exp (Z :. Int)
index1' = index1 . fromIntegral
flatten :: (Shape ix, Elt a) => Acc (Array ix a) -> Acc (Array DIM1 a)
flatten a = reshape (index1 $ size a) a
fill :: (Shape sh, Elt e) => Exp sh -> Exp e -> Acc (Array sh e)
fill sh c = generate sh (const c)
enumFromN :: (Shape sh, Elt e, IsNum e) => Exp sh -> Exp e -> Acc (Array sh e)
enumFromN sh x = enumFromStepN sh x 1
enumFromStepN :: (Shape sh, Elt e, IsNum e)
=> Exp sh
-> Exp e
-> Exp e
-> Acc (Array sh e)
enumFromStepN sh x y = reshape sh
$ generate (index1 $ shapeSize sh)
((\i -> ((fromIntegral i) * y) + x) . unindex1)
gather :: (Elt e)
=> Acc (Vector Int)
-> Acc (Vector e)
-> Acc (Vector e)
gather mapV inputV = backpermute (shape mapV) bpF inputV
where
bpF ix = lift (Z :. (mapV ! ix))
gatherIf :: (Elt e, Elt e')
=> Acc (Vector Int)
-> Acc (Vector e)
-> (Exp e -> Exp Bool)
-> Acc (Vector e')
-> Acc (Vector e')
-> Acc (Vector e')
gatherIf mapV maskV pred defaultV inputV = zipWith zwF predV gatheredV
where
zwF p g = p ? (unlift g)
gatheredV = zip (gather mapV inputV) defaultV
predV = map pred maskV
scatter :: (Elt e)
=> Acc (Vector Int)
-> Acc (Vector e)
-> Acc (Vector e)
-> Acc (Vector e)
scatter mapV defaultV inputV = permute (const) defaultV pF inputV
where
pF ix = lift (Z :. (mapV ! ix))
scatterIf :: (Elt e, Elt e')
=> Acc (Vector Int)
-> Acc (Vector e)
-> (Exp e -> Exp Bool)
-> Acc (Vector e')
-> Acc (Vector e')
-> Acc (Vector e')
scatterIf mapV maskV pred defaultV inputV = permute const defaultV pF inputV
where
pF ix = (pred (maskV ! ix)) ? (lift (Z :. (mapV ! ix)), ignore)
take :: Elt e => Exp Int -> Acc (Vector e) -> Acc (Vector e)
take n = backpermute (index1 n) id
drop :: Elt e => Exp Int -> Acc (Vector e) -> Acc (Vector e)
drop n arr = backpermute (ilift1 (\x -> x n) $ shape arr) (ilift1 (+ n)) arr
init :: Elt e => Acc (Vector e) -> Acc (Vector e)
init arr = take ((unindex1 $ shape arr) 1) arr
tail :: Elt e => Acc (Vector e) -> Acc (Vector e)
tail = drop 1
slit :: Elt e
=> Exp Int
-> Exp Int
-> Acc (Vector e)
-> Acc (Vector e)
slit i n = backpermute (index1 n) (ilift1 (+ i))