{-# LANGUAGE TypeOperators, ScopedTypeVariables #-} -- | -- Module : Data.Array.Accelerate.Prelude -- Copyright : [2010..2011] Manuel M T Chakravarty, Gabriele Keller, Ben Lever -- [2009..2012] Manuel M T Chakravarty, Gabriele Keller, Trevor L. McDonell -- License : BSD3 -- -- Maintainer : Manuel M T Chakravarty -- Stability : experimental -- Portability : non-portable (GHC extensions) -- -- Standard functions that are not part of the core set (directly represented in the AST), but are -- instead implemented in terms of the core set. -- module Data.Array.Accelerate.Prelude ( -- * Zipping zipWith3, zipWith4, zip, zip3, zip4, -- * Unzipping unzip, unzip3, unzip4, -- * Reductions foldAll, fold1All, -- ** Specialised folds all, any, and, or, sum, product, minimum, maximum, -- * Scans prescanl, postscanl, prescanr, postscanr, -- ** Segmented scans scanlSeg, scanl'Seg, scanl1Seg, prescanlSeg, postscanlSeg, scanrSeg, scanr'Seg, scanr1Seg, prescanrSeg, postscanrSeg, -- * Shape manipulation flatten, -- * Enumeration and filling fill, enumFromN, enumFromStepN, -- * Working with predicates -- ** Filtering filter, -- ** Scatter / Gather scatter, scatterIf, gather, gatherIf, -- * Permutations reverse, transpose, -- * Extracting sub-vectors init, tail, take, drop, slit ) where -- avoid clashes with Prelude functions -- import Data.Bits import Data.Bool import Prelude ((.), ($), (+), (-), (*), const, subtract, id) import qualified Prelude as P -- friends import Data.Array.Accelerate.Array.Sugar hiding ((!), ignore, shape, size) import Data.Array.Accelerate.Language import Data.Array.Accelerate.Smart import Data.Array.Accelerate.Type -- Map-like composites -- ------------------- -- | Zip three arrays with the given function -- zipWith3 :: (Shape sh, Elt a, Elt b, Elt c, Elt d) => (Exp a -> Exp b -> Exp c -> Exp d) -> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) zipWith3 f as bs cs = map (\x -> let (a,b,c) = unlift x in f a b c) $ zip3 as bs cs -- | Zip four arrays with the given function -- zipWith4 :: (Shape sh, Elt a, Elt b, Elt c, Elt d, Elt e) => (Exp a -> Exp b -> Exp c -> Exp d -> Exp e) -> Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh e) zipWith4 f as bs cs ds = map (\x -> let (a,b,c,d) = unlift x in f a b c d) $ zip4 as bs cs ds -- | Combine the elements of two arrays pairwise. The shape of the result is -- the intersection of the two argument shapes. -- zip :: (Shape sh, Elt a, Elt b) => Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh (a, b)) zip = zipWith (curry lift) -- | Take three arrays and return an array of triples, analogous to zip. -- zip3 :: forall sh. forall a. forall b. forall c. (Shape sh, Elt a, Elt b, Elt c) => Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh (a, b, c)) zip3 as bs cs = zipWith (\a bc -> let (b, c) = unlift bc :: (Exp b, Exp c) in lift (a, b, c)) as $ zip bs cs -- | Take four arrays and return an array of quadruples, analogous to zip. -- zip4 :: forall sh. forall a. forall b. forall c. forall d. (Shape sh, Elt a, Elt b, Elt c, Elt d) => Acc (Array sh a) -> Acc (Array sh b) -> Acc (Array sh c) -> Acc (Array sh d) -> Acc (Array sh (a, b, c, d)) zip4 as bs cs ds = zipWith (\a bcd -> let (b, c, d) = unlift bcd :: (Exp b, Exp c, Exp d) in lift (a, b, c, d)) as $ zip3 bs cs ds -- | The converse of 'zip', but the shape of the two results is identical to the -- shape of the argument. -- unzip :: (Shape sh, Elt a, Elt b) => Acc (Array sh (a, b)) -> (Acc (Array sh a), Acc (Array sh b)) unzip arr = (map fst arr, map snd arr) -- | Take an array of triples and return three arrays, analogous to unzip. -- unzip3 :: (Shape sh, Elt a, Elt b, Elt c) => Acc (Array sh (a, b, c)) -> (Acc (Array sh a), Acc (Array sh b), Acc (Array sh c)) unzip3 xs = (map get1 xs, map get2 xs, map get3 xs) where get1 :: forall a b c. (Elt a, Elt b, Elt c) => Exp (a,b,c) -> Exp a get1 x = let (a, _ :: Exp b, _ :: Exp c) = unlift x in a get2 :: forall a b c. (Elt a, Elt b, Elt c) => Exp (a,b,c) -> Exp b get2 x = let (_ :: Exp a, b, _ :: Exp c) = unlift x in b get3 :: forall a b c. (Elt a, Elt b, Elt c) => Exp (a,b,c) -> Exp c get3 x = let (_ :: Exp a, _ :: Exp b, c) = unlift x in c -- | Take an array of quadruples and return four arrays, analogous to unzip. -- unzip4 :: (Shape sh, Elt a, Elt b, Elt c, Elt d) => Acc (Array sh (a, b, c, d)) -> (Acc (Array sh a), Acc (Array sh b), Acc (Array sh c), Acc (Array sh d)) unzip4 xs = (map get1 xs, map get2 xs, map get3 xs, map get4 xs) where get1 :: forall a b c d. (Elt a, Elt b, Elt c, Elt d) => Exp (a,b,c,d) -> Exp a get1 x = let (a, _ :: Exp b, _ :: Exp c, _ :: Exp d) = unlift x in a get2 :: forall a b c d. (Elt a, Elt b, Elt c, Elt d) => Exp (a,b,c,d) -> Exp b get2 x = let (_ :: Exp a, b, _ :: Exp c, _ :: Exp d) = unlift x in b get3 :: forall a b c d. (Elt a, Elt b, Elt c, Elt d) => Exp (a,b,c,d) -> Exp c get3 x = let (_ :: Exp a, _ :: Exp b, c, _ :: Exp d) = unlift x in c get4 :: forall a b c d. (Elt a, Elt b, Elt c, Elt d) => Exp (a,b,c,d) -> Exp d get4 x = let (_ :: Exp a, _ :: Exp b, _ :: Exp c, d) = unlift x in d -- Reductions -- ---------- -- | Reduction of an array of arbitrary rank to a single scalar value. -- foldAll :: (Shape sh, Elt a) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Array sh a) -> Acc (Scalar a) foldAll f e arr = fold f e (flatten arr) -- | Variant of 'foldAll' that requires the reduced array to be non-empty and -- doesn't need an default value. -- fold1All :: (Shape sh, Elt a) => (Exp a -> Exp a -> Exp a) -> Acc (Array sh a) -> Acc (Scalar a) fold1All f arr = fold1 f (flatten arr) -- Specialised reductions -- ---------------------- -- -- Leave the results of these as scalar arrays to make it clear that these are -- array computations, and thus can not be nested. -- | Check if all elements satisfy a predicate -- all :: (Shape sh, Elt e) => (Exp e -> Exp Bool) -> Acc (Array sh e) -> Acc (Scalar Bool) all f = and . map f -- | Check if any element satisfies the predicate -- any :: (Shape sh, Elt e) => (Exp e -> Exp Bool) -> Acc (Array sh e) -> Acc (Scalar Bool) any f = or . map f -- | Check if all elements are 'True' -- and :: Shape sh => Acc (Array sh Bool) -> Acc (Scalar Bool) and = foldAll (&&*) (constant True) -- | Check if any element is 'True' -- or :: Shape sh => Acc (Array sh Bool) -> Acc (Scalar Bool) or = foldAll (||*) (constant False) -- | Compute the sum of elements -- sum :: (Shape sh, Elt e, IsNum e) => Acc (Array sh e) -> Acc (Scalar e) sum = foldAll (+) 0 -- | Compute the product of the elements -- product :: (Shape sh, Elt e, IsNum e) => Acc (Array sh e) -> Acc (Scalar e) product = foldAll (*) 1 -- | Yield the minimum element of an array. The array must not be empty. -- minimum :: (Shape sh, Elt e, IsScalar e) => Acc (Array sh e) -> Acc (Scalar e) minimum = fold1All min -- | Yield the maximum element of an array. The array must not be empty. -- maximum :: (Shape sh, Elt e, IsScalar e) => Acc (Array sh e) -> Acc (Scalar e) maximum = fold1All max -- Composite scans -- --------------- -- |Left-to-right prescan (aka exclusive scan). As for 'scan', the first argument must be an -- /associative/ function. Denotationally, we have -- -- > prescanl f e = Prelude.fst . scanl' f e -- prescanl :: Elt a => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Vector a) -> Acc (Vector a) prescanl f e = P.fst . scanl' f e -- |Left-to-right postscan, a variant of 'scanl1' with an initial value. Denotationally, we have -- -- > postscanl f e = map (e `f`) . scanl1 f -- postscanl :: Elt a => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Vector a) -> Acc (Vector a) postscanl f e = map (e `f`) . scanl1 f -- |Right-to-left prescan (aka exclusive scan). As for 'scan', the first argument must be an -- /associative/ function. Denotationally, we have -- -- > prescanr f e = Prelude.fst . scanr' f e -- prescanr :: Elt a => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Vector a) -> Acc (Vector a) prescanr f e = P.fst . scanr' f e -- |Right-to-left postscan, a variant of 'scanr1' with an initial value. Denotationally, we have -- -- > postscanr f e = map (e `f`) . scanr1 f -- postscanr :: Elt a => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Vector a) -> Acc (Vector a) postscanr f e = map (`f` e) . scanr1 f -- Segmented scans -- --------------- -- |Segmented version of 'scanl' -- scanlSeg :: (Elt a, Elt i, IsIntegral i) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Vector a) -> Acc (Segments i) -> Acc (Vector a) scanlSeg f z vec seg = scanl1Seg f vec' seg' where -- Segmented exclusive scan is implemented by first injecting the seed -- element at the head of each segment, and then performing a segmented -- inclusive scan. -- -- This is done by creating a creating a vector entirely of the seed -- element, and overlaying the input data in all places other than at the -- start of a segment. -- seg' = map (+1) seg vec' = permute const (fill (index1 $ size vec + size seg) z) (\ix -> index1' $ unindex1' ix + inc ! ix) vec -- Each element in the segments must be shifted to the right one additional -- place for each successive segment, to make room for the seed element. -- Here, we make use of the fact that the vector returned by 'mkHeadFlags' -- contains non-unit entries, which indicate zero length segments. -- flags = mkHeadFlags seg inc = scanl1 (+) flags -- |Segmented version of 'scanl'' -- -- The first element of the resulting tuple is a vector of scanned values. The -- second element is a vector of segment scan totals and has the same size as -- the segment vector. -- scanl'Seg :: forall a i. (Elt a, Elt i, IsIntegral i) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Vector a) -> Acc (Segments i) -> Acc (Vector a, Vector a) scanl'Seg f z vec seg = result where -- Returned the result combined, so that the sub-calculations are shared -- should the user require both results. -- result = lift (body, sums) -- Segmented scan' is implemented by deconstructing a segmented exclusive -- scan, to separate the final value and scan body. -- -- TLM: Segmented scans, and this version in particular, expend a lot of -- effort scanning flag arrays. On inspection it appears that several -- of these operations are duplicated, but this will not be picked up -- by sharing _observation_. Perhaps a global CSE-style pass would be -- beneficial. -- vec' = scanlSeg f z vec seg -- Extract the final reduction value for each segment, which is at the last -- index of each segment. -- seg' = map (+1) seg tails = zipWith (+) seg . P.fst $ scanl' (+) 0 seg' sums = backpermute (shape seg) (\ix -> index1' $ tails ! ix) vec' -- Slice out the body of each segment. -- -- Build a head-flags representation based on the original segment -- descriptor. This contains the target length of each of the body segments, -- which is one fewer element than the actual bodies stored in vec'. Thus, -- the flags align with the last element of each body section, and when -- scanned, this element will be incremented over. -- offset = scanl1 (+) seg inc = scanl1 (+) $ permute (+) (fill (index1 $ size vec + 1) 0) (\ix -> index1' $ offset ! ix) (fill (shape seg) (1 :: Exp i)) body = backpermute (shape vec) (\ix -> index1' $ unindex1' ix + inc ! ix) vec' -- |Segmented version of 'scanl1'. -- scanl1Seg :: (Elt a, Elt i, IsIntegral i) => (Exp a -> Exp a -> Exp a) -> Acc (Vector a) -> Acc (Segments i) -> Acc (Vector a) scanl1Seg f vec seg = P.snd . unzip . scanl1 (segmented f) $ zip (mkHeadFlags seg) vec -- |Segmented version of 'prescanl'. -- prescanlSeg :: (Elt a, Elt i, IsIntegral i) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Vector a) -> Acc (Segments i) -> Acc (Vector a) prescanlSeg f e vec seg = P.fst . unatup2 $ scanl'Seg f e vec seg -- |Segmented version of 'postscanl'. -- postscanlSeg :: (Elt a, Elt i, IsIntegral i) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Vector a) -> Acc (Segments i) -> Acc (Vector a) postscanlSeg f e vec seg = map (f e) $ scanl1Seg f vec seg -- |Segmented version of 'scanr'. -- scanrSeg :: (Elt a, Elt i, IsIntegral i) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Vector a) -> Acc (Segments i) -> Acc (Vector a) scanrSeg f z vec seg = scanr1Seg f vec' seg' where -- Using technique described for 'scanlSeg', where we intersperse the array -- with the seed element at the start of each segment, and then perform an -- inclusive segmented scan. -- inc = scanl1 (+) (mkHeadFlags seg) seg' = map (+1) seg vec' = permute const (fill (index1 $ size vec + size seg) z) (\ix -> index1' $ unindex1' ix + inc ! ix - 1) vec -- | Segmented version of 'scanr''. -- scanr'Seg :: forall a i. (Elt a, Elt i, IsIntegral i) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Vector a) -> Acc (Segments i) -> Acc (Vector a, Vector a) scanr'Seg f z vec seg = result where -- Using technique described for scanl'Seg -- result = lift (body, sums) vec' = scanrSeg f z vec seg -- reduction values seg' = map (+1) seg heads = P.fst $ scanl' (+) 0 seg' sums = backpermute (shape seg) (\ix -> index1' $ heads ! ix) vec' -- body segments inc = scanl1 (+) $ mkHeadFlags seg body = backpermute (shape vec) (\ix -> index1' $ unindex1' ix + inc ! ix) vec' -- |Segmented version of 'scanr1'. -- scanr1Seg :: (Elt a, Elt i, IsIntegral i) => (Exp a -> Exp a -> Exp a) -> Acc (Vector a) -> Acc (Segments i) -> Acc (Vector a) scanr1Seg f vec seg = P.snd . unzip . scanr1 (segmented f) $ zip (mkTailFlags seg) vec -- |Segmented version of 'prescanr'. -- prescanrSeg :: (Elt a, Elt i, IsIntegral i) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Vector a) -> Acc (Segments i) -> Acc (Vector a) prescanrSeg f e vec seg = P.fst . unatup2 $ scanr'Seg f e vec seg -- |Segmented version of 'postscanr'. -- postscanrSeg :: (Elt a, Elt i, IsIntegral i) => (Exp a -> Exp a -> Exp a) -> Exp a -> Acc (Vector a) -> Acc (Segments i) -> Acc (Vector a) postscanrSeg f e vec seg = map (f e) $ scanr1Seg f vec seg -- Segmented scan helpers -- ---------------------- -- |Compute head flags vector from segment vector for left-scans. -- -- The vector will be full of zeros in the body of a segment, and non-zero -- otherwise. The "flag" value, if greater than one, indicates that several -- empty segments are represented by this single flag entry. This is additional -- data is used by exclusive segmented scan. -- mkHeadFlags :: (Elt i, IsIntegral i) => Acc (Segments i) -> Acc (Segments i) mkHeadFlags seg = init $ permute (+) zeros (\ix -> index1' (offset ! ix)) ones where (offset, len) = scanl' (+) 0 seg zeros = fill (index1' $ the len + 1) 0 ones = fill (index1 $ size offset) 1 -- |Compute tail flags vector from segment vector for right-scans. That is, the -- flag is placed at the last place in each segment. -- mkTailFlags :: (Elt i, IsIntegral i) => Acc (Segments i) -> Acc (Segments i) mkTailFlags seg = init $ permute (+) zeros (\ix -> index1' (the len - 1 - offset ! ix)) ones where (offset, len) = scanr' (+) 0 seg zeros = fill (index1' $ the len + 1) 0 ones = fill (index1 $ size offset) 1 -- |Construct a segmented version of a function from a non-segmented version. -- The segmented apply operates on a head-flag value tuple, and follows the -- procedure of Sengupta et. al. -- segmented :: (Elt e, Elt i, IsIntegral i) => (Exp e -> Exp e -> Exp e) -> Exp (i, e) -> Exp (i, e) -> Exp (i, e) segmented f a b = let (aF, aV) = unlift a (bF, bV) = unlift b in lift (aF .|. bF, bF /=* 0 ? (bV, f aV bV)) -- |Index construction and destruction generalised to integral types. -- -- We generalise the segment descriptor to integral types because some -- architectures, such as GPUs, have poor performance for 64-bit types. So, -- there is a tension between performance and requiring 64-bit indices for some -- applications, and we would not like to restrict ourselves to either one. -- -- As we don't yet support non-Int dimensions in shapes, we will need to convert -- back to concrete Int. However, don't put these generalised forms into the -- base library, because it results in too many ambiguity errors. -- index1' :: (Elt i, IsIntegral i) => Exp i -> Exp DIM1 index1' i = lift (Z :. fromIntegral i) unindex1' :: (Elt i, IsIntegral i) => Exp DIM1 -> Exp i unindex1' ix = let Z :. i = unlift ix in fromIntegral i -- Reshaping of arrays -- ------------------- -- | Flattens a given array of arbitrary dimension. -- flatten :: (Shape ix, Elt a) => Acc (Array ix a) -> Acc (Vector a) flatten a = reshape (index1 $ size a) a -- Enumeration and filling -- ----------------------- -- | Create an array where all elements are the same value. -- fill :: (Shape sh, Elt e) => Exp sh -> Exp e -> Acc (Array sh e) fill sh c = generate sh (const c) -- | Create an array of the given shape containing the values x, x+1, etc (in -- row-major order). -- enumFromN :: (Shape sh, Elt e, IsNum e) => Exp sh -> Exp e -> Acc (Array sh e) enumFromN sh x = enumFromStepN sh x 1 -- | Create an array of the given shape containing the values @x@, @x+y@, -- @x+y+y@ etc. (in row-major order). -- enumFromStepN :: (Shape sh, Elt e, IsNum e) => Exp sh -> Exp e -- ^ x: start -> Exp e -- ^ y: step -> Acc (Array sh e) enumFromStepN sh x y = reshape sh $ generate (index1 $ shapeSize sh) (\ix -> (fromIntegral (unindex1 ix :: Exp Int) * y) + x) -- Filtering -- --------- -- | Drop elements that do not satisfy the predicate -- filter :: Elt a => (Exp a -> Exp Bool) -> Acc (Vector a) -> Acc (Vector a) filter p arr = let flags = map (boolToInt . p) arr (targetIdx, len) = scanl' (+) 0 flags arr' = backpermute (index1 $ the len) id arr in permute const arr' (\ix -> flags!ix ==* 0 ? (ignore, index1 $ targetIdx!ix)) arr -- FIXME: This is abusing 'permute' in that the first two arguments are -- only justified because we know the permutation function will -- write to each location in the target exactly once. -- Instead, we should have a primitive that directly encodes the -- compaction pattern of the permutation function. {-# RULES "ACC filter/filter" forall f g arr. filter f (filter g arr) = filter (\x -> g x &&* f x) arr #-} -- Gather operations -- ----------------- -- | Copy elements from source array to destination array according to a map. This -- is a backpermute operation where a 'map' vector encodes the output to input -- index mapping. -- -- For example: -- -- > input = [1, 9, 6, 4, 4, 2, 0, 1, 2] -- > map = [1, 3, 7, 2, 5, 3] -- > -- > output = [9, 4, 1, 6, 2, 4] -- gather :: (Elt e) => Acc (Vector Int) -- ^map -> Acc (Vector e) -- ^input -> Acc (Vector e) -- ^output gather mapV inputV = backpermute (shape mapV) bpF inputV where bpF ix = lift (Z :. (mapV ! ix)) -- | Conditionally copy elements from source array to destination array according -- to a map. This is a backpermute operation where a 'map' vector encodes the -- output to input index mapping. In addition, there is a 'mask' vector, and an -- associated predication function, that specifies whether an element will be -- copied. If not copied, the output array assumes the default vector's value. -- -- For example: -- -- > default = [6, 6, 6, 6, 6, 6] -- > map = [1, 3, 7, 2, 5, 3] -- > mask = [3, 4, 9, 2, 7, 5] -- > pred = (> 4) -- > input = [1, 9, 6, 4, 4, 2, 0, 1, 2] -- > -- > output = [6, 6, 1, 6, 2, 4] -- gatherIf :: (Elt e, Elt e') => Acc (Vector Int) -- ^map -> Acc (Vector e) -- ^mask -> (Exp e -> Exp Bool) -- ^predicate -> Acc (Vector e') -- ^default -> Acc (Vector e') -- ^input -> Acc (Vector e') -- ^output gatherIf mapV maskV pred defaultV inputV = zipWith zwF predV gatheredV where zwF p g = p ? (unlift g) gatheredV = zip (gather mapV inputV) defaultV predV = map pred maskV -- Scatter operations -- ------------------ -- | Copy elements from source array to destination array according to a map. This -- is a forward-permute operation where a 'map' vector encodes an input to output -- index mapping. Output elements for indices that are not mapped assume the -- default vector's value. -- -- For example: -- -- > default = [0, 0, 0, 0, 0, 0, 0, 0, 0] -- > map = [1, 3, 7, 2, 5, 8] -- > input = [1, 9, 6, 4, 4, 2, 5] -- > -- > output = [0, 1, 4, 9, 0, 4, 0, 6, 2] -- -- Note if the same index appears in the map more than once, the result is -- undefined. The map vector cannot be larger than the input vector. -- scatter :: (Elt e) => Acc (Vector Int) -- ^map -> Acc (Vector e) -- ^default -> Acc (Vector e) -- ^input -> Acc (Vector e) -- ^output scatter mapV defaultV inputV = permute (const) defaultV pF inputV where pF ix = lift (Z :. (mapV ! ix)) -- | Conditionally copy elements from source array to destination array according -- to a map. This is a forward-permute operation where a 'map' vector encodes an -- input to output index mapping. In addition, there is a 'mask' vector, and an -- associated predicate function, that specifies whether an elements will be -- copied. If not copied, the output array assumes the default vector's value. -- -- For example: -- -- > default = [0, 0, 0, 0, 0, 0, 0, 0, 0] -- > map = [1, 3, 7, 2, 5, 8] -- > mask = [3, 4, 9, 2, 7, 5] -- > pred = (> 4) -- > input = [1, 9, 6, 4, 4, 2] -- > -- > output = [0, 0, 0, 0, 0, 4, 0, 6, 2] -- -- Note if the same index appears in the map more than once, the result is -- undefined. The map and input vector must be of the same length. -- scatterIf :: (Elt e, Elt e') => Acc (Vector Int) -- ^map -> Acc (Vector e) -- ^mask -> (Exp e -> Exp Bool) -- ^predicate -> Acc (Vector e') -- ^default -> Acc (Vector e') -- ^input -> Acc (Vector e') -- ^output scatterIf mapV maskV pred defaultV inputV = permute const defaultV pF inputV where pF ix = (pred (maskV ! ix)) ? (lift (Z :. (mapV ! ix)), ignore) -- Permutations -- ------------ -- | Reverse the elements of a vector. -- reverse :: Elt e => Acc (Vector e) -> Acc (Vector e) reverse xs = let len = unindex1 (shape xs) pf i = len - i - 1 in backpermute (shape xs) (ilift1 pf) xs -- | Transpose the rows and columns of a matrix. -- transpose :: Elt e => Acc (Array DIM2 e) -> Acc (Array DIM2 e) transpose mat = let swap = lift1 $ \(Z:.x:.y) -> Z:.y:.x :: Z:.Exp Int:.Exp Int in backpermute (swap $ shape mat) swap mat -- Extracting sub-vectors -- ---------------------- -- | Yield the first @n@ elements of the input vector. The vector must contain -- no more than @n@ elements. -- take :: Elt e => Exp Int -> Acc (Vector e) -> Acc (Vector e) take n = let n' = the (unit n) in backpermute (index1 n') id -- | Yield all but the first @n@ elements of the input vector. The vector must -- contain no fewer than @n@ elements. -- drop :: Elt e => Exp Int -> Acc (Vector e) -> Acc (Vector e) drop n arr = let n' = the (unit n) in backpermute (ilift1 (subtract n') (shape arr)) (ilift1 (+ n')) arr -- | Yield all but the last element of the input vector. The vector must not be -- empty. -- init :: Elt e => Acc (Vector e) -> Acc (Vector e) init arr = take ((unindex1 $ shape arr) - 1) arr -- | Yield all but the first element of the input vector. The vector must not be -- empty. -- tail :: Elt e => Acc (Vector e) -> Acc (Vector e) tail arr = backpermute (ilift1 (subtract 1) (shape arr)) (ilift1 (+1)) arr -- | Yield a slit (slice) from the vector. The vector must contain at least -- @i + n@ elements. Denotationally, we have: -- -- > slit i n = take n . drop i -- slit :: Elt e => Exp Int -> Exp Int -> Acc (Vector e) -> Acc (Vector e) slit i n = let i' = the (unit i) n' = the (unit n) in backpermute (index1 n') (ilift1 (+ i'))