{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {- | List-like functions on the next-to-innermost dimension. -} module Data.Array.Accelerate.Utility.Sliced1 where import qualified Data.Array.Accelerate.Utility.Lift.Exp as Exp import Data.Array.Accelerate.Utility.Lift.Exp (expr) import qualified Data.Array.Accelerate as A import Data.Array.Accelerate (Exp, Acc, Array, Elt, (:.)((:.)), Slice, Shape, (!), (?), ) length :: (Shape sh, Slice sh, Elt a) => Acc (Array (sh:.Int:.Int) a) -> Exp Int length = A.indexHead . A.indexTail . A.shape head :: (Shape sh, Slice sh, Elt a) => Acc (Array (sh:.Int:.Int) a) -> Acc (Array (sh:.Int) a) head xs = A.slice xs (A.constant $ A.Any:.(0::Int):.A.All) tail :: (Shape sh, Slice sh, Elt a) => Acc (Array (sh:.Int:.Int) a) -> Acc (Array (sh:.Int:.Int) a) tail xs = A.backpermute (Exp.modify (expr:.expr:.expr) (\(sh :. n :. m) -> sh :. n-1 :. m) (A.shape xs)) (Exp.modify (expr:.expr:.expr) $ \(ix:.k:.j) -> ix :. k+1 :. j) xs cons :: (Shape sh, Slice sh, Elt a) => Acc (Array (sh:.Int) a) -> Acc (Array (sh:.Int:.Int) a) -> Acc (Array (sh:.Int:.Int) a) cons x xs = A.generate (Exp.modify (expr:.expr:.expr) (\(sh :. n :. m) -> sh :. n+1 :. m) (A.shape xs)) (Exp.modify (expr:.expr:.expr) $ \(ix:.k:.j) -> k A.== 0 ? (x ! A.lift (ix:.j), xs ! A.lift (ix :. k-1 :. j))) {- | The outer and innermost dimensions must match. Otherwise you may or may not get out-of-bound errors. -} append :: (Shape sh, Slice sh, Elt a) => Acc (Array (sh:.Int:.Int) a) -> Acc (Array (sh:.Int:.Int) a) -> Acc (Array (sh:.Int:.Int) a) append x y = let ( shx:.nx:.lenx) = Exp.unlift (expr:.expr:.expr) $ A.shape x (_shy:.ny:.leny) = Exp.unlift (expr:.expr:.expr) $ A.shape y in A.generate (A.lift $ shx :. nx+ny :. max lenx leny) $ Exp.modify (expr:.expr:.expr) $ \(ix:.k:.j) -> nx A.> k ? (x ! A.lift (ix:.k:.j), y ! A.lift (ix:.k-nx:.j)) append3 :: (Shape sh, Slice sh, Elt a) => Acc (Array (sh:.Int:.Int) a) -> Acc (Array (sh:.Int:.Int) a) -> Acc (Array (sh:.Int:.Int) a) -> Acc (Array (sh:.Int:.Int) a) append3 x y z = let (sh :. n :. m) = Exp.unlift (expr :. expr :. expr) $ A.shape x in A.reshape (A.lift $ sh :. 3*n :. m) $ stack3 x y z stack3 :: (Shape sh, Slice sh, Elt a) => Acc (Array (sh:.Int:.Int) a) -> Acc (Array (sh:.Int:.Int) a) -> Acc (Array (sh:.Int:.Int) a) -> Acc (Array (sh:.Int:.Int:.Int) a) stack3 x y z = A.generate (Exp.modify (expr :. expr :. expr) (\(sh :. n :. m) -> sh :. (3::Int) :. n :. m) (A.shape x)) (Exp.modify (expr :. expr :. expr :. expr) $ \(globalIx :. k :. j :. i) -> let ix = A.lift $ globalIx :. j :. i in flip (A.caseof k) (x ! ix) $ ((A.== 1), (y ! ix)) : ((A.== 2), (z ! ix)) : []) take, drop :: (Shape sh, Slice sh, Elt a) => Exp Int -> Acc (Array (sh:.Int:.Int) a) -> Acc (Array (sh:.Int:.Int) a) take n arr = A.backpermute (Exp.modify (expr:.expr:.expr) (\(sh:._:.m) -> sh:.n:.m) $ A.shape arr) id arr drop d arr = A.backpermute (Exp.modify (expr:.expr:.expr) (\(sh:.n:.m) -> sh :. n - d :. m) $ A.shape arr) (Exp.modify (expr:.expr:.expr) $ \(ix:.k:.j) -> ix :. k + d :. j) arr sieve :: (Shape sh, Slice sh, Elt a) => Exp Int -> Exp Int -> Acc (Array (sh:.Int:.Int) a) -> Acc (Array (sh:.Int:.Int) a) sieve fac start arr = let sh:.n:.m = Exp.unlift (expr:.expr:.expr) $ A.shape arr in A.backpermute (A.lift $ sh :. div n fac :. m) (Exp.modify (expr :. expr :. expr) $ \(ix :. k :. j) -> ix :. fac*k+start :. j) arr