{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverlappingInstances #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} -- | -- Module : Data.Array.Accelerate.Prelude -- Copyright : [2009..2014] Manuel M T Chakravarty, Gabriele Keller, Trevor L. McDonell -- [2010..2011] Ben Lever -- License : BSD3 -- -- Maintainer : Manuel M T Chakravarty -- Stability : experimental -- Portability : non-portable (GHC extensions) -- -- Standard functions that are not part of the core set (directly represented in -- the AST), but are instead implemented in terms of the core set. -- module Data.Array.Accelerate.Prelude ( -- * Zipping zipWith3, zipWith4, zipWith5, zipWith6, zipWith7, zipWith8, zipWith9, zip, zip3, zip4, zip5, zip6, zip7, zip8, zip9, -- * Unzipping unzip, unzip3, unzip4, unzip5, unzip6, unzip7, unzip8, unzip9, -- * Reductions foldAll, fold1All, -- ** Specialised folds all, any, and, or, sum, product, minimum, maximum, -- * Scans prescanl, postscanl, prescanr, postscanr, -- ** Segmented scans scanlSeg, scanl'Seg, scanl1Seg, prescanlSeg, postscanlSeg, scanrSeg, scanr'Seg, scanr1Seg, prescanrSeg, postscanrSeg, -- * Shape manipulation flatten, -- * Enumeration and filling fill, enumFromN, enumFromStepN, -- * Concatenation (++), -- * Working with predicates -- ** Filtering filter, -- ** Scatter / Gather scatter, scatterIf, gather, gatherIf, -- * Permutations reverse, transpose, -- * Extracting sub-vectors init, tail, take, drop, slit, -- * Array-level flow control (?|), -- * Expression-level flow control (?), caseof, -- * Scalar iteration iterate, -- * Scalar reduction sfoldl, -- sfoldr, -- * Lifting and unlifting Lift(..), Unlift(..), lift1, lift2, ilift1, ilift2, -- ** Tuple construction and destruction fst, afst, snd, asnd, curry, uncurry, -- ** Index construction and destruction index0, index1, unindex1, index2, unindex2, -- * Array operations with a scalar result the, null, length, ) where -- avoid clashes with Prelude functions -- import Data.Bits import Data.Bool import Prelude ((.), ($), (+), (-), (*), const, subtract, id, min, max, Float, Double, Char) import qualified Prelude as P -- friends import Data.Array.Accelerate.Array.Sugar hiding ((!), ignore, shape, size, intersect) import Data.Array.Accelerate.Language import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Tuple import Data.Array.Accelerate.Type -- Map-like composites -- ------------------- -- | Zip three arrays with the given function, analogous to 'zipWith'. -- zipWith3 :: (Shape sh, Elt a, Elt b, Elt c, Elt d) => (Exp a -> Exp b -> Exp c -> Exp d) -> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) zipWith3 f as bs cs = generate (shape as `intersect` shape bs `intersect` shape cs) (\ix -> f (as ! ix) (bs ! ix) (cs ! ix)) -- | Zip four arrays with the given function, analogous to 'zipWith'. -- zipWith4 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e) => (Exp a -> Exp b -> Exp c -> Exp d -> Exp e) -> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) zipWith4 f as bs cs ds = generate (shape as `intersect` shape bs `intersect` shape cs `intersect` shape ds) (\ix -> f (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix)) -- | Zip five arrays with the given function, analogous to 'zipWith'. -- zipWith5 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f) => (Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f) -> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) -> Acc (Array sh f) zipWith5 f as bs cs ds es = generate (shape as `intersect` shape bs `intersect` shape cs `intersect` shape ds `intersect` shape es) (\ix -> f (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix) (es ! ix)) -- | Zip six arrays with the given function, analogous to 'zipWith'. -- zipWith6 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g) => (Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f -> Exp g) -> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) -> Acc (Array sh f) -> Acc (Array sh g) zipWith6 f as bs cs ds es fs = generate (shape as `intersect` shape bs `intersect` shape cs `intersect` shape ds `intersect` shape es `intersect` shape fs) (\ix -> f (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix) (es ! ix) (fs ! ix)) -- | Zip seven arrays with the given function, analogous to 'zipWith'. -- zipWith7 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h) => (Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f -> Exp g -> Exp h) -> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) -> Acc (Array sh f) -> Acc (Array sh g) -> Acc (Array sh h) zipWith7 f as bs cs ds es fs gs = generate (shape as `intersect` shape bs `intersect` shape cs `intersect` shape ds `intersect` shape es `intersect` shape fs `intersect` shape gs) (\ix -> f (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix) (es ! ix) (fs ! ix) (gs ! ix)) -- | Zip eight arrays with the given function, analogous to 'zipWith'. -- zipWith8 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i) => (Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f -> Exp g -> Exp h -> Exp i) -> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) -> Acc (Array sh f) -> Acc (Array sh g) -> Acc (Array sh h) -> Acc (Array sh i) zipWith8 f as bs cs ds es fs gs hs = generate (shape as `intersect` shape bs `intersect` shape cs `intersect` shape ds `intersect` shape es `intersect` shape fs `intersect` shape gs `intersect` shape hs) (\ix -> f (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix) (es ! ix) (fs ! ix) (gs ! ix) (hs ! ix)) -- | Zip nine arrays with the given function, analogous to 'zipWith'. -- zipWith9 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i, Elt j) => (Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f -> Exp g -> Exp h -> Exp i -> Exp j) -> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) -> Acc (Array sh f) -> Acc (Array sh g) -> Acc (Array sh h) -> Acc (Array sh i) -> Acc (Array sh j) zipWith9 f as bs cs ds es fs gs hs is = generate (shape as `intersect` shape bs `intersect` shape cs `intersect` shape ds `intersect` shape es `intersect` shape fs `intersect` shape gs `intersect` shape hs `intersect` shape is) (\ix -> f (as ! ix) (bs ! ix) (cs ! ix) (ds ! ix) (es ! ix) (fs ! ix) (gs ! ix) (hs ! ix) (is ! ix)) -- | Combine the elements of two arrays pairwise. The shape of the result is -- the intersection of the two argument shapes. -- zip :: (Shape sh, Elt a, Elt b) => Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh (a, b)) zip = zipWith (curry lift) -- | Take three arrays and return an array of triples, analogous to zip. -- zip3 :: (Shape sh, Elt a, Elt b, Elt c) => Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh (a, b, c)) zip3 = zipWith3 (\a b c -> lift (a,b,c)) -- | Take four arrays and return an array of quadruples, analogous to zip. -- zip4 :: (Shape sh, Elt a, Elt b, Elt c, Elt d) => Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh (a, b, c, d)) zip4 = zipWith4 (\a b c d -> lift (a,b,c,d)) -- | Take five arrays and return an array of five-tuples, analogous to zip. -- zip5 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e) => Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) -> Acc (Array sh (a, b, c, d, e)) zip5 = zipWith5 (\a b c d e -> lift (a,b,c,d,e)) -- | Take six arrays and return an array of six-tuples, analogous to zip. -- zip6 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f) => Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) -> Acc (Array sh f) -> Acc (Array sh (a, b, c, d, e, f)) zip6 = zipWith6 (\a b c d e f -> lift (a,b,c,d,e,f)) -- | Take seven arrays and return an array of seven-tuples, analogous to zip. -- zip7 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g) => Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) -> Acc (Array sh f) -> Acc (Array sh g) -> Acc (Array sh (a, b, c, d, e, f, g)) zip7 = zipWith7 (\a b c d e f g -> lift (a,b,c,d,e,f,g)) -- | Take seven arrays and return an array of seven-tuples, analogous to zip. -- zip8 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h) => Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) -> Acc (Array sh f) -> Acc (Array sh g) -> Acc (Array sh h) -> Acc (Array sh (a, b, c, d, e, f, g, h)) zip8 = zipWith8 (\a b c d e f g h -> lift (a,b,c,d,e,f,g,h)) -- | Take seven arrays and return an array of seven-tuples, analogous to zip. -- zip9 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i) => Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) -> Acc (Array sh f) -> Acc (Array sh g) -> Acc (Array sh h) -> Acc (Array sh i) -> Acc (Array sh (a, b, c, d, e, f, g, h, i)) zip9 = zipWith9 (\a b c d e f g h i -> lift (a,b,c,d,e,f,g,h,i)) -- | The converse of 'zip', but the shape of the two results is identical to the -- shape of the argument. -- unzip :: (Shape sh, Elt a, Elt b) => Acc (Array sh (a, b)) -> (Acc (Array sh a), Acc (Array sh b)) unzip arr = (map fst arr, map snd arr) -- | Take an array of triples and return three arrays, analogous to unzip. -- unzip3 :: (Shape sh, Elt a, Elt b, Elt c) => Acc (Array sh (a, b, c)) -> (Acc (Array sh a), Acc (Array sh b), Acc (Array sh c)) unzip3 xs = (map get1 xs, map get2 xs, map get3 xs) where get1 x = let (a,_,_) = untup3 x in a get2 x = let (_,b,_) = untup3 x in b get3 x = let (_,_,c) = untup3 x in c -- | Take an array of quadruples and return four arrays, analogous to unzip. -- unzip4 :: (Shape sh, Elt a, Elt b, Elt c, Elt d) => Acc (Array sh (a, b, c, d)) -> (Acc (Array sh a), Acc (Array sh b), Acc (Array sh c), Acc (Array sh d)) unzip4 xs = (map get1 xs, map get2 xs, map get3 xs, map get4 xs) where get1 x = let (a,_,_,_) = untup4 x in a get2 x = let (_,b,_,_) = untup4 x in b get3 x = let (_,_,c,_) = untup4 x in c get4 x = let (_,_,_,d) = untup4 x in d -- | Take an array of 5-tuples and return five arrays, analogous to unzip. -- unzip5 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e) => Acc (Array sh (a, b, c, d, e)) -> (Acc (Array sh a), Acc (Array sh b), Acc (Array sh c), Acc (Array sh d), Acc (Array sh e)) unzip5 xs = (map get1 xs, map get2 xs, map get3 xs, map get4 xs, map get5 xs) where get1 x = let (a,_,_,_,_) = untup5 x in a get2 x = let (_,b,_,_,_) = untup5 x in b get3 x = let (_,_,c,_,_) = untup5 x in c get4 x = let (_,_,_,d,_) = untup5 x in d get5 x = let (_,_,_,_,e) = untup5 x in e -- | Take an array of 6-tuples and return six arrays, analogous to unzip. -- unzip6 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f) => Acc (Array sh (a, b, c, d, e, f)) -> ( Acc (Array sh a), Acc (Array sh b), Acc (Array sh c) , Acc (Array sh d), Acc (Array sh e), Acc (Array sh f)) unzip6 xs = (map get1 xs, map get2 xs, map get3 xs, map get4 xs, map get5 xs, map get6 xs) where get1 x = let (a,_,_,_,_,_) = untup6 x in a get2 x = let (_,b,_,_,_,_) = untup6 x in b get3 x = let (_,_,c,_,_,_) = untup6 x in c get4 x = let (_,_,_,d,_,_) = untup6 x in d get5 x = let (_,_,_,_,e,_) = untup6 x in e get6 x = let (_,_,_,_,_,f) = untup6 x in f -- | Take an array of 7-tuples and return seven arrays, analogous to unzip. -- unzip7 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g) => Acc (Array sh (a, b, c, d, e, f, g)) -> ( Acc (Array sh a), Acc (Array sh b), Acc (Array sh c) , Acc (Array sh d), Acc (Array sh e), Acc (Array sh f) , Acc (Array sh g)) unzip7 xs = ( map get1 xs, map get2 xs, map get3 xs , map get4 xs, map get5 xs, map get6 xs , map get7 xs ) where get1 x = let (a,_,_,_,_,_,_) = untup7 x in a get2 x = let (_,b,_,_,_,_,_) = untup7 x in b get3 x = let (_,_,c,_,_,_,_) = untup7 x in c get4 x = let (_,_,_,d,_,_,_) = untup7 x in d get5 x = let (_,_,_,_,e,_,_) = untup7 x in e get6 x = let (_,_,_,_,_,f,_) = untup7 x in f get7 x = let (_,_,_,_,_,_,g) = untup7 x in g -- | Take an array of 8-tuples and return eight arrays, analogous to unzip. -- unzip8 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h) => Acc (Array sh (a, b, c, d, e, f, g, h)) -> ( Acc (Array sh a), Acc (Array sh b), Acc (Array sh c) , Acc (Array sh d), Acc (Array sh e), Acc (Array sh f) , Acc (Array sh g), Acc (Array sh h) ) unzip8 xs = ( map get1 xs, map get2 xs, map get3 xs , map get4 xs, map get5 xs, map get6 xs , map get7 xs, map get8 xs ) where get1 x = let (a,_,_,_,_,_,_,_) = untup8 x in a get2 x = let (_,b,_,_,_,_,_,_) = untup8 x in b get3 x = let (_,_,c,_,_,_,_,_) = untup8 x in c get4 x = let (_,_,_,d,_,_,_,_) = untup8 x in d get5 x = let (_,_,_,_,e,_,_,_) = untup8 x in e get6 x = let (_,_,_,_,_,f,_,_) = untup8 x in f get7 x = let (_,_,_,_,_,_,g,_) = untup8 x in g get8 x = let (_,_,_,_,_,_,_,h) = untup8 x in h -- | Take an array of 8-tuples and return eight arrays, analogous to unzip. -- unzip9 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i) => Acc (Array sh (a, b, c, d, e, f, g, h, i)) -> ( Acc (Array sh a), Acc (Array sh b), Acc (Array sh c) , Acc (Array sh d), Acc (Array sh e), Acc (Array sh f) , Acc (Array sh g), Acc (Array sh h), Acc (Array sh i)) unzip9 xs = ( map get1 xs, map get2 xs, map get3 xs , map get4 xs, map get5 xs, map get6 xs , map get7 xs, map get8 xs, map get9 xs ) where get1 x = let (a,_,_,_,_,_,_,_,_) = untup9 x in a get2 x = let (_,b,_,_,_,_,_,_,_) = untup9 x in b get3 x = let (_,_,c,_,_,_,_,_,_) = untup9 x in c get4 x = let (_,_,_,d,_,_,_,_,_) = untup9 x in d get5 x = let (_,_,_,_,e,_,_,_,_) = untup9 x in e get6 x = let (_,_,_,_,_,f,_,_,_) = untup9 x in f get7 x = let (_,_,_,_,_,_,g,_,_) = untup9 x in g get8 x = let (_,_,_,_,_,_,_,h,_) = untup9 x in h get9 x = let (_,_,_,_,_,_,_,_,i) = untup9 x in i -- Reductions -- ---------- -- | Reduction of an array of arbitrary rank to a single scalar value. -- foldAll :: (Shape sh, Elt a) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Array sh a) -> Acc (Scalar a) foldAll f e arr = fold f e (flatten arr) -- | Variant of 'foldAll' that requires the reduced array to be non-empty and -- doesn't need an default value. -- fold1All :: (Shape sh, Elt a) => (Exp a -> Exp a -> Exp a) -> Acc (Array sh a) -> Acc (Scalar a) fold1All f arr = fold1 f (flatten arr) -- Specialised reductions -- ---------------------- -- -- Leave the results of these as scalar arrays to make it clear that these are -- array computations, and thus can not be nested. -- | Check if all elements satisfy a predicate -- all :: (Shape sh, Elt e) => (Exp e -> Exp Bool) -> Acc (Array sh e) -> Acc (Scalar Bool) all f = and . map f -- | Check if any element satisfies the predicate -- any :: (Shape sh, Elt e) => (Exp e -> Exp Bool) -> Acc (Array sh e) -> Acc (Scalar Bool) any f = or . map f -- | Check if all elements are 'True' -- and :: Shape sh => Acc (Array sh Bool) -> Acc (Scalar Bool) and = foldAll (&&*) (constant True) -- | Check if any element is 'True' -- or :: Shape sh => Acc (Array sh Bool) -> Acc (Scalar Bool) or = foldAll (||*) (constant False) -- | Compute the sum of elements -- sum :: (Shape sh, Elt e, IsNum e) => Acc (Array sh e) -> Acc (Scalar e) sum = foldAll (+) 0 -- | Compute the product of the elements -- product :: (Shape sh, Elt e, IsNum e) => Acc (Array sh e) -> Acc (Scalar e) product = foldAll (*) 1 -- | Yield the minimum element of an array. The array must not be empty. -- minimum :: (Shape sh, Elt e, IsScalar e) => Acc (Array sh e) -> Acc (Scalar e) minimum = fold1All min -- | Yield the maximum element of an array. The array must not be empty. -- maximum :: (Shape sh, Elt e, IsScalar e) => Acc (Array sh e) -> Acc (Scalar e) maximum = fold1All max -- Composite scans -- --------------- -- |Left-to-right prescan (aka exclusive scan). As for 'scan', the first argument must be an -- /associative/ function. Denotationally, we have -- -- > prescanl f e = Prelude.fst . scanl' f e -- prescanl :: Elt a => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Vector a) -> Acc (Vector a) prescanl f e = P.fst . scanl' f e -- |Left-to-right postscan, a variant of 'scanl1' with an initial value. Denotationally, we have -- -- > postscanl f e = map (e `f`) . scanl1 f -- postscanl :: Elt a => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Vector a) -> Acc (Vector a) postscanl f e = map (e `f`) . scanl1 f -- |Right-to-left prescan (aka exclusive scan). As for 'scan', the first argument must be an -- /associative/ function. Denotationally, we have -- -- > prescanr f e = Prelude.fst . scanr' f e -- prescanr :: Elt a => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Vector a) -> Acc (Vector a) prescanr f e = P.fst . scanr' f e -- |Right-to-left postscan, a variant of 'scanr1' with an initial value. Denotationally, we have -- -- > postscanr f e = map (e `f`) . scanr1 f -- postscanr :: Elt a => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Vector a) -> Acc (Vector a) postscanr f e = map (`f` e) . scanr1 f -- Segmented scans -- --------------- -- |Segmented version of 'scanl' -- scanlSeg :: (Elt a, Elt i, IsIntegral i) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Vector a) -> Acc (Segments i) -> Acc (Vector a) scanlSeg f z vec seg = scanl1Seg f vec' seg' where -- Segmented exclusive scan is implemented by first injecting the seed -- element at the head of each segment, and then performing a segmented -- inclusive scan. -- -- This is done by creating a creating a vector entirely of the seed -- element, and overlaying the input data in all places other than at the -- start of a segment. -- seg' = map (+1) seg vec' = permute const (fill (index1 $ size vec + size seg) z) (\ix -> index1' $ unindex1' ix + inc ! ix) vec -- Each element in the segments must be shifted to the right one additional -- place for each successive segment, to make room for the seed element. -- Here, we make use of the fact that the vector returned by 'mkHeadFlags' -- contains non-unit entries, which indicate zero length segments. -- flags = mkHeadFlags seg inc = scanl1 (+) flags -- |Segmented version of 'scanl'' -- -- The first element of the resulting tuple is a vector of scanned values. The -- second element is a vector of segment scan totals and has the same size as -- the segment vector. -- scanl'Seg :: forall a i. (Elt a, Elt i, IsIntegral i) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Vector a) -> Acc (Segments i) -> Acc (Vector a, Vector a) scanl'Seg f z vec seg = result where -- Returned the result combined, so that the sub-calculations are shared -- should the user require both results. -- result = lift (body, sums) -- Segmented scan' is implemented by deconstructing a segmented exclusive -- scan, to separate the final value and scan body. -- -- TLM: Segmented scans, and this version in particular, expend a lot of -- effort scanning flag arrays. On inspection it appears that several -- of these operations are duplicated, but this will not be picked up -- by sharing _observation_. Perhaps a global CSE-style pass would be -- beneficial. -- vec' = scanlSeg f z vec seg -- Extract the final reduction value for each segment, which is at the last -- index of each segment. -- seg' = map (+1) seg tails = zipWith (+) seg . P.fst $ scanl' (+) 0 seg' sums = backpermute (shape seg) (\ix -> index1' $ tails ! ix) vec' -- Slice out the body of each segment. -- -- Build a head-flags representation based on the original segment -- descriptor. This contains the target length of each of the body segments, -- which is one fewer element than the actual bodies stored in vec'. Thus, -- the flags align with the last element of each body section, and when -- scanned, this element will be incremented over. -- offset = scanl1 (+) seg inc = scanl1 (+) $ permute (+) (fill (index1 $ size vec + 1) 0) (\ix -> index1' $ offset ! ix) (fill (shape seg) (1 :: Exp i)) body = backpermute (shape vec) (\ix -> index1' $ unindex1' ix + inc ! ix) vec' -- |Segmented version of 'scanl1'. -- scanl1Seg :: (Elt a, Elt i, IsIntegral i) => (Exp a -> Exp a -> Exp a) -> Acc (Vector a) -> Acc (Segments i) -> Acc (Vector a) scanl1Seg f vec seg = P.snd . unzip . scanl1 (segmented f) $ zip (mkHeadFlags seg) vec -- |Segmented version of 'prescanl'. -- prescanlSeg :: (Elt a, Elt i, IsIntegral i) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Vector a) -> Acc (Segments i) -> Acc (Vector a) prescanlSeg f e vec seg = P.fst . unatup2 $ scanl'Seg f e vec seg -- |Segmented version of 'postscanl'. -- postscanlSeg :: (Elt a, Elt i, IsIntegral i) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Vector a) -> Acc (Segments i) -> Acc (Vector a) postscanlSeg f e vec seg = map (f e) $ scanl1Seg f vec seg -- |Segmented version of 'scanr'. -- scanrSeg :: (Elt a, Elt i, IsIntegral i) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Vector a) -> Acc (Segments i) -> Acc (Vector a) scanrSeg f z vec seg = scanr1Seg f vec' seg' where -- Using technique described for 'scanlSeg', where we intersperse the array -- with the seed element at the start of each segment, and then perform an -- inclusive segmented scan. -- inc = scanl1 (+) (mkHeadFlags seg) seg' = map (+1) seg vec' = permute const (fill (index1 $ size vec + size seg) z) (\ix -> index1' $ unindex1' ix + inc ! ix - 1) vec -- | Segmented version of 'scanr''. -- scanr'Seg :: forall a i. (Elt a, Elt i, IsIntegral i) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Vector a) -> Acc (Segments i) -> Acc (Vector a, Vector a) scanr'Seg f z vec seg = result where -- Using technique described for scanl'Seg -- result = lift (body, sums) vec' = scanrSeg f z vec seg -- reduction values seg' = map (+1) seg heads = P.fst $ scanl' (+) 0 seg' sums = backpermute (shape seg) (\ix -> index1' $ heads ! ix) vec' -- body segments inc = scanl1 (+) $ mkHeadFlags seg body = backpermute (shape vec) (\ix -> index1' $ unindex1' ix + inc ! ix) vec' -- |Segmented version of 'scanr1'. -- scanr1Seg :: (Elt a, Elt i, IsIntegral i) => (Exp a -> Exp a -> Exp a) -> Acc (Vector a) -> Acc (Segments i) -> Acc (Vector a) scanr1Seg f vec seg = P.snd . unzip . scanr1 (segmented f) $ zip (mkTailFlags seg) vec -- |Segmented version of 'prescanr'. -- prescanrSeg :: (Elt a, Elt i, IsIntegral i) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Vector a) -> Acc (Segments i) -> Acc (Vector a) prescanrSeg f e vec seg = P.fst . unatup2 $ scanr'Seg f e vec seg -- |Segmented version of 'postscanr'. -- postscanrSeg :: (Elt a, Elt i, IsIntegral i) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Vector a) -> Acc (Segments i) -> Acc (Vector a) postscanrSeg f e vec seg = map (f e) $ scanr1Seg f vec seg -- Segmented scan helpers -- ---------------------- -- |Compute head flags vector from segment vector for left-scans. -- -- The vector will be full of zeros in the body of a segment, and non-zero -- otherwise. The "flag" value, if greater than one, indicates that several -- empty segments are represented by this single flag entry. This is additional -- data is used by exclusive segmented scan. -- mkHeadFlags :: (Elt i, IsIntegral i) => Acc (Segments i) -> Acc (Segments i) mkHeadFlags seg = init $ permute (+) zeros (\ix -> index1' (offset ! ix)) ones where (offset, len) = scanl' (+) 0 seg zeros = fill (index1' $ the len + 1) 0 ones = fill (index1 $ size offset) 1 -- |Compute tail flags vector from segment vector for right-scans. That is, the -- flag is placed at the last place in each segment. -- mkTailFlags :: (Elt i, IsIntegral i) => Acc (Segments i) -> Acc (Segments i) mkTailFlags seg = init $ permute (+) zeros (\ix -> index1' (the len - 1 - offset ! ix)) ones where (offset, len) = scanr' (+) 0 seg zeros = fill (index1' $ the len + 1) 0 ones = fill (index1 $ size offset) 1 -- |Construct a segmented version of a function from a non-segmented version. -- The segmented apply operates on a head-flag value tuple, and follows the -- procedure of Sengupta et. al. -- segmented :: (Elt e, Elt i, IsIntegral i) => (Exp e -> Exp e -> Exp e) -> Exp (i, e) -> Exp (i, e) -> Exp (i, e) segmented f a b = let (aF, aV) = unlift a (bF, bV) = unlift b in lift (aF .|. bF, bF /=* 0 ? (bV, f aV bV)) -- |Index construction and destruction generalised to integral types. -- -- We generalise the segment descriptor to integral types because some -- architectures, such as GPUs, have poor performance for 64-bit types. So, -- there is a tension between performance and requiring 64-bit indices for some -- applications, and we would not like to restrict ourselves to either one. -- -- As we don't yet support non-Int dimensions in shapes, we will need to convert -- back to concrete Int. However, don't put these generalised forms into the -- base library, because it results in too many ambiguity errors. -- index1' :: (Elt i, IsIntegral i) => Exp i -> Exp DIM1 index1' i = lift (Z :. fromIntegral i) unindex1' :: (Elt i, IsIntegral i) => Exp DIM1 -> Exp i unindex1' ix = let Z :. i = unlift ix in fromIntegral i -- Reshaping of arrays -- ------------------- -- | Flattens a given array of arbitrary dimension. -- flatten :: (Shape ix, Elt a) => Acc (Array ix a) -> Acc (Vector a) flatten a = reshape (index1 $ size a) a -- Enumeration and filling -- ----------------------- -- | Create an array where all elements are the same value. -- fill :: (Shape sh, Elt e) => Exp sh -> Exp e -> Acc (Array sh e) fill sh c = generate sh (const c) -- | Create an array of the given shape containing the values x, x+1, etc (in -- row-major order). -- enumFromN :: (Shape sh, Elt e, IsNum e) => Exp sh -> Exp e -> Acc (Array sh e) enumFromN sh x = enumFromStepN sh x 1 -- | Create an array of the given shape containing the values @x@, @x+y@, -- @x+y+y@ etc. (in row-major order). -- enumFromStepN :: (Shape sh, Elt e, IsNum e) => Exp sh -> Exp e -- ^ x: start -> Exp e -- ^ y: step -> Acc (Array sh e) enumFromStepN sh x y = reshape sh $ generate (index1 $ shapeSize sh) (\ix -> (fromIntegral (unindex1 ix :: Exp Int) * y) + x) -- Concatenation -- ------------- -- | Concatenate outermost component of two arrays. The extent of the lower -- dimensional component is the intersection of the two arrays. -- infixr 5 ++ (++) :: forall sh e. (Slice sh, Shape sh, Elt e) => Acc (Array (sh :. Int) e) -> Acc (Array (sh :. Int) e) -> Acc (Array (sh :. Int) e) (++) xs ys = let sh1 :. n = unlift (shape xs) :: Exp sh :. Exp Int sh2 :. m = unlift (shape ys) :: Exp sh :. Exp Int in generate (lift (intersect sh1 sh2 :. n + m)) (\ix -> let sh :. i = unlift ix :: Exp sh :. Exp Int in i <* n ? ( xs ! ix, ys ! lift (sh :. i-n)) ) -- Filtering -- --------- -- | Drop elements that do not satisfy the predicate -- filter :: Elt a => (Exp a -> Exp Bool) -> Acc (Vector a) -> Acc (Vector a) filter p arr = let flags = map (boolToInt . p) arr (targetIdx, len) = scanl' (+) 0 flags arr' = backpermute (index1 $ the len) id arr in permute const arr' (\ix -> flags!ix ==* 0 ? (ignore, index1 $ targetIdx!ix)) arr -- FIXME: This is abusing 'permute' in that the first two arguments are -- only justified because we know the permutation function will -- write to each location in the target exactly once. -- Instead, we should have a primitive that directly encodes the -- compaction pattern of the permutation function. {-# NOINLINE filter #-} {-# RULES "ACC filter/filter" forall f g arr. filter f (filter g arr) = filter (\x -> g x &&* f x) arr #-} -- Gather operations -- ----------------- -- | Copy elements from source array to destination array according to a map. This -- is a backpermute operation where a 'map' vector encodes the output to input -- index mapping. -- -- For example: -- -- > input = [1, 9, 6, 4, 4, 2, 0, 1, 2] -- > from = [1, 3, 7, 2, 5, 3] -- > -- > output = [9, 4, 1, 6, 2, 4] -- gather :: Elt e => Acc (Vector Int) -- ^index mapping -> Acc (Vector e) -- ^input -> Acc (Vector e) -- ^output gather from input = backpermute (shape from) bpF input where bpF ix = index1 (from ! ix) -- | Conditionally copy elements from source array to destination array according -- to an index mapping. This is a backpermute operation where a 'from' vector -- encodes the output to input index mapping. In addition, there is a 'mask' -- vector, and an associated predication function, that specifies whether an -- element will be copied. If not copied, the output array assumes the default -- vector's value. -- -- For example: -- -- > default = [6, 6, 6, 6, 6, 6] -- > from = [1, 3, 7, 2, 5, 3] -- > mask = [3, 4, 9, 2, 7, 5] -- > pred = (>* 4) -- > input = [1, 9, 6, 4, 4, 2, 0, 1, 2] -- > -- > output = [6, 6, 1, 6, 2, 4] -- gatherIf :: (Elt e, Elt e') => Acc (Vector Int) -- ^index mapping -> Acc (Vector e) -- ^mask -> (Exp e -> Exp Bool) -- ^predicate -> Acc (Vector e') -- ^default -> Acc (Vector e') -- ^input -> Acc (Vector e') -- ^output gatherIf from maskV pred defaults input = zipWith zf pf gatheredV where zf p g = p ? (unlift g) gatheredV = zip (gather from input) defaults pf = map pred maskV -- Scatter operations -- ------------------ -- | Copy elements from source array to destination array according to an index -- mapping. This is a forward-permute operation where a 'to' vector encodes an -- input to output index mapping. Output elements for indices that are not -- mapped assume the default vector's value. -- -- For example: -- -- > default = [0, 0, 0, 0, 0, 0, 0, 0, 0] -- > to = [1, 3, 7, 2, 5, 8] -- > input = [1, 9, 6, 4, 4, 2, 5] -- > -- > output = [0, 1, 4, 9, 0, 4, 0, 6, 2] -- -- Note if the same index appears in the index mapping more than once, the -- result is undefined. It does not makes sense for the 'to' vector to be -- larger than the 'input' vector. -- scatter :: Elt e => Acc (Vector Int) -- ^index mapping -> Acc (Vector e) -- ^default -> Acc (Vector e) -- ^input -> Acc (Vector e) -- ^output scatter to defaults input = permute const defaults pf input' where pf ix = index1 (to ! ix) input' = backpermute (shape to `intersect` shape input) id input -- | Conditionally copy elements from source array to destination array according -- to an index mapping. This is a forward-permute operation where a 'to' -- vector encodes an input to output index mapping. In addition, there is a -- 'mask' vector, and an associated predicate function. The mapping will only -- occur if the predicate function applied to the mask at that position -- resolves to 'True'. If not copied, the output array assumes the default -- vector's value. -- -- For example: -- -- > default = [0, 0, 0, 0, 0, 0, 0, 0, 0] -- > to = [1, 3, 7, 2, 5, 8] -- > mask = [3, 4, 9, 2, 7, 5] -- > pred = (>* 4) -- > input = [1, 9, 6, 4, 4, 2, 5] -- > -- > output = [0, 0, 0, 0, 0, 4, 0, 6, 2] -- -- Note if the same index appears in the mapping more than once, the result is -- undefined. The 'to' and 'mask' vectors must be the same length. It does not -- make sense for these to be larger than the 'input' vector. -- scatterIf :: (Elt e, Elt e') => Acc (Vector Int) -- ^index mapping -> Acc (Vector e) -- ^mask -> (Exp e -> Exp Bool) -- ^predicate -> Acc (Vector e') -- ^default -> Acc (Vector e') -- ^input -> Acc (Vector e') -- ^output scatterIf to maskV pred defaults input = permute const defaults pf input' where pf ix = pred (maskV ! ix) ? ( index1 (to ! ix), ignore ) input' = backpermute (shape to `intersect` shape input) id input -- Permutations -- ------------ -- | Reverse the elements of a vector. -- reverse :: Elt e => Acc (Vector e) -> Acc (Vector e) reverse xs = let len = unindex1 (shape xs) pf i = len - i - 1 in backpermute (shape xs) (ilift1 pf) xs -- | Transpose the rows and columns of a matrix. -- transpose :: Elt e => Acc (Array DIM2 e) -> Acc (Array DIM2 e) transpose mat = let swap = lift1 $ \(Z:.x:.y) -> Z:.y:.x :: Z:.Exp Int:.Exp Int in backpermute (swap $ shape mat) swap mat -- Extracting sub-vectors -- ---------------------- -- | Yield the first @n@ elements of the input vector. The vector must contain -- no more than @n@ elements. -- take :: Elt e => Exp Int -> Acc (Vector e) -> Acc (Vector e) take n = let n' = the (unit n) in backpermute (index1 n') id -- | Yield all but the first @n@ elements of the input vector. The vector must -- contain no fewer than @n@ elements. -- drop :: Elt e => Exp Int -> Acc (Vector e) -> Acc (Vector e) drop n arr = let n' = the (unit n) in backpermute (ilift1 (subtract n') (shape arr)) (ilift1 (+ n')) arr -- | Yield all but the last element of the input vector. The vector must not be -- empty. -- init :: Elt e => Acc (Vector e) -> Acc (Vector e) init arr = backpermute (ilift1 (subtract 1) (shape arr)) id arr -- | Yield all but the first element of the input vector. The vector must not be -- empty. -- tail :: Elt e => Acc (Vector e) -> Acc (Vector e) tail arr = backpermute (ilift1 (subtract 1) (shape arr)) (ilift1 (+1)) arr -- | Yield a slit (slice) from the vector. The vector must contain at least -- @i + n@ elements. Denotationally, we have: -- -- > slit i n = take n . drop i -- slit :: Elt e => Exp Int -> Exp Int -> Acc (Vector e) -> Acc (Vector e) slit i n = let i' = the (unit i) n' = the (unit n) in backpermute (index1 n') (ilift1 (+ i')) -- Flow control -- ------------ -- | Infix version of 'acond'. If the predicate evaluates to 'True', the first -- component of the tuple is returned, else the second. -- infix 0 ?| (?|) :: (Arrays a) => Exp Bool -> (Acc a, Acc a) -> Acc a c ?| (t, e) = acond c t e -- | An infix version of 'cond'. If the predicate evaluates to 'True', the first -- component of the tuple is returned, else the second. -- infix 0 ? (?) :: Elt t => Exp Bool -> (Exp t, Exp t) -> Exp t c ? (t, e) = cond c t e -- | A case-like control structure -- caseof :: (Elt a, Elt b) => Exp a -- ^ case subject -> [(Exp a -> Exp Bool, Exp b)] -- ^ list of cases to attempt -> Exp b -- ^ default value -> Exp b caseof _ [] e = e caseof x ((p,b):l) e = cond (p x) b (caseof x l e) -- Scalar iteration -- ---------------- -- | Repeatedly apply a function a fixed number of times -- iterate :: forall a. Elt a => Exp Int -> (Exp a -> Exp a) -> Exp a -> Exp a iterate n f z = let step :: (Exp Int, Exp a) -> (Exp Int, Exp a) step (i, acc) = ( i+1, f acc ) in snd $ while (\v -> fst v <* n) (lift1 step) (lift (constant 0, z)) -- Scalar bulk operations -- ---------------------- -- | Reduce along an innermost slice of an array /sequentially/, by applying a -- binary operator to a starting value and the array from left to right. -- sfoldl :: forall sh a b. (Shape sh, Slice sh, Elt a, Elt b) => (Exp a -> Exp b -> Exp a) -> Exp a -> Exp sh -> Acc (Array (sh :. Int) b) -> Exp a sfoldl f z ix xs = let step :: (Exp Int, Exp a) -> (Exp Int, Exp a) step (i, acc) = ( i+1, acc `f` (xs ! lift (ix :. i)) ) (_ :. n) = unlift (shape xs) :: Exp sh :. Exp Int in snd $ while (\v -> fst v <* n) (lift1 step) (lift (constant 0, z)) -- Lifting surface expressions -- --------------------------- -- | The class of types @e@ which can be lifted into @c@. class Lift c e where -- | An associated-type (i.e. a type-level function) that strips all -- instances of surface type constructors @c@ from the input type @e@. -- -- For example, the tuple types @(Exp Int, Int)@ and @(Int, Exp -- Int)@ have the same \"Plain\" representation. That is, the -- following type equality holds: -- -- @Plain (Exp Int, Int) ~ (Int,Int) ~ Plain (Int, Exp Int)@ type Plain e -- | Lift the given value into a surface type 'c' --- either 'Exp' for scalar -- expressions or 'Acc' for array computations. The value may already contain -- subexpressions in 'c'. -- lift :: e -> c (Plain e) -- | A limited subset of types which can be lifted, can also be unlifted. class Lift c e => Unlift c e where -- | Unlift the outermost constructor through the surface type. This is only -- possible if the constructor is fully determined by its type - i.e., it is a -- singleton. -- unlift :: c (Plain e) -> e -- instances for indices instance Lift Exp () where type Plain () = () lift _ = Exp $ Tuple NilTup instance Unlift Exp () where unlift _ = () instance Lift Exp Z where type Plain Z = Z lift _ = Exp $ IndexNil instance Unlift Exp Z where unlift _ = Z instance (Slice (Plain ix), Lift Exp ix) => Lift Exp (ix :. Int) where type Plain (ix :. Int) = Plain ix :. Int lift (ix:.i) = Exp $ IndexCons (lift ix) (Exp $ Const i) instance (Slice (Plain ix), Lift Exp ix) => Lift Exp (ix :. All) where type Plain (ix :. All) = Plain ix :. All lift (ix:.i) = Exp $ IndexCons (lift ix) (Exp $ Const i) instance (Elt e, Slice (Plain ix), Lift Exp ix) => Lift Exp (ix :. Exp e) where type Plain (ix :. Exp e) = Plain ix :. e lift (ix:.i) = Exp $ IndexCons (lift ix) i instance (Elt e, Slice (Plain ix), Unlift Exp ix) => Unlift Exp (ix :. Exp e) where unlift e = unlift (Exp $ IndexTail e) :. Exp (IndexHead e) instance (Elt e, Slice ix) => Unlift Exp (Exp ix :. Exp e) where unlift e = (Exp $ IndexTail e) :. Exp (IndexHead e) instance Shape sh => Lift Exp (Any sh) where type Plain (Any sh) = Any sh lift Any = Exp $ IndexAny -- instances for numeric types instance Lift Exp Int where type Plain Int = Int lift = Exp . Const instance Lift Exp Int8 where type Plain Int8 = Int8 lift = Exp . Const instance Lift Exp Int16 where type Plain Int16 = Int16 lift = Exp . Const instance Lift Exp Int32 where type Plain Int32 = Int32 lift = Exp . Const instance Lift Exp Int64 where type Plain Int64 = Int64 lift = Exp . Const instance Lift Exp Word where type Plain Word = Word lift = Exp . Const instance Lift Exp Word8 where type Plain Word8 = Word8 lift = Exp . Const instance Lift Exp Word16 where type Plain Word16 = Word16 lift = Exp . Const instance Lift Exp Word32 where type Plain Word32 = Word32 lift = Exp . Const instance Lift Exp Word64 where type Plain Word64 = Word64 lift = Exp . Const instance Lift Exp CShort where type Plain CShort = CShort lift = Exp . Const instance Lift Exp CUShort where type Plain CUShort = CUShort lift = Exp . Const instance Lift Exp CInt where type Plain CInt = CInt lift = Exp . Const instance Lift Exp CUInt where type Plain CUInt = CUInt lift = Exp . Const instance Lift Exp CLong where type Plain CLong = CLong lift = Exp . Const instance Lift Exp CULong where type Plain CULong = CULong lift = Exp . Const instance Lift Exp CLLong where type Plain CLLong = CLLong lift = Exp . Const instance Lift Exp CULLong where type Plain CULLong = CULLong lift = Exp . Const instance Lift Exp Float where type Plain Float = Float lift = Exp . Const instance Lift Exp Double where type Plain Double = Double lift = Exp . Const instance Lift Exp CFloat where type Plain CFloat = CFloat lift = Exp . Const instance Lift Exp CDouble where type Plain CDouble = CDouble lift = Exp . Const instance Lift Exp Bool where type Plain Bool = Bool lift = Exp . Const instance Lift Exp Char where type Plain Char = Char lift = Exp . Const instance Lift Exp CChar where type Plain CChar = CChar lift = Exp . Const instance Lift Exp CSChar where type Plain CSChar = CSChar lift = Exp . Const instance Lift Exp CUChar where type Plain CUChar = CUChar lift = Exp . Const -- Instances for tuples instance (Lift Exp a, Lift Exp b, Elt (Plain a), Elt (Plain b)) => Lift Exp (a, b) where type Plain (a, b) = (Plain a, Plain b) lift (x, y) = tup2 (lift x, lift y) instance (Elt a, Elt b) => Unlift Exp (Exp a, Exp b) where unlift = untup2 instance (Lift Exp a, Lift Exp b, Lift Exp c, Elt (Plain a), Elt (Plain b), Elt (Plain c)) => Lift Exp (a, b, c) where type Plain (a, b, c) = (Plain a, Plain b, Plain c) lift (x, y, z) = tup3 (lift x, lift y, lift z) instance (Elt a, Elt b, Elt c) => Unlift Exp (Exp a, Exp b, Exp c) where unlift = untup3 instance (Lift Exp a, Lift Exp b, Lift Exp c, Lift Exp d, Elt (Plain a), Elt (Plain b), Elt (Plain c), Elt (Plain d)) => Lift Exp (a, b, c, d) where type Plain (a, b, c, d) = (Plain a, Plain b, Plain c, Plain d) lift (x, y, z, u) = tup4 (lift x, lift y, lift z, lift u) instance (Elt a, Elt b, Elt c, Elt d) => Unlift Exp (Exp a, Exp b, Exp c, Exp d) where unlift = untup4 instance (Lift Exp a, Lift Exp b, Lift Exp c, Lift Exp d, Lift Exp e, Elt (Plain a), Elt (Plain b), Elt (Plain c), Elt (Plain d), Elt (Plain e)) => Lift Exp (a, b, c, d, e) where type Plain (a, b, c, d, e) = (Plain a, Plain b, Plain c, Plain d, Plain e) lift (x, y, z, u, v) = tup5 (lift x, lift y, lift z, lift u, lift v) instance (Elt a, Elt b, Elt c, Elt d, Elt e) => Unlift Exp (Exp a, Exp b, Exp c, Exp d, Exp e) where unlift = untup5 instance (Lift Exp a, Lift Exp b, Lift Exp c, Lift Exp d, Lift Exp e, Lift Exp f, Elt (Plain a), Elt (Plain b), Elt (Plain c), Elt (Plain d), Elt (Plain e), Elt (Plain f)) => Lift Exp (a, b, c, d, e, f) where type Plain (a, b, c, d, e, f) = (Plain a, Plain b, Plain c, Plain d, Plain e, Plain f) lift (x, y, z, u, v, w) = tup6 (lift x, lift y, lift z, lift u, lift v, lift w) instance (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f) => Unlift Exp (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f) where unlift = untup6 instance (Lift Exp a, Lift Exp b, Lift Exp c, Lift Exp d, Lift Exp e, Lift Exp f, Lift Exp g, Elt (Plain a), Elt (Plain b), Elt (Plain c), Elt (Plain d), Elt (Plain e), Elt (Plain f), Elt (Plain g)) => Lift Exp (a, b, c, d, e, f, g) where type Plain (a, b, c, d, e, f, g) = (Plain a, Plain b, Plain c, Plain d, Plain e, Plain f, Plain g) lift (x, y, z, u, v, w, r) = tup7 (lift x, lift y, lift z, lift u, lift v, lift w, lift r) instance (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g) => Unlift Exp (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g) where unlift = untup7 instance (Lift Exp a, Lift Exp b, Lift Exp c, Lift Exp d, Lift Exp e, Lift Exp f, Lift Exp g, Lift Exp h, Elt (Plain a), Elt (Plain b), Elt (Plain c), Elt (Plain d), Elt (Plain e), Elt (Plain f), Elt (Plain g), Elt (Plain h)) => Lift Exp (a, b, c, d, e, f, g, h) where type Plain (a, b, c, d, e, f, g, h) = (Plain a, Plain b, Plain c, Plain d, Plain e, Plain f, Plain g, Plain h) lift (x, y, z, u, v, w, r, s) = tup8 (lift x, lift y, lift z, lift u, lift v, lift w, lift r, lift s) instance (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h) => Unlift Exp (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h) where unlift = untup8 instance (Lift Exp a, Lift Exp b, Lift Exp c, Lift Exp d, Lift Exp e, Lift Exp f, Lift Exp g, Lift Exp h, Lift Exp i, Elt (Plain a), Elt (Plain b), Elt (Plain c), Elt (Plain d), Elt (Plain e), Elt (Plain f), Elt (Plain g), Elt (Plain h), Elt (Plain i)) => Lift Exp (a, b, c, d, e, f, g, h, i) where type Plain (a, b, c, d, e, f, g, h, i) = (Plain a, Plain b, Plain c, Plain d, Plain e, Plain f, Plain g, Plain h, Plain i) lift (x, y, z, u, v, w, r, s, t) = tup9 (lift x, lift y, lift z, lift u, lift v, lift w, lift r, lift s, lift t) instance (Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt g, Elt h, Elt i) => Unlift Exp (Exp a, Exp b, Exp c, Exp d, Exp e, Exp f, Exp g, Exp h, Exp i) where unlift = untup9 -- Instance for scalar Accelerate expressions instance Lift Exp (Exp e) where type Plain (Exp e) = e lift = id -- Instance for Accelerate array computations instance Lift Acc (Acc a) where type Plain (Acc a) = a lift = id -- Instances for Arrays class --instance Lift Acc () where -- type Plain () = () -- lift _ = Acc (Atuple NilAtup) instance (Shape sh, Elt e) => Lift Acc (Array sh e) where type Plain (Array sh e) = Array sh e lift = Acc . Use instance (Lift Acc a, Lift Acc b, Arrays (Plain a), Arrays (Plain b)) => Lift Acc (a, b) where type Plain (a, b) = (Plain a, Plain b) lift (x, y) = atup2 (lift x, lift y) instance (Arrays a, Arrays b) => Unlift Acc (Acc a, Acc b) where unlift = unatup2 instance (Lift Acc a, Lift Acc b, Lift Acc c, Arrays (Plain a), Arrays (Plain b), Arrays (Plain c)) => Lift Acc (a, b, c) where type Plain (a, b, c) = (Plain a, Plain b, Plain c) lift (x, y, z) = atup3 (lift x, lift y, lift z) instance (Arrays a, Arrays b, Arrays c) => Unlift Acc (Acc a, Acc b, Acc c) where unlift = unatup3 instance (Lift Acc a, Lift Acc b, Lift Acc c, Lift Acc d, Arrays (Plain a), Arrays (Plain b), Arrays (Plain c), Arrays (Plain d)) => Lift Acc (a, b, c, d) where type Plain (a, b, c, d) = (Plain a, Plain b, Plain c, Plain d) lift (x, y, z, u) = atup4 (lift x, lift y, lift z, lift u) instance (Arrays a, Arrays b, Arrays c, Arrays d) => Unlift Acc (Acc a, Acc b, Acc c, Acc d) where unlift = unatup4 instance (Lift Acc a, Lift Acc b, Lift Acc c, Lift Acc d, Lift Acc e, Arrays (Plain a), Arrays (Plain b), Arrays (Plain c), Arrays (Plain d), Arrays (Plain e)) => Lift Acc (a, b, c, d, e) where type Plain (a, b, c, d, e) = (Plain a, Plain b, Plain c, Plain d, Plain e) lift (x, y, z, u, v) = atup5 (lift x, lift y, lift z, lift u, lift v) instance (Arrays a, Arrays b, Arrays c, Arrays d, Arrays e) => Unlift Acc (Acc a, Acc b, Acc c, Acc d, Acc e) where unlift = unatup5 instance (Lift Acc a, Lift Acc b, Lift Acc c, Lift Acc d, Lift Acc e, Lift Acc f, Arrays (Plain a), Arrays (Plain b), Arrays (Plain c), Arrays (Plain d), Arrays (Plain e), Arrays (Plain f)) => Lift Acc (a, b, c, d, e, f) where type Plain (a, b, c, d, e, f) = (Plain a, Plain b, Plain c, Plain d, Plain e, Plain f) lift (x, y, z, u, v, w) = atup6 (lift x, lift y, lift z, lift u, lift v, lift w) instance (Arrays a, Arrays b, Arrays c, Arrays d, Arrays e, Arrays f) => Unlift Acc (Acc a, Acc b, Acc c, Acc d, Acc e, Acc f) where unlift = unatup6 instance (Lift Acc a, Lift Acc b, Lift Acc c, Lift Acc d, Lift Acc e, Lift Acc f, Lift Acc g, Arrays (Plain a), Arrays (Plain b), Arrays (Plain c), Arrays (Plain d), Arrays (Plain e), Arrays (Plain f), Arrays (Plain g)) => Lift Acc (a, b, c, d, e, f, g) where type Plain (a, b, c, d, e, f, g) = (Plain a, Plain b, Plain c, Plain d, Plain e, Plain f, Plain g) lift (x, y, z, u, v, w, r) = atup7 (lift x, lift y, lift z, lift u, lift v, lift w, lift r) instance (Arrays a, Arrays b, Arrays c, Arrays d, Arrays e, Arrays f, Arrays g) => Unlift Acc (Acc a, Acc b, Acc c, Acc d, Acc e, Acc f, Acc g) where unlift = unatup7 instance (Lift Acc a, Lift Acc b, Lift Acc c, Lift Acc d, Lift Acc e, Lift Acc f, Lift Acc g, Lift Acc h, Arrays (Plain a), Arrays (Plain b), Arrays (Plain c), Arrays (Plain d), Arrays (Plain e), Arrays (Plain f), Arrays (Plain g), Arrays (Plain h)) => Lift Acc (a, b, c, d, e, f, g, h) where type Plain (a, b, c, d, e, f, g, h) = (Plain a, Plain b, Plain c, Plain d, Plain e, Plain f, Plain g, Plain h) lift (x, y, z, u, v, w, r, s) = atup8 (lift x, lift y, lift z, lift u, lift v, lift w, lift r, lift s) instance (Arrays a, Arrays b, Arrays c, Arrays d, Arrays e, Arrays f, Arrays g, Arrays h) => Unlift Acc (Acc a, Acc b, Acc c, Acc d, Acc e, Acc f, Acc g, Acc h) where unlift = unatup8 instance (Lift Acc a, Lift Acc b, Lift Acc c, Lift Acc d, Lift Acc e, Lift Acc f, Lift Acc g, Lift Acc h, Lift Acc i, Arrays (Plain a), Arrays (Plain b), Arrays (Plain c), Arrays (Plain d), Arrays (Plain e), Arrays (Plain f), Arrays (Plain g), Arrays (Plain h), Arrays (Plain i)) => Lift Acc (a, b, c, d, e, f, g, h, i) where type Plain (a, b, c, d, e, f, g, h, i) = (Plain a, Plain b, Plain c, Plain d, Plain e, Plain f, Plain g, Plain h, Plain i) lift (x, y, z, u, v, w, r, s, t) = atup9 (lift x, lift y, lift z, lift u, lift v, lift w, lift r, lift s, lift t) instance (Arrays a, Arrays b, Arrays c, Arrays d, Arrays e, Arrays f, Arrays g, Arrays h, Arrays i) => Unlift Acc (Acc a, Acc b, Acc c, Acc d, Acc e, Acc f, Acc g, Acc h, Acc i) where unlift = unatup9 -- |Lift a unary function into 'Exp'. -- lift1 :: (Unlift Exp e1, Lift Exp e2) => (e1 -> e2) -> Exp (Plain e1) -> Exp (Plain e2) lift1 f = lift . f . unlift -- |Lift a binary function into 'Exp'. -- lift2 :: (Unlift Exp e1, Unlift Exp e2, Lift Exp e3) => (e1 -> e2 -> e3) -> Exp (Plain e1) -> Exp (Plain e2) -> Exp (Plain e3) lift2 f x y = lift $ f (unlift x) (unlift y) -- |Lift a unary function to a computation over rank-1 indices. -- ilift1 :: (Exp Int -> Exp Int) -> Exp DIM1 -> Exp DIM1 ilift1 f = lift1 (\(Z:.i) -> Z :. f i) -- |Lift a binary function to a computation over rank-1 indices. -- ilift2 :: (Exp Int -> Exp Int -> Exp Int) -> Exp DIM1 -> Exp DIM1 -> Exp DIM1 ilift2 f = lift2 (\(Z:.i) (Z:.j) -> Z :. f i j) -- Tuples -- ------ -- |Extract the first component of a scalar pair. -- fst :: forall a b. (Elt a, Elt b) => Exp (a, b) -> Exp a fst e = let (x, _::Exp b) = unlift e in x -- |Extract the first component of an array pair. afst :: forall a b. (Arrays a, Arrays b) => Acc (a, b) -> Acc a afst a = let (x, _::Acc b) = unlift a in x -- |Extract the second component of a scalar pair. -- snd :: forall a b. (Elt a, Elt b) => Exp (a, b) -> Exp b snd e = let (_:: Exp a, y) = unlift e in y -- | Extract the second component of an array pair asnd :: forall a b. (Arrays a, Arrays b) => Acc (a, b) -> Acc b asnd a = let (_::Acc a, y) = unlift a in y -- |Converts an uncurried function to a curried function. -- curry :: Lift f (f a, f b) => (f (Plain (f a), Plain (f b)) -> f c) -> f a -> f b -> f c curry f x y = f (lift (x, y)) -- |Converts a curried function to a function on pairs. -- uncurry :: Unlift f (f a, f b) => (f a -> f b -> f c) -> f (Plain (f a), Plain (f b)) -> f c uncurry f t = let (x, y) = unlift t in f x y -- Shapes and indices -- ------------------ -- |The one index for a rank-0 array. -- index0 :: Exp Z index0 = lift Z -- |Turn an 'Int' expression into a rank-1 indexing expression. -- index1 :: Elt i => Exp i -> Exp (Z :. i) index1 i = lift (Z :. i) -- |Turn a rank-1 indexing expression into an 'Int' expression. -- unindex1 :: Elt i => Exp (Z :. i) -> Exp i unindex1 ix = let Z :. i = unlift ix in i -- | Creates a rank-2 index from two Exp Int`s -- index2 :: (Elt i, Slice (Z :. i)) => Exp i -> Exp i -> Exp (Z :. i :. i) index2 i j = lift (Z :. i :. j) -- | Destructs a rank-2 index to an Exp tuple of two Int`s. -- unindex2 :: forall i. (Elt i, Slice (Z :. i)) => Exp (Z :. i :. i) -> Exp (i, i) unindex2 ix = let Z :. i :. j = unlift ix :: Z :. Exp i :. Exp i in lift (i, j) -- Array operations with a scalar result -- ------------------------------------- -- |Extraction of the element in a singleton array -- the :: Elt e => Acc (Scalar e) -> Exp e the = (!index0) -- |Test whether an array is empty -- null :: (Shape ix, Elt e) => Acc (Array ix e) -> Exp Bool null arr = size arr ==* 0 -- |Get the length of a vector -- length :: Elt e => Acc (Vector e) -> Exp Int length = unindex1 . shape