module Data.Array.Accelerate.Prelude (
zipWith3, zipWith4,
zip, zip3, zip4,
unzip, unzip3, unzip4,
foldAll, fold1All,
all, any, and, or, sum, product, minimum, maximum,
prescanl, postscanl, prescanr, postscanr,
scanlSeg, scanl'Seg, scanl1Seg, prescanlSeg, postscanlSeg,
scanrSeg, scanr'Seg, scanr1Seg, prescanrSeg, postscanrSeg,
flatten,
fill, enumFromN, enumFromStepN,
filter,
scatter, scatterIf,
gather, gatherIf,
reverse, transpose,
init, tail, take, drop, slit
) where
import Data.Bits
import Data.Bool
import Prelude ((.), ($), (+), (), (*), const, subtract, id)
import qualified Prelude as P
import Data.Array.Accelerate.Array.Sugar hiding ((!), ignore, shape, size)
import Data.Array.Accelerate.Language
import Data.Array.Accelerate.Smart
import Data.Array.Accelerate.Type
zipWith3 :: (Shape sh, Elt a, Elt b, Elt c, Elt d)
=> (Exp a -> Exp b -> Exp c -> Exp d)
-> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh d)
zipWith3 f as bs cs
= map (\x -> let (a,b,c) = unlift x in f a b c)
$ zip3 as bs cs
zipWith4 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e)
=> (Exp a -> Exp b -> Exp c -> Exp d -> Exp e)
-> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh d)
-> Acc (Array sh e)
zipWith4 f as bs cs ds
= map (\x -> let (a,b,c,d) = unlift x in f a b c d)
$ zip4 as bs cs ds
zip :: (Shape sh, Elt a, Elt b)
=> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh (a, b))
zip = zipWith (curry lift)
zip3 :: forall sh. forall a. forall b. forall c. (Shape sh, Elt a, Elt b, Elt c)
=> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh (a, b, c))
zip3 as bs cs
= zipWith (\a bc -> let (b, c) = unlift bc :: (Exp b, Exp c) in lift (a, b, c)) as
$ zip bs cs
zip4 :: forall sh. forall a. forall b. forall c. forall d. (Shape sh, Elt a, Elt b, Elt c, Elt d)
=> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh d)
-> Acc (Array sh (a, b, c, d))
zip4 as bs cs ds
= zipWith (\a bcd -> let (b, c, d) = unlift bcd :: (Exp b, Exp c, Exp d) in lift (a, b, c, d)) as
$ zip3 bs cs ds
unzip :: (Shape sh, Elt a, Elt b)
=> Acc (Array sh (a, b))
-> (Acc (Array sh a), Acc (Array sh b))
unzip arr = (map fst arr, map snd arr)
unzip3 :: (Shape sh, Elt a, Elt b, Elt c)
=> Acc (Array sh (a, b, c))
-> (Acc (Array sh a), Acc (Array sh b), Acc (Array sh c))
unzip3 xs = (map get1 xs, map get2 xs, map get3 xs)
where
get1 :: forall a b c. (Elt a, Elt b, Elt c) => Exp (a,b,c) -> Exp a
get1 x = let (a, _ :: Exp b, _ :: Exp c) = unlift x in a
get2 :: forall a b c. (Elt a, Elt b, Elt c) => Exp (a,b,c) -> Exp b
get2 x = let (_ :: Exp a, b, _ :: Exp c) = unlift x in b
get3 :: forall a b c. (Elt a, Elt b, Elt c) => Exp (a,b,c) -> Exp c
get3 x = let (_ :: Exp a, _ :: Exp b, c) = unlift x in c
unzip4 :: (Shape sh, Elt a, Elt b, Elt c, Elt d)
=> Acc (Array sh (a, b, c, d))
-> (Acc (Array sh a), Acc (Array sh b), Acc (Array sh c), Acc (Array sh d))
unzip4 xs = (map get1 xs, map get2 xs, map get3 xs, map get4 xs)
where
get1 :: forall a b c d. (Elt a, Elt b, Elt c, Elt d) => Exp (a,b,c,d) -> Exp a
get1 x = let (a, _ :: Exp b, _ :: Exp c, _ :: Exp d) = unlift x in a
get2 :: forall a b c d. (Elt a, Elt b, Elt c, Elt d) => Exp (a,b,c,d) -> Exp b
get2 x = let (_ :: Exp a, b, _ :: Exp c, _ :: Exp d) = unlift x in b
get3 :: forall a b c d. (Elt a, Elt b, Elt c, Elt d) => Exp (a,b,c,d) -> Exp c
get3 x = let (_ :: Exp a, _ :: Exp b, c, _ :: Exp d) = unlift x in c
get4 :: forall a b c d. (Elt a, Elt b, Elt c, Elt d) => Exp (a,b,c,d) -> Exp d
get4 x = let (_ :: Exp a, _ :: Exp b, _ :: Exp c, d) = unlift x in d
foldAll :: (Shape sh, Elt a)
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Array sh a)
-> Acc (Scalar a)
foldAll f e arr = fold f e (flatten arr)
fold1All :: (Shape sh, Elt a)
=> (Exp a -> Exp a -> Exp a)
-> Acc (Array sh a)
-> Acc (Scalar a)
fold1All f arr = fold1 f (flatten arr)
all :: (Shape sh, Elt e)
=> (Exp e -> Exp Bool)
-> Acc (Array sh e)
-> Acc (Scalar Bool)
all f = and . map f
any :: (Shape sh, Elt e)
=> (Exp e -> Exp Bool)
-> Acc (Array sh e)
-> Acc (Scalar Bool)
any f = or . map f
and :: Shape sh
=> Acc (Array sh Bool)
-> Acc (Scalar Bool)
and = foldAll (&&*) (constant True)
or :: Shape sh
=> Acc (Array sh Bool)
-> Acc (Scalar Bool)
or = foldAll (||*) (constant False)
sum :: (Shape sh, Elt e, IsNum e)
=> Acc (Array sh e)
-> Acc (Scalar e)
sum = foldAll (+) 0
product :: (Shape sh, Elt e, IsNum e)
=> Acc (Array sh e)
-> Acc (Scalar e)
product = foldAll (*) 1
minimum :: (Shape sh, Elt e, IsScalar e)
=> Acc (Array sh e)
-> Acc (Scalar e)
minimum = fold1All min
maximum :: (Shape sh, Elt e, IsScalar e)
=> Acc (Array sh e)
-> Acc (Scalar e)
maximum = fold1All max
prescanl :: Elt a
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Vector a)
prescanl f e = P.fst . scanl' f e
postscanl :: Elt a
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Vector a)
postscanl f e = map (e `f`) . scanl1 f
prescanr :: Elt a
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Vector a)
prescanr f e = P.fst . scanr' f e
postscanr :: Elt a
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Vector a)
postscanr f e = map (`f` e) . scanr1 f
scanlSeg :: (Elt a, Elt i, IsIntegral i)
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Segments i)
-> Acc (Vector a)
scanlSeg f z vec seg = scanl1Seg f vec' seg'
where
seg' = map (+1) seg
vec' = permute const
(fill (index1 $ size vec + size seg) z)
(\ix -> index1' $ unindex1' ix + inc ! ix)
vec
flags = mkHeadFlags seg
inc = scanl1 (+) flags
scanl'Seg :: forall a i. (Elt a, Elt i, IsIntegral i)
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Segments i)
-> Acc (Vector a, Vector a)
scanl'Seg f z vec seg = result
where
result = lift (body, sums)
vec' = scanlSeg f z vec seg
seg' = map (+1) seg
tails = zipWith (+) seg . P.fst $ scanl' (+) 0 seg'
sums = backpermute (shape seg) (\ix -> index1' $ tails ! ix) vec'
offset = scanl1 (+) seg
inc = scanl1 (+)
$ permute (+) (fill (index1 $ size vec + 1) 0)
(\ix -> index1' $ offset ! ix)
(fill (shape seg) (1 :: Exp i))
body = backpermute (shape vec)
(\ix -> index1' $ unindex1' ix + inc ! ix)
vec'
scanl1Seg :: (Elt a, Elt i, IsIntegral i)
=> (Exp a -> Exp a -> Exp a)
-> Acc (Vector a)
-> Acc (Segments i)
-> Acc (Vector a)
scanl1Seg f vec seg
= P.snd
. unzip
. scanl1 (segmented f)
$ zip (mkHeadFlags seg) vec
prescanlSeg :: (Elt a, Elt i, IsIntegral i)
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Segments i)
-> Acc (Vector a)
prescanlSeg f e vec seg
= P.fst
. unatup2
$ scanl'Seg f e vec seg
postscanlSeg :: (Elt a, Elt i, IsIntegral i)
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Segments i)
-> Acc (Vector a)
postscanlSeg f e vec seg
= map (f e)
$ scanl1Seg f vec seg
scanrSeg :: (Elt a, Elt i, IsIntegral i)
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Segments i)
-> Acc (Vector a)
scanrSeg f z vec seg = scanr1Seg f vec' seg'
where
inc = scanl1 (+) (mkHeadFlags seg)
seg' = map (+1) seg
vec' = permute const
(fill (index1 $ size vec + size seg) z)
(\ix -> index1' $ unindex1' ix + inc ! ix 1)
vec
scanr'Seg :: forall a i. (Elt a, Elt i, IsIntegral i)
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Segments i)
-> Acc (Vector a, Vector a)
scanr'Seg f z vec seg = result
where
result = lift (body, sums)
vec' = scanrSeg f z vec seg
seg' = map (+1) seg
heads = P.fst $ scanl' (+) 0 seg'
sums = backpermute (shape seg) (\ix -> index1' $ heads ! ix) vec'
inc = scanl1 (+) $ mkHeadFlags seg
body = backpermute (shape vec)
(\ix -> index1' $ unindex1' ix + inc ! ix)
vec'
scanr1Seg :: (Elt a, Elt i, IsIntegral i)
=> (Exp a -> Exp a -> Exp a)
-> Acc (Vector a)
-> Acc (Segments i)
-> Acc (Vector a)
scanr1Seg f vec seg
= P.snd
. unzip
. scanr1 (segmented f)
$ zip (mkTailFlags seg) vec
prescanrSeg :: (Elt a, Elt i, IsIntegral i)
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Segments i)
-> Acc (Vector a)
prescanrSeg f e vec seg
= P.fst
. unatup2
$ scanr'Seg f e vec seg
postscanrSeg :: (Elt a, Elt i, IsIntegral i)
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Vector a)
-> Acc (Segments i)
-> Acc (Vector a)
postscanrSeg f e vec seg
= map (f e)
$ scanr1Seg f vec seg
mkHeadFlags :: (Elt i, IsIntegral i) => Acc (Segments i) -> Acc (Segments i)
mkHeadFlags seg
= init
$ permute (+) zeros (\ix -> index1' (offset ! ix)) ones
where
(offset, len) = scanl' (+) 0 seg
zeros = fill (index1' $ the len + 1) 0
ones = fill (index1 $ size offset) 1
mkTailFlags :: (Elt i, IsIntegral i) => Acc (Segments i) -> Acc (Segments i)
mkTailFlags seg
= init
$ permute (+) zeros (\ix -> index1' (the len 1 offset ! ix)) ones
where
(offset, len) = scanr' (+) 0 seg
zeros = fill (index1' $ the len + 1) 0
ones = fill (index1 $ size offset) 1
segmented :: (Elt e, Elt i, IsIntegral i)
=> (Exp e -> Exp e -> Exp e)
-> Exp (i, e) -> Exp (i, e) -> Exp (i, e)
segmented f a b =
let (aF, aV) = unlift a
(bF, bV) = unlift b
in
lift (aF .|. bF, bF /=* 0 ? (bV, f aV bV))
index1' :: (Elt i, IsIntegral i) => Exp i -> Exp DIM1
index1' i = lift (Z :. fromIntegral i)
unindex1' :: (Elt i, IsIntegral i) => Exp DIM1 -> Exp i
unindex1' ix = let Z :. i = unlift ix in fromIntegral i
flatten :: (Shape ix, Elt a) => Acc (Array ix a) -> Acc (Vector a)
flatten a = reshape (index1 $ size a) a
fill :: (Shape sh, Elt e) => Exp sh -> Exp e -> Acc (Array sh e)
fill sh c = generate sh (const c)
enumFromN :: (Shape sh, Elt e, IsNum e) => Exp sh -> Exp e -> Acc (Array sh e)
enumFromN sh x = enumFromStepN sh x 1
enumFromStepN :: (Shape sh, Elt e, IsNum e)
=> Exp sh
-> Exp e
-> Exp e
-> Acc (Array sh e)
enumFromStepN sh x y
= reshape sh
$ generate (index1 $ shapeSize sh)
(\ix -> (fromIntegral (unindex1 ix :: Exp Int) * y) + x)
filter :: Elt a
=> (Exp a -> Exp Bool)
-> Acc (Vector a)
-> Acc (Vector a)
filter p arr
= let flags = map (boolToInt . p) arr
(targetIdx, len) = scanl' (+) 0 flags
arr' = backpermute (index1 $ the len) id arr
in
permute const arr' (\ix -> flags!ix ==* 0 ? (ignore, index1 $ targetIdx!ix)) arr
gather :: (Elt e)
=> Acc (Vector Int)
-> Acc (Vector e)
-> Acc (Vector e)
gather mapV inputV = backpermute (shape mapV) bpF inputV
where
bpF ix = lift (Z :. (mapV ! ix))
gatherIf :: (Elt e, Elt e')
=> Acc (Vector Int)
-> Acc (Vector e)
-> (Exp e -> Exp Bool)
-> Acc (Vector e')
-> Acc (Vector e')
-> Acc (Vector e')
gatherIf mapV maskV pred defaultV inputV = zipWith zwF predV gatheredV
where
zwF p g = p ? (unlift g)
gatheredV = zip (gather mapV inputV) defaultV
predV = map pred maskV
scatter :: (Elt e)
=> Acc (Vector Int)
-> Acc (Vector e)
-> Acc (Vector e)
-> Acc (Vector e)
scatter mapV defaultV inputV = permute (const) defaultV pF inputV
where
pF ix = lift (Z :. (mapV ! ix))
scatterIf :: (Elt e, Elt e')
=> Acc (Vector Int)
-> Acc (Vector e)
-> (Exp e -> Exp Bool)
-> Acc (Vector e')
-> Acc (Vector e')
-> Acc (Vector e')
scatterIf mapV maskV pred defaultV inputV = permute const defaultV pF inputV
where
pF ix = (pred (maskV ! ix)) ? (lift (Z :. (mapV ! ix)), ignore)
reverse :: Elt e => Acc (Vector e) -> Acc (Vector e)
reverse xs =
let len = unindex1 (shape xs)
pf i = len i 1
in backpermute (shape xs) (ilift1 pf) xs
transpose :: Elt e => Acc (Array DIM2 e) -> Acc (Array DIM2 e)
transpose mat =
let swap = lift1 $ \(Z:.x:.y) -> Z:.y:.x :: Z:.Exp Int:.Exp Int
in backpermute (swap $ shape mat) swap mat
take :: Elt e => Exp Int -> Acc (Vector e) -> Acc (Vector e)
take n =
let n' = the (unit n)
in backpermute (index1 n') id
drop :: Elt e => Exp Int -> Acc (Vector e) -> Acc (Vector e)
drop n arr =
let n' = the (unit n)
in backpermute (ilift1 (subtract n') (shape arr)) (ilift1 (+ n')) arr
init :: Elt e => Acc (Vector e) -> Acc (Vector e)
init arr = take ((unindex1 $ shape arr) 1) arr
tail :: Elt e => Acc (Vector e) -> Acc (Vector e)
tail arr = backpermute (ilift1 (subtract 1) (shape arr)) (ilift1 (+1)) arr
slit :: Elt e => Exp Int -> Exp Int -> Acc (Vector e) -> Acc (Vector e)
slit i n =
let i' = the (unit i)
n' = the (unit n)
in backpermute (index1 n') (ilift1 (+ i'))