module Data.Array.Accelerate.Prelude (
indexed,
imap,
zipWith3, zipWith4, zipWith5, zipWith6, zipWith7, zipWith8, zipWith9,
izipWith, izipWith3, izipWith4, izipWith5, izipWith6, izipWith7, izipWith8, izipWith9,
zip, zip3, zip4, zip5, zip6, zip7, zip8, zip9,
unzip, unzip3, unzip4, unzip5, unzip6, unzip7, unzip8, unzip9,
foldAll, fold1All,
all, any, and, or, sum, product, minimum, maximum,
prescanl, postscanl, prescanr, postscanr,
scanlSeg, scanl'Seg, scanl1Seg, prescanlSeg, postscanlSeg,
scanrSeg, scanr'Seg, scanr1Seg, prescanrSeg, postscanrSeg,
flatten,
fill, enumFromN, enumFromStepN,
(++),
filter,
scatter, scatterIf,
gather, gatherIf,
reverse, transpose,
init, tail, take, drop, slit,
compute,
IfThenElse(..),
(?|),
(?), caseof,
iterate,
sfoldl,
Lift(..), Unlift(..),
lift1, lift2, lift3, ilift1, ilift2, ilift3,
fst, afst, snd, asnd, curry, uncurry,
index0, index1, unindex1, index2, unindex2, index3, unindex3,
the, null, length,
) where
import Data.Typeable ( gcast )
import GHC.Base ( Constraint )
import Prelude ( (.), ($), Maybe(..), const, id, fromInteger, flip, undefined, fail )
import qualified Prelude as P
import Data.Array.Accelerate.Analysis.Match
import Data.Array.Accelerate.Array.Sugar hiding ( (!), ignore, shape, size, intersect, toIndex, fromIndex )
import Data.Array.Accelerate.Classes
import Data.Array.Accelerate.Language
import Data.Array.Accelerate.Lift
import Data.Array.Accelerate.Smart
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Data.Bits
indexed :: (Shape sh, Elt a) => Acc (Array sh a) -> Acc (Array sh (sh, a))
indexed xs = zip (generate (shape xs) id) xs
imap :: (Shape sh, Elt a, Elt b)
=> (Exp sh -> Exp a -> Exp b)
-> Acc (Array sh a)
-> Acc (Array sh b)
imap f xs = zipWith f (generate (shape xs) id) xs
zipWith3 :: (Shape sh, Elt a, Elt b, Elt c, Elt d)
=> (Exp a -> Exp b -> Exp c -> Exp d)
-> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh d)
zipWith3 f as bs cs
= generate (shape as `intersect` shape bs `intersect` shape cs)
(\ix -> f (as ! ix) (bs ! ix) (cs ! ix))
zipWith4 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e)
=> (Exp a -> Exp b -> Exp c -> Exp d -> Exp e)
-> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh d)
-> Acc (Array sh e)
zipWith4 f as bs cs ds
= generate (shape as `intersect` shape bs `intersect`
shape cs `intersect` shape ds)
(\ix -> f (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix))
zipWith5 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f)
=> (Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f)
-> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh d)
-> Acc (Array sh e)
-> Acc (Array sh f)
zipWith5 f as bs cs ds es
= generate (shape as `intersect` shape bs `intersect` shape cs
`intersect` shape ds `intersect` shape es)
(\ix -> f (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix) (es ! ix))
zipWith6 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g)
=> (Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f -> Exp g)
-> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh d)
-> Acc (Array sh e)
-> Acc (Array sh f)
-> Acc (Array sh g)
zipWith6 f as bs cs ds es fs
= generate (shape as `intersect` shape bs `intersect` shape cs
`intersect` shape ds `intersect` shape es
`intersect` shape fs)
(\ix -> f (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix) (es ! ix) (fs ! ix))
zipWith7 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h)
=> (Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f -> Exp g -> Exp h)
-> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh d)
-> Acc (Array sh e)
-> Acc (Array sh f)
-> Acc (Array sh g)
-> Acc (Array sh h)
zipWith7 f as bs cs ds es fs gs
= generate (shape as `intersect` shape bs `intersect` shape cs
`intersect` shape ds `intersect` shape es
`intersect` shape fs `intersect` shape gs)
(\ix -> f (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix) (es ! ix) (fs ! ix) (gs ! ix))
zipWith8 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i)
=> (Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f -> Exp g -> Exp h -> Exp i)
-> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh d)
-> Acc (Array sh e)
-> Acc (Array sh f)
-> Acc (Array sh g)
-> Acc (Array sh h)
-> Acc (Array sh i)
zipWith8 f as bs cs ds es fs gs hs
= generate (shape as `intersect` shape bs `intersect` shape cs
`intersect` shape ds `intersect` shape es
`intersect` shape fs `intersect` shape gs
`intersect` shape hs)
(\ix -> f (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix) (es ! ix) (fs ! ix) (gs ! ix) (hs ! ix))
zipWith9 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i, Elt j)
=> (Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f -> Exp g -> Exp h -> Exp i -> Exp j)
-> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh d)
-> Acc (Array sh e)
-> Acc (Array sh f)
-> Acc (Array sh g)
-> Acc (Array sh h)
-> Acc (Array sh i)
-> Acc (Array sh j)
zipWith9 f as bs cs ds es fs gs hs is
= generate (shape as `intersect` shape bs `intersect` shape cs
`intersect` shape ds `intersect` shape es
`intersect` shape fs `intersect` shape gs
`intersect` shape hs `intersect` shape is)
(\ix -> f (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix) (es ! ix) (fs ! ix) (gs ! ix) (hs ! ix) (is ! ix))
izipWith :: (Shape sh, Elt a, Elt b, Elt c)
=> (Exp sh -> Exp a -> Exp b -> Exp c)
-> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
izipWith f as bs
= generate (shape as `intersect` shape bs)
(\ix -> f ix (as ! ix) (bs ! ix))
izipWith3 :: (Shape sh, Elt a, Elt b, Elt c, Elt d)
=> (Exp sh -> Exp a -> Exp b -> Exp c -> Exp d)
-> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh d)
izipWith3 f as bs cs
= generate (shape as `intersect` shape bs `intersect` shape cs)
(\ix -> f ix (as ! ix) (bs ! ix) (cs ! ix))
izipWith4 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e)
=> (Exp sh -> Exp a -> Exp b -> Exp c -> Exp d -> Exp e)
-> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh d)
-> Acc (Array sh e)
izipWith4 f as bs cs ds
= generate (shape as `intersect` shape bs `intersect`
shape cs `intersect` shape ds)
(\ix -> f ix (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix))
izipWith5 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f)
=> (Exp sh -> Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f)
-> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh d)
-> Acc (Array sh e)
-> Acc (Array sh f)
izipWith5 f as bs cs ds es
= generate (shape as `intersect` shape bs `intersect` shape cs
`intersect` shape ds `intersect` shape es)
(\ix -> f ix (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix) (es ! ix))
izipWith6 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g)
=> (Exp sh -> Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f -> Exp g)
-> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh d)
-> Acc (Array sh e)
-> Acc (Array sh f)
-> Acc (Array sh g)
izipWith6 f as bs cs ds es fs
= generate (shape as `intersect` shape bs `intersect` shape cs
`intersect` shape ds `intersect` shape es
`intersect` shape fs)
(\ix -> f ix (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix) (es ! ix) (fs ! ix))
izipWith7 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h)
=> (Exp sh -> Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f -> Exp g -> Exp h)
-> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh d)
-> Acc (Array sh e)
-> Acc (Array sh f)
-> Acc (Array sh g)
-> Acc (Array sh h)
izipWith7 f as bs cs ds es fs gs
= generate (shape as `intersect` shape bs `intersect` shape cs
`intersect` shape ds `intersect` shape es
`intersect` shape fs `intersect` shape gs)
(\ix -> f ix (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix) (es ! ix) (fs ! ix) (gs ! ix))
izipWith8 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i)
=> (Exp sh -> Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f -> Exp g -> Exp h -> Exp i)
-> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh d)
-> Acc (Array sh e)
-> Acc (Array sh f)
-> Acc (Array sh g)
-> Acc (Array sh h)
-> Acc (Array sh i)
izipWith8 f as bs cs ds es fs gs hs
= generate (shape as `intersect` shape bs `intersect` shape cs
`intersect` shape ds `intersect` shape es
`intersect` shape fs `intersect` shape gs
`intersect` shape hs)
(\ix -> f ix (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix) (es ! ix) (fs ! ix) (gs ! ix) (hs ! ix))
izipWith9 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i, Elt j)
=> (Exp sh -> Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f -> Exp g -> Exp h -> Exp i -> Exp j)
-> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh d)
-> Acc (Array sh e)
-> Acc (Array sh f)
-> Acc (Array sh g)
-> Acc (Array sh h)
-> Acc (Array sh i)
-> Acc (Array sh j)
izipWith9 f as bs cs ds es fs gs hs is
= generate (shape as `intersect` shape bs `intersect` shape cs
`intersect` shape ds `intersect` shape es
`intersect` shape fs `intersect` shape gs
`intersect` shape hs `intersect` shape is)
(\ix -> f ix (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix) (es ! ix) (fs ! ix) (gs ! ix) (hs ! ix) (is ! ix))
zip :: (Shape sh, Elt a, Elt b)
=> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh (a, b))
zip = zipWith (curry lift)
zip3 :: (Shape sh, Elt a, Elt b, Elt c)
=> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh (a, b, c))
zip3 = zipWith3 (\a b c -> lift (a,b,c))
zip4 :: (Shape sh, Elt a, Elt b, Elt c, Elt d)
=> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh d)
-> Acc (Array sh (a, b, c, d))
zip4 = zipWith4 (\a b c d -> lift (a,b,c,d))
zip5 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e)
=> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh d)
-> Acc (Array sh e)
-> Acc (Array sh (a, b, c, d, e))
zip5 = zipWith5 (\a b c d e -> lift (a,b,c,d,e))
zip6 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f)
=> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh d)
-> Acc (Array sh e)
-> Acc (Array sh f)
-> Acc (Array sh (a, b, c, d, e, f))
zip6 = zipWith6 (\a b c d e f -> lift (a,b,c,d,e,f))
zip7 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g)
=> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh d)
-> Acc (Array sh e)
-> Acc (Array sh f)
-> Acc (Array sh g)
-> Acc (Array sh (a, b, c, d, e, f, g))
zip7 = zipWith7 (\a b c d e f g -> lift (a,b,c,d,e,f,g))
zip8 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h)
=> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh d)
-> Acc (Array sh e)
-> Acc (Array sh f)
-> Acc (Array sh g)
-> Acc (Array sh h)
-> Acc (Array sh (a, b, c, d, e, f, g, h))
zip8 = zipWith8 (\a b c d e f g h -> lift (a,b,c,d,e,f,g,h))
zip9 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i)
=> Acc (Array sh a)
-> Acc (Array sh b)
-> Acc (Array sh c)
-> Acc (Array sh d)
-> Acc (Array sh e)
-> Acc (Array sh f)
-> Acc (Array sh g)
-> Acc (Array sh h)
-> Acc (Array sh i)
-> Acc (Array sh (a, b, c, d, e, f, g, h, i))
zip9 = zipWith9 (\a b c d e f g h i -> lift (a,b,c,d,e,f,g,h,i))
unzip :: (Shape sh, Elt a, Elt b)
=> Acc (Array sh (a, b))
-> (Acc (Array sh a), Acc (Array sh b))
unzip arr = (map fst arr, map snd arr)
unzip3 :: (Shape sh, Elt a, Elt b, Elt c)
=> Acc (Array sh (a, b, c))
-> (Acc (Array sh a), Acc (Array sh b), Acc (Array sh c))
unzip3 xs = (map get1 xs, map get2 xs, map get3 xs)
where
get1 x = let (a,_,_) = untup3 x in a
get2 x = let (_,b,_) = untup3 x in b
get3 x = let (_,_,c) = untup3 x in c
unzip4 :: (Shape sh, Elt a, Elt b, Elt c, Elt d)
=> Acc (Array sh (a, b, c, d))
-> (Acc (Array sh a), Acc (Array sh b), Acc (Array sh c), Acc (Array sh d))
unzip4 xs = (map get1 xs, map get2 xs, map get3 xs, map get4 xs)
where
get1 x = let (a,_,_,_) = untup4 x in a
get2 x = let (_,b,_,_) = untup4 x in b
get3 x = let (_,_,c,_) = untup4 x in c
get4 x = let (_,_,_,d) = untup4 x in d
unzip5 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e)
=> Acc (Array sh (a, b, c, d, e))
-> (Acc (Array sh a), Acc (Array sh b), Acc (Array sh c), Acc (Array sh d), Acc (Array sh e))
unzip5 xs = (map get1 xs, map get2 xs, map get3 xs, map get4 xs, map get5 xs)
where
get1 x = let (a,_,_,_,_) = untup5 x in a
get2 x = let (_,b,_,_,_) = untup5 x in b
get3 x = let (_,_,c,_,_) = untup5 x in c
get4 x = let (_,_,_,d,_) = untup5 x in d
get5 x = let (_,_,_,_,e) = untup5 x in e
unzip6 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f)
=> Acc (Array sh (a, b, c, d, e, f))
-> ( Acc (Array sh a), Acc (Array sh b), Acc (Array sh c)
, Acc (Array sh d), Acc (Array sh e), Acc (Array sh f))
unzip6 xs = (map get1 xs, map get2 xs, map get3 xs, map get4 xs, map get5 xs, map get6 xs)
where
get1 x = let (a,_,_,_,_,_) = untup6 x in a
get2 x = let (_,b,_,_,_,_) = untup6 x in b
get3 x = let (_,_,c,_,_,_) = untup6 x in c
get4 x = let (_,_,_,d,_,_) = untup6 x in d
get5 x = let (_,_,_,_,e,_) = untup6 x in e
get6 x = let (_,_,_,_,_,f) = untup6 x in f
unzip7 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g)
=> Acc (Array sh (a, b, c, d, e, f, g))
-> ( Acc (Array sh a), Acc (Array sh b), Acc (Array sh c)
, Acc (Array sh d), Acc (Array sh e), Acc (Array sh f)
, Acc (Array sh g))
unzip7 xs = ( map get1 xs, map get2 xs, map get3 xs
, map get4 xs, map get5 xs, map get6 xs
, map get7 xs )
where
get1 x = let (a,_,_,_,_,_,_) = untup7 x in a
get2 x = let (_,b,_,_,_,_,_) = untup7 x in b
get3 x = let (_,_,c,_,_,_,_) = untup7 x in c
get4 x = let (_,_,_,d,_,_,_) = untup7 x in d
get5 x = let (_,_,_,_,e,_,_) = untup7 x in e
get6 x = let (_,_,_,_,_,f,_) = untup7 x in f
get7 x = let (_,_,_,_,_,_,g) = untup7 x in g
unzip8 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h)
=> Acc (Array sh (a, b, c, d, e, f, g, h))
-> ( Acc (Array sh a), Acc (Array sh b), Acc (Array sh c)
, Acc (Array sh d), Acc (Array sh e), Acc (Array sh f)
, Acc (Array sh g), Acc (Array sh h) )
unzip8 xs = ( map get1 xs, map get2 xs, map get3 xs
, map get4 xs, map get5 xs, map get6 xs
, map get7 xs, map get8 xs )
where
get1 x = let (a,_,_,_,_,_,_,_) = untup8 x in a
get2 x = let (_,b,_,_,_,_,_,_) = untup8 x in b
get3 x = let (_,_,c,_,_,_,_,_) = untup8 x in c
get4 x = let (_,_,_,d,_,_,_,_) = untup8 x in d
get5 x = let (_,_,_,_,e,_,_,_) = untup8 x in e
get6 x = let (_,_,_,_,_,f,_,_) = untup8 x in f
get7 x = let (_,_,_,_,_,_,g,_) = untup8 x in g
get8 x = let (_,_,_,_,_,_,_,h) = untup8 x in h
unzip9 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i)
=> Acc (Array sh (a, b, c, d, e, f, g, h, i))
-> ( Acc (Array sh a), Acc (Array sh b), Acc (Array sh c)
, Acc (Array sh d), Acc (Array sh e), Acc (Array sh f)
, Acc (Array sh g), Acc (Array sh h), Acc (Array sh i))
unzip9 xs = ( map get1 xs, map get2 xs, map get3 xs
, map get4 xs, map get5 xs, map get6 xs
, map get7 xs, map get8 xs, map get9 xs )
where
get1 x = let (a,_,_,_,_,_,_,_,_) = untup9 x in a
get2 x = let (_,b,_,_,_,_,_,_,_) = untup9 x in b
get3 x = let (_,_,c,_,_,_,_,_,_) = untup9 x in c
get4 x = let (_,_,_,d,_,_,_,_,_) = untup9 x in d
get5 x = let (_,_,_,_,e,_,_,_,_) = untup9 x in e
get6 x = let (_,_,_,_,_,f,_,_,_) = untup9 x in f
get7 x = let (_,_,_,_,_,_,g,_,_) = untup9 x in g
get8 x = let (_,_,_,_,_,_,_,h,_) = untup9 x in h
get9 x = let (_,_,_,_,_,_,_,_,i) = untup9 x in i
foldAll :: (Shape sh, Elt a)
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Array sh a)
-> Acc (Scalar a)
foldAll f e arr = fold f e (flatten arr)
fold1All :: (Shape sh, Elt a)
=> (Exp a -> Exp a -> Exp a)
-> Acc (Array sh a)
-> Acc (Scalar a)
fold1All f arr = fold1 f (flatten arr)
all :: (Shape sh, Elt e)
=> (Exp e -> Exp Bool)
-> Acc (Array sh e)
-> Acc (Scalar Bool)
all f = and . map f
any :: (Shape sh, Elt e)
=> (Exp e -> Exp Bool)
-> Acc (Array sh e)
-> Acc (Scalar Bool)
any f = or . map f
and :: Shape sh
=> Acc (Array sh Bool)
-> Acc (Scalar Bool)
and = foldAll (&&) (constant True)
or :: Shape sh
=> Acc (Array sh Bool)
-> Acc (Scalar Bool)
or = foldAll (||) (constant False)
sum :: (Shape sh, Num e)
=> Acc (Array sh e)
-> Acc (Scalar e)
sum = foldAll (+) 0
product :: (Shape sh, Num e)
=> Acc (Array sh e)
-> Acc (Scalar e)
product = foldAll (*) 1
minimum :: (Shape sh, Ord e)
=> Acc (Array sh e)
-> Acc (Scalar e)
minimum = fold1All min
maximum :: (Shape sh, Ord e)
=> Acc (Array sh e)
-> Acc (Scalar e)
maximum = fold1All max
prescanl :: (Shape sh, Elt a)
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Array (sh:.Int) a)
-> Acc (Array (sh:.Int) a)
prescanl f e = P.fst . scanl' f e
postscanl :: (Shape sh, Elt a)
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Array (sh:.Int) a)
-> Acc (Array (sh:.Int) a)
postscanl f e = map (e `f`) . scanl1 f
prescanr :: (Shape sh, Elt a)
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Array (sh:.Int) a)
-> Acc (Array (sh:.Int) a)
prescanr f e = P.fst . scanr' f e
postscanr :: (Shape sh, Elt a)
=> (Exp a -> Exp a -> Exp a)
-> Exp a
-> Acc (Array (sh:.Int) a)
-> Acc (Array (sh:.Int) a)
postscanr f e = map (`f` e) . scanr1 f
scanlSeg
:: forall sh e i. (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int)
=> (Exp e -> Exp e -> Exp e)
-> Exp e
-> Acc (Array (sh:.Int) e)
-> Acc (Segments i)
-> Acc (Array (sh:.Int) e)
scanlSeg f z arr seg =
if null arr || null flags
then fill (lift (sh:.sz + length seg)) z
else scanl1Seg f arr' seg'
where
sh :. sz = unlift (shape arr) :: Exp sh :. Exp Int
seg' = map (+1) seg
arr' = permute const
(fill (lift (sh :. sz + length seg)) z)
(\ix -> let sx :. i = unlift ix :: Exp sh :. Exp Int
in lift (sx :. i + fromIntegral (inc ! index1 i)))
(take (length flags) arr)
flags = mkHeadFlags seg
inc = scanl1 (+) flags
scanl'Seg
:: forall sh e i. (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int)
=> (Exp e -> Exp e -> Exp e)
-> Exp e
-> Acc (Array (sh:.Int) e)
-> Acc (Segments i)
-> Acc (Array (sh:.Int) e, Array (sh:.Int) e)
scanl'Seg f z arr seg =
if null arr
then lift (arr, fill (lift (indexTail (shape arr) :. length seg)) z)
else lift (body, sums)
where
arr' = scanlSeg f z arr seg
seg' = map (+1) seg
tails = zipWith (+) seg . P.fst $ scanl' (+) 0 seg'
sums = backpermute
(lift (indexTail (shape arr') :. length seg))
(\ix -> let sz:.i = unlift ix :: Exp sh :. Exp Int
in lift (sz :. fromIntegral (tails ! index1 i)))
arr'
offset = scanl1 (+) seg
inc = scanl1 (+)
$ permute (+) (fill (index1 $ size arr + 1) 0)
(\ix -> index1' $ offset ! ix)
(fill (shape seg) (1 :: Exp i))
len = offset ! index1 (length offset 1)
body = backpermute
(lift (indexTail (shape arr) :. fromIntegral len))
(\ix -> let sz:.i = unlift ix :: Exp sh :. Exp Int
in lift (sz :. i + fromIntegral (inc ! index1 i)))
arr'
scanl1Seg
:: (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int)
=> (Exp e -> Exp e -> Exp e)
-> Acc (Array (sh:.Int) e)
-> Acc (Segments i)
-> Acc (Array (sh:.Int) e)
scanl1Seg f arr seg
= P.snd
. unzip
. scanl1 (segmented f)
$ zip (replicate (lift (indexTail (shape arr) :. All)) (mkHeadFlags seg)) arr
prescanlSeg
:: (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int)
=> (Exp e -> Exp e -> Exp e)
-> Exp e
-> Acc (Array (sh:.Int) e)
-> Acc (Segments i)
-> Acc (Array (sh:.Int) e)
prescanlSeg f e vec seg
= afst
$ scanl'Seg f e vec seg
postscanlSeg
:: (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int)
=> (Exp e -> Exp e -> Exp e)
-> Exp e
-> Acc (Array (sh:.Int) e)
-> Acc (Segments i)
-> Acc (Array (sh:.Int) e)
postscanlSeg f e vec seg
= map (f e)
$ scanl1Seg f vec seg
scanrSeg
:: forall sh e i. (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int)
=> (Exp e -> Exp e -> Exp e)
-> Exp e
-> Acc (Array (sh:.Int) e)
-> Acc (Segments i)
-> Acc (Array (sh:.Int) e)
scanrSeg f z arr seg =
if null arr || null flags
then fill (lift (sh :. sz + length seg)) z
else scanr1Seg f arr' seg'
where
sh :. sz = unlift (shape arr) :: Exp sh :. Exp Int
flags = mkHeadFlags seg
inc = scanl1 (+) flags
seg' = map (+1) seg
arr' = permute const
(fill (lift (sh :. sz + length seg)) z)
(\ix -> let sx :. i = unlift ix :: Exp sh :. Exp Int
in lift (sx :. i + fromIntegral (inc ! index1 i) 1))
(drop (sz length flags) arr)
scanr'Seg
:: forall sh e i. (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int)
=> (Exp e -> Exp e -> Exp e)
-> Exp e
-> Acc (Array (sh:.Int) e)
-> Acc (Segments i)
-> Acc (Array (sh:.Int) e, Array (sh:.Int) e)
scanr'Seg f z arr seg =
if null arr
then lift (arr, fill (lift (indexTail (shape arr) :. length seg)) z)
else lift (body, sums)
where
arr' = scanrSeg f z arr seg
seg' = map (+1) seg
heads = P.fst $ scanl' (+) 0 seg'
sums = backpermute
(lift (indexTail (shape arr') :. length seg))
(\ix -> let sz:.i = unlift ix :: Exp sh :. Exp Int
in lift (sz :. fromIntegral (heads ! index1 i)))
arr'
flags = mkHeadFlags seg
inc = scanl1 (+) flags
body = backpermute
(lift (indexTail (shape arr) :. indexHead (shape flags)))
(\ix -> let sz:.i = unlift ix :: Exp sh :. Exp Int
in lift (sz :. i + fromIntegral (inc ! index1 i)))
arr'
scanr1Seg
:: (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int)
=> (Exp e -> Exp e -> Exp e)
-> Acc (Array (sh:.Int) e)
-> Acc (Segments i)
-> Acc (Array (sh:.Int) e)
scanr1Seg f arr seg
= P.snd
. unzip
. scanr1 (flip (segmented f))
$ zip (replicate (lift (indexTail (shape arr) :. All)) (mkTailFlags seg)) arr
prescanrSeg
:: (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int)
=> (Exp e -> Exp e -> Exp e)
-> Exp e
-> Acc (Array (sh:.Int) e)
-> Acc (Segments i)
-> Acc (Array (sh:.Int) e)
prescanrSeg f e vec seg
= afst
$ scanr'Seg f e vec seg
postscanrSeg
:: (Shape sh, Slice sh, Elt e, Integral i, Bits i, FromIntegral i Int)
=> (Exp e -> Exp e -> Exp e)
-> Exp e
-> Acc (Array (sh:.Int) e)
-> Acc (Segments i)
-> Acc (Array (sh:.Int) e)
postscanrSeg f e vec seg
= map (f e)
$ scanr1Seg f vec seg
mkHeadFlags
:: (Integral i, FromIntegral i Int)
=> Acc (Segments i)
-> Acc (Segments i)
mkHeadFlags seg
= init
$ permute (+) zeros (\ix -> index1' (offset ! ix)) ones
where
(offset, len) = scanl' (+) 0 seg
zeros = fill (index1' $ the len + 1) 0
ones = fill (index1 $ size offset) 1
mkTailFlags
:: (Integral i, FromIntegral i Int)
=> Acc (Segments i)
-> Acc (Segments i)
mkTailFlags seg
= init
$ permute (+) zeros (\ix -> index1' (the len 1 offset ! ix)) ones
where
(offset, len) = scanr' (+) 0 seg
zeros = fill (index1' $ the len + 1) 0
ones = fill (index1 $ size offset) 1
segmented
:: (Elt e, Num i, Bits i)
=> (Exp e -> Exp e -> Exp e)
-> Exp (i, e)
-> Exp (i, e)
-> Exp (i, e)
segmented f a b =
let (aF, aV) = unlift a
(bF, bV) = unlift b
in
lift (aF .|. bF, bF /= 0 ? (bV, f aV bV))
index1' :: (Integral i, FromIntegral i Int) => Exp i -> Exp DIM1
index1' i = lift (Z :. fromIntegral i)
flatten :: forall sh e. (Shape sh, Elt e) => Acc (Array sh e) -> Acc (Vector e)
flatten a
| Just Refl <- matchShapeType (undefined::sh) (undefined::DIM1)
= a
flatten a
= reshape (index1 $ size a) a
fill :: (Shape sh, Elt e) => Exp sh -> Exp e -> Acc (Array sh e)
fill sh c = generate sh (const c)
enumFromN
:: (Shape sh, Num e, FromIntegral Int e)
=> Exp sh
-> Exp e
-> Acc (Array sh e)
enumFromN sh x = enumFromStepN sh x 1
enumFromStepN
:: (Shape sh, Num e, FromIntegral Int e)
=> Exp sh
-> Exp e
-> Exp e
-> Acc (Array sh e)
enumFromStepN sh x y
= reshape sh
$ generate (index1 $ shapeSize sh)
(\ix -> (fromIntegral (unindex1 ix :: Exp Int) * y) + x)
infixr 5 ++
(++) :: forall sh e. (Slice sh, Shape sh, Elt e)
=> Acc (Array (sh :. Int) e)
-> Acc (Array (sh :. Int) e)
-> Acc (Array (sh :. Int) e)
(++) xs ys
= let sh1 :. n = unlift (shape xs) :: Exp sh :. Exp Int
sh2 :. m = unlift (shape ys) :: Exp sh :. Exp Int
in
generate (lift (intersect sh1 sh2 :. n + m))
(\ix -> let sh :. i = unlift ix :: Exp sh :. Exp Int
in i < n ? ( xs ! ix, ys ! lift (sh :. in)) )
filter :: forall sh e. (Shape sh, Slice sh, Elt e)
=> (Exp e -> Exp Bool)
-> Acc (Array (sh:.Int) e)
-> Acc (Vector e, Array sh Int)
filter p arr
| Just Refl <- matchShapeType (undefined::sh) (undefined::Z)
= let
keep = map p arr
(target, len) = scanl' (+) 0 (map boolToInt keep)
prj ix = keep!ix ? ( index1 (target!ix), ignore )
dummy = backpermute (index1 (the len)) id arr
result = permute const dummy prj arr
in
if null arr
then lift (emptyArray, fill (constant Z) 0)
else lift (result, len)
filter p arr
= let
sz = indexTail (shape arr)
keep = map p arr
(target, len) = scanl' (+) 0 (map boolToInt keep)
(offset, valid) = scanl' (+) 0 (flatten len)
prj ix = if keep!ix
then index1 $ offset!index1 (toIndex sz (indexTail ix)) + target!ix
else ignore
dummy = backpermute (index1 (the valid)) id (flatten arr)
result = permute const dummy prj arr
in
if null arr
then lift (emptyArray, fill sz 0)
else lift (result, len)
gather
:: (Shape sh, Elt e)
=> Acc (Array sh Int)
-> Acc (Vector e)
-> Acc (Array sh e)
gather indices input = map (input !!) indices
gatherIf
:: (Elt a, Elt b)
=> Acc (Vector Int)
-> Acc (Vector a)
-> (Exp a -> Exp Bool)
-> Acc (Vector b)
-> Acc (Vector b)
-> Acc (Vector b)
gatherIf from maskV pred defaults input = zipWith zf pf gatheredV
where
zf p g = p ? (unlift g)
gatheredV = zip (gather from input) defaults
pf = map pred maskV
scatter
:: Elt e
=> Acc (Vector Int)
-> Acc (Vector e)
-> Acc (Vector e)
-> Acc (Vector e)
scatter to defaults input = permute const defaults pf input'
where
pf ix = index1 (to ! ix)
input' = backpermute (shape to `intersect` shape input) id input
scatterIf
:: (Elt e, Elt e')
=> Acc (Vector Int)
-> Acc (Vector e)
-> (Exp e -> Exp Bool)
-> Acc (Vector e')
-> Acc (Vector e')
-> Acc (Vector e')
scatterIf to maskV pred defaults input = permute const defaults pf input'
where
pf ix = pred (maskV ! ix) ? ( index1 (to ! ix), ignore )
input' = backpermute (shape to `intersect` shape input) id input
reverse :: Elt e => Acc (Vector e) -> Acc (Vector e)
reverse xs =
let len = unindex1 (shape xs)
pf i = len i 1
in backpermute (shape xs) (ilift1 pf) xs
transpose :: Elt e => Acc (Array DIM2 e) -> Acc (Array DIM2 e)
transpose mat =
let swap = lift1 $ \(Z:.x:.y) -> Z:.y:.x :: Z:.Exp Int:.Exp Int
in backpermute (swap $ shape mat) swap mat
take :: forall sh e. (Slice sh, Shape sh, Elt e)
=> Exp Int
-> Acc (Array (sh :. Int) e)
-> Acc (Array (sh :. Int) e)
take n acc =
let n' = the (unit (n `min` sz))
sh :. sz = unlift (shape acc) :: Exp sh :. Exp Int
in
backpermute (lift (sh :. n')) id acc
drop :: forall sh e. (Slice sh, Shape sh, Elt e)
=> Exp Int
-> Acc (Array (sh :. Int) e)
-> Acc (Array (sh :. Int) e)
drop n acc =
let n' = the (unit n)
sh :. sz = unlift (shape acc) :: Exp sh :. Exp Int
index ix = let j :. i = unlift ix :: Exp sh :. Exp Int
in lift (j :. i + n')
in
backpermute (lift (sh :. 0 `max` (sz n'))) index acc
init :: forall sh e. (Slice sh, Shape sh, Elt e)
=> Acc (Array (sh :. Int) e)
-> Acc (Array (sh :. Int) e)
init acc =
let sh :. sz = unlift (shape acc) :: Exp sh :. Exp Int
in backpermute (lift (sh :. sz `min` (sz 1))) id acc
tail :: forall sh e. (Slice sh, Shape sh, Elt e)
=> Acc (Array (sh :. Int) e)
-> Acc (Array (sh :. Int) e)
tail acc =
let sh :. sz = unlift (shape acc) :: Exp sh :. Exp Int
index ix = let j :. i = unlift ix :: Exp sh :. Exp Int
in lift (j :. i + 1)
in
backpermute (lift (sh :. 0 `max` (sz 1))) index acc
slit :: forall sh e. (Slice sh, Shape sh, Elt e)
=> Exp Int
-> Exp Int
-> Acc (Array (sh :. Int) e)
-> Acc (Array (sh :. Int) e)
slit m n acc =
let m' = the (unit m)
n' = the (unit n)
sh :. sz = unlift (shape acc) :: Exp sh :. Exp Int
index ix = let j :. i = unlift ix :: Exp sh :. Exp Int
in lift (j :. i + m')
in
backpermute (lift (sh :. (n' `min` ((sz m') `max` 0)))) index acc
compute :: Arrays a => Acc a -> Acc a
compute = id >-> id
infix 0 ?|
(?|) :: Arrays a => Exp Bool -> (Acc a, Acc a) -> Acc a
c ?| (t, e) = acond c t e
infix 0 ?
(?) :: Elt t => Exp Bool -> (Exp t, Exp t) -> Exp t
c ? (t, e) = cond c t e
caseof :: (Elt a, Elt b)
=> Exp a
-> [(Exp a -> Exp Bool, Exp b)]
-> Exp b
-> Exp b
caseof _ [] e = e
caseof x ((p,b):l) e = cond (p x) b (caseof x l e)
class IfThenElse t where
type EltT t a :: Constraint
ifThenElse :: EltT t a => Exp Bool -> t a -> t a -> t a
instance IfThenElse Exp where
type EltT Exp t = Elt t
ifThenElse = cond
instance IfThenElse Acc where
type EltT Acc a = Arrays a
ifThenElse = acond
iterate :: forall a. Elt a
=> Exp Int
-> (Exp a -> Exp a)
-> Exp a
-> Exp a
iterate n f z
= let step :: (Exp Int, Exp a) -> (Exp Int, Exp a)
step (i, acc) = ( i+1, f acc )
in
snd $ while (\v -> fst v < n) (lift1 step) (lift (constant 0, z))
sfoldl :: forall sh a b. (Shape sh, Slice sh, Elt a, Elt b)
=> (Exp a -> Exp b -> Exp a)
-> Exp a
-> Exp sh
-> Acc (Array (sh :. Int) b)
-> Exp a
sfoldl f z ix xs
= let step :: (Exp Int, Exp a) -> (Exp Int, Exp a)
step (i, acc) = ( i+1, acc `f` (xs ! lift (ix :. i)) )
(_ :. n) = unlift (shape xs) :: Exp sh :. Exp Int
in
snd $ while (\v -> fst v < n) (lift1 step) (lift (constant 0, z))
fst :: forall a b. (Elt a, Elt b) => Exp (a, b) -> Exp a
fst e = let (x, _::Exp b) = unlift e in x
afst :: forall a b. (Arrays a, Arrays b) => Acc (a, b) -> Acc a
afst a = let (x, _::Acc b) = unlift a in x
snd :: forall a b. (Elt a, Elt b) => Exp (a, b) -> Exp b
snd e = let (_:: Exp a, y) = unlift e in y
asnd :: forall a b. (Arrays a, Arrays b) => Acc (a, b) -> Acc b
asnd a = let (_::Acc a, y) = unlift a in y
curry :: Lift f (f a, f b) => (f (Plain (f a), Plain (f b)) -> f c) -> f a -> f b -> f c
curry f x y = f (lift (x, y))
uncurry :: Unlift f (f a, f b) => (f a -> f b -> f c) -> f (Plain (f a), Plain (f b)) -> f c
uncurry f t = let (x, y) = unlift t in f x y
index0 :: Exp Z
index0 = lift Z
index1 :: Elt i => Exp i -> Exp (Z :. i)
index1 i = lift (Z :. i)
unindex1 :: Elt i => Exp (Z :. i) -> Exp i
unindex1 ix = let Z :. i = unlift ix in i
index2 :: (Elt i, Slice (Z :. i))
=> Exp i
-> Exp i
-> Exp (Z :. i :. i)
index2 i j = lift (Z :. i :. j)
unindex2 :: forall i. (Elt i, Slice (Z :. i))
=> Exp (Z :. i :. i)
-> Exp (i, i)
unindex2 ix
= let Z :. i :. j = unlift ix :: Z :. Exp i :. Exp i
in lift (i, j)
index3
:: (Elt i, Slice (Z :. i), Slice (Z :. i :. i))
=> Exp i
-> Exp i
-> Exp i
-> Exp (Z :. i :. i :. i)
index3 k j i = lift (Z :. k :. j :. i)
unindex3
:: forall i. (Elt i, Slice (Z :. i), Slice (Z :. i :. i))
=> Exp (Z :. i :. i :. i)
-> Exp (i, i, i)
unindex3 ix = let Z :. k :. j :. i = unlift ix :: Z :. Exp i :. Exp i :. Exp i
in lift (k, j, i)
the :: Elt e => Acc (Scalar e) -> Exp e
the = (!index0)
null :: (Shape sh, Elt e) => Acc (Array sh e) -> Exp Bool
null arr = size arr == 0
length :: Elt e => Acc (Vector e) -> Exp Int
length = unindex1 . shape
emptyArray :: (Shape sh, Elt e) => Acc (Array sh e)
emptyArray = use (fromList empty [])
matchShapeType :: forall s t. (Shape s, Shape t) => s -> t -> Maybe (s :~: t)
matchShapeType _ _
| Just Refl <- matchTupleType (eltType (undefined::s)) (eltType (undefined::t))
= gcast Refl
matchShapeType _ _
= Nothing