{-# OPTIONS -fno-warn-orphans #-} {-# LANGUAGE CPP #-} #include "fusion-phases.h" module Data.Array.Parallel.Lifted.Scalar where import Data.Array.Parallel.Lifted.PArray import Data.Array.Parallel.PArray.PReprInstances import Data.Array.Parallel.PArray.PDataInstances import qualified Data.Array.Parallel.Unlifted as U import Data.Array.Parallel.Base (fromBool, toBool) import GHC.Exts (Int(..)) -- Pretend Bools are scalars -------------------------------------------------- instance Scalar Bool where {-# INLINE toScalarPData #-} toScalarPData bs = PBool (U.tagsToSel2 (U.map fromBool bs)) {-# INLINE fromScalarPData #-} fromScalarPData (PBool sel) = U.map toBool (U.tagsSel2 sel) -- Projections ---------------------------------------------------------------- prim_lengthPA :: Scalar a => PArray a -> Int {-# INLINE prim_lengthPA #-} prim_lengthPA xs = I# (lengthPA# xs) -- Conversion ----------------------------------------------------------------- -- | Create a PArray out of a scalar U.Array, -- the first argument is the array length. -- -- TODO: ditch this version, just use fromUArrPA' -- fromUArrPA :: Scalar a => Int -> U.Array a -> PArray a {-# INLINE fromUArrPA #-} fromUArrPA (I# n#) xs = PArray n# (toScalarPData xs) -- | Create a PArray out of a scalar U.Array, -- reading the length directly from the U.Array. fromUArrPA' :: Scalar a => U.Array a -> PArray a {-# INLINE fromUArrPA' #-} fromUArrPA' xs = fromUArrPA (U.length xs) xs -- | Convert a PArray back to a plain U.Array. toUArrPA :: Scalar a => PArray a -> U.Array a {-# INLINE toUArrPA #-} toUArrPA (PArray _ xs) = fromScalarPData xs -- Tuple Conversions ---------------------------------------------------------- -- | Convert an U.Array of pairs to a PArray. fromUArrPA_2 :: (Scalar a, Scalar b) => Int -> U.Array (a,b) -> PArray (a,b) {-# INLINE fromUArrPA_2 #-} fromUArrPA_2 (I# n#) ps = PArray n# (P_2 (toScalarPData xs) (toScalarPData ys)) where (xs,ys) = U.unzip ps -- | Convert a U.Array of pairs to a PArray, -- reading the length directly from the U.Array. fromUArrPA_2' :: (Scalar a, Scalar b) => U.Array (a,b) -> PArray (a, b) {-# INLINE fromUArrPA_2' #-} fromUArrPA_2' ps = fromUArrPA_2 (U.length ps) ps -- | Convert a U.Array of triples to a PArray. fromUArrPA_3 :: (Scalar a, Scalar b, Scalar c) => Int -> U.Array ((a,b),c) -> PArray (a,b,c) {-# INLINE fromUArrPA_3 #-} fromUArrPA_3 (I# n#) ps = PArray n# (P_3 (toScalarPData xs) (toScalarPData ys) (toScalarPData zs)) where (qs,zs) = U.unzip ps (xs,ys) = U.unzip qs -- | Convert a U.Array of triples to a PArray, -- reading the length directly from the U.Array. fromUArrPA_3' :: (Scalar a, Scalar b, Scalar c) => U.Array ((a,b),c) -> PArray (a, b, c) {-# INLINE fromUArrPA_3' #-} fromUArrPA_3' ps = fromUArrPA_3 (U.length ps) ps -- Nesting arrays ------------------------------------------------------------- -- | O(1). Create a nested array. nestUSegdPA :: Int -- ^ total number of elements in the nested array -> U.Segd -- ^ segment descriptor -> PArray a -- ^ array of data elements. -> PArray (PArray a) {-# INLINE nestUSegdPA #-} nestUSegdPA (I# n#) segd (PArray _ xs) = PArray n# (PNested segd xs) -- | O(1). Create a nested array, -- using the same length as the source array. nestUSegdPA' :: U.Segd -- ^ segment descriptor -> PArray a -- ^ array of data elements -> PArray (PArray a) {-# INLINE nestUSegdPA' #-} nestUSegdPA' segd xs = nestUSegdPA (U.lengthSegd segd) segd xs -- Scalar Operators ----------------------------------------------------------- -- These work on PArrays of scalar elements. -- TODO: Why do we need these versions as well as the standard ones? -- | Apply a worker function to every element of an array, yielding a new array. scalar_map :: (Scalar a, Scalar b) => (a -> b) -> PArray a -> PArray b {-# INLINE_PA scalar_map #-} scalar_map f xs = fromUArrPA (prim_lengthPA xs) . U.map f $ toUArrPA xs -- | Zip two arrays, yielding a new array. scalar_zipWith :: (Scalar a, Scalar b, Scalar c) => (a -> b -> c) -> PArray a -> PArray b -> PArray c {-# INLINE_PA scalar_zipWith #-} scalar_zipWith f xs ys = fromUArrPA (prim_lengthPA xs) $ U.zipWith f (toUArrPA xs) (toUArrPA ys) -- | Zip three arrays, yielding a new array. scalar_zipWith3 :: (Scalar a, Scalar b, Scalar c, Scalar d) => (a -> b -> c -> d) -> PArray a -> PArray b -> PArray c -> PArray d {-# INLINE_PA scalar_zipWith3 #-} scalar_zipWith3 f xs ys zs = fromUArrPA (prim_lengthPA xs) $ U.zipWith3 f (toUArrPA xs) (toUArrPA ys) (toUArrPA zs) -- | Left fold over an array. scalar_fold :: Scalar a => (a -> a -> a) -> a -> PArray a -> a {-# INLINE_PA scalar_fold #-} scalar_fold f z = U.fold f z . toUArrPA -- | Left fold over an array, using the first element to initialise the state. scalar_fold1 :: Scalar a => (a -> a -> a) -> PArray a -> a {-# INLINE_PA scalar_fold1 #-} scalar_fold1 f = U.fold1 f . toUArrPA -- | Segmented fold of an array of arrays. -- Each segment is folded individually, yielding an array of the fold results. scalar_folds :: Scalar a => (a -> a -> a) -> a -> PArray (PArray a) -> PArray a {-# INLINE_PA scalar_folds #-} scalar_folds f z xss = fromUArrPA (prim_lengthPA (concatPA# xss)) . U.fold_s f z (segdPA# xss) . toUArrPA $ concatPA# xss -- | Segmented fold of an array of arrays, using the first element of each -- segment to initialse the state for that segment. -- Each segment is folded individually, yielding an array of all the fold results. scalar_fold1s :: Scalar a => (a -> a -> a) -> PArray (PArray a) -> PArray a {-# INLINE_PA scalar_fold1s #-} scalar_fold1s f xss = fromUArrPA (prim_lengthPA (concatPA# xss)) . U.fold1_s f (segdPA# xss) . toUArrPA $ concatPA# xss -- | Left fold over an array, also passing the index of each element -- to the parameter function. scalar_fold1Index :: Scalar a => ((Int, a) -> (Int, a) -> (Int, a)) -> PArray a -> Int {-# INLINE_PA scalar_fold1Index #-} scalar_fold1Index f = fst . U.fold1 f . U.indexed . toUArrPA -- | Segmented fold over an array, also passing the index of each -- element to the parameter function. scalar_fold1sIndex :: Scalar a => ((Int, a) -> (Int, a) -> (Int, a)) -> PArray (PArray a) -> PArray Int {-# INLINE_PA scalar_fold1sIndex #-} scalar_fold1sIndex f (PArray m# (PNested segd xs)) = PArray m# $ toScalarPData $ U.fsts $ U.fold1_s f segd $ U.zip (U.indices_s segd) $ fromScalarPData xs