{-# LANGUAGE BangPatterns, ExplicitForAll, TypeOperators, MagicHash #-}
{-# OPTIONS -fno-warn-orphans #-}
module Data.Array.Repa.Operators.Reduction
( foldS, foldP
, foldAllS, foldAllP
, sumS, sumP
, sumAllS, sumAllP
, equalsS, equalsP)
where
import Data.Array.Repa.Base
import Data.Array.Repa.Index
import Data.Array.Repa.Eval
import Data.Array.Repa.Repr.Unboxed
import Data.Array.Repa.Operators.Mapping as R
import Data.Array.Repa.Shape as S
import qualified Data.Vector.Unboxed as V
import qualified Data.Vector.Unboxed.Mutable as M
import Prelude hiding (sum)
import qualified Data.Array.Repa.Eval.Reduction as E
import System.IO.Unsafe
import GHC.Exts
-- fold ----------------------------------------------------------------------
-- | Sequential reduction of the innermost dimension of an arbitrary rank array.
--
-- Combine this with `transpose` to fold any other dimension.
--
-- Elements are reduced in the order of their indices, from lowest to highest.
-- Applications of the operator are associatied arbitrarily.
--
-- >>> let c 0 x = x; c x 0 = x; c x y = y
-- >>> let a = fromListUnboxed (Z :. 2 :. 2) [1,2,3,4] :: Array U (Z :. Int :. Int) Int
-- >>> foldS c 0 a
-- AUnboxed (Z :. 2) (fromList [2,4])
--
foldS :: (Shape sh, Source r a, Unbox a)
=> (a -> a -> a)
-> a
-> Array r (sh :. Int) a
-> Array U sh a
foldS f z arr
= arr `deepSeqArray`
let sh@(sz :. n') = extent arr
!(I# n) = n'
in unsafePerformIO
$ do mvec <- M.unsafeNew (S.size sz)
E.foldS mvec (\ix -> arr `unsafeIndex` fromIndex sh (I# ix)) f z n
!vec <- V.unsafeFreeze mvec
now $ fromUnboxed sz vec
{-# INLINE [1] foldS #-}
-- | Parallel reduction of the innermost dimension of an arbitray rank array.
--
-- The first argument needs to be an associative sequential operator.
-- The starting element must be neutral with respect to the operator, for
-- example @0@ is neutral with respect to @(+)@ as @0 + a = a@.
-- These restrictions are required to support parallel evaluation, as the
-- starting element may be used multiple times depending on the number of threads.
--
-- Elements are reduced in the order of their indices, from lowest to highest.
-- Applications of the operator are associatied arbitrarily.
--
-- >>> let c 0 x = x; c x 0 = x; c x y = y
-- >>> let a = fromListUnboxed (Z :. 2 :. 2) [1,2,3,4] :: Array U (Z :. Int :. Int) Int
-- >>> foldP c 0 a
-- AUnboxed (Z :. 2) (fromList [2,4])
--
foldP :: (Shape sh, Source r a, Unbox a, Monad m)
=> (a -> a -> a)
-> a
-> Array r (sh :. Int) a
-> m (Array U sh a)
foldP f z arr
= arr `deepSeqArray`
let sh@(sz :. n) = extent arr
in case rank sh of
-- specialise rank-1 arrays, else one thread does all the work.
-- We can't match against the shape constructor,
-- otherwise type error: (sz ~ Z)
--
1 -> do
x <- foldAllP f z arr
now $ fromUnboxed sz $ V.singleton x
_ -> now
$ unsafePerformIO
$ do mvec <- M.unsafeNew (S.size sz)
E.foldP mvec (\ix -> arr `unsafeIndex` fromIndex sh ix) f z n
!vec <- V.unsafeFreeze mvec
now $ fromUnboxed sz vec
{-# INLINE [1] foldP #-}
-- foldAll --------------------------------------------------------------------
-- | Sequential reduction of an array of arbitrary rank to a single scalar value.
--
-- Elements are reduced in row-major order. Applications of the operator are
-- associated arbitrarily.
--
foldAllS :: (Shape sh, Source r a)
=> (a -> a -> a)
-> a
-> Array r sh a
-> a
foldAllS f z arr
= arr `deepSeqArray`
let !ex = extent arr
!(I# n) = size ex
in E.foldAllS
(\ix -> arr `unsafeIndex` fromIndex ex (I# ix))
f z n
{-# INLINE [1] foldAllS #-}
-- | Parallel reduction of an array of arbitrary rank to a single scalar value.
--
-- The first argument needs to be an associative sequential operator.
-- The starting element must be neutral with respect to the operator,
-- for example @0@ is neutral with respect to @(+)@ as @0 + a = a@.
-- These restrictions are required to support parallel evaluation, as the
-- starting element may be used multiple times depending on the number of threads.
--
-- Elements are reduced in row-major order. Applications of the operator are
-- associated arbitrarily.
--
foldAllP
:: (Shape sh, Source r a, Unbox a, Monad m)
=> (a -> a -> a)
-> a
-> Array r sh a
-> m a
foldAllP f z arr
= arr `deepSeqArray`
let sh = extent arr
n = size sh
in return
$ unsafePerformIO
$ E.foldAllP (\ix -> arr `unsafeIndex` fromIndex sh ix) f z n
{-# INLINE [1] foldAllP #-}
-- sum ------------------------------------------------------------------------
-- | Sequential sum the innermost dimension of an array.
sumS :: (Shape sh, Source r a, Num a, Unbox a)
=> Array r (sh :. Int) a
-> Array U sh a
sumS = foldS (+) 0
{-# INLINE [3] sumS #-}
-- | Parallel sum the innermost dimension of an array.
sumP :: (Shape sh, Source r a, Num a, Unbox a, Monad m)
=> Array r (sh :. Int) a
-> m (Array U sh a)
sumP = foldP (+) 0
{-# INLINE [3] sumP #-}
-- sumAll ---------------------------------------------------------------------
-- | Sequential sum of all the elements of an array.
sumAllS :: (Shape sh, Source r a, Num a)
=> Array r sh a
-> a
sumAllS = foldAllS (+) 0
{-# INLINE [3] sumAllS #-}
-- | Parallel sum all the elements of an array.
sumAllP :: (Shape sh, Source r a, Unbox a, Num a, Monad m)
=> Array r sh a
-> m a
sumAllP = foldAllP (+) 0
{-# INLINE [3] sumAllP #-}
-- Equality ------------------------------------------------------------------
instance (Shape sh, Eq sh, Source r a, Eq a) => Eq (Array r sh a) where
(==) arr1 arr2
= extent arr1 == extent arr2
&& (foldAllS (&&) True (R.zipWith (==) arr1 arr2))
-- | Check whether two arrays have the same shape and contain equal elements,
-- in parallel.
equalsP :: (Shape sh, Source r1 a, Source r2 a, Eq a, Monad m)
=> Array r1 sh a
-> Array r2 sh a
-> m Bool
equalsP arr1 arr2
= do same <- foldAllP (&&) True (R.zipWith (==) arr1 arr2)
return $ (extent arr1 == extent arr2) && same
-- | Check whether two arrays have the same shape and contain equal elements,
-- sequentially.
equalsS :: (Shape sh, Source r1 a, Source r2 a, Eq a)
=> Array r1 sh a
-> Array r2 sh a
-> Bool
equalsS arr1 arr2
= extent arr1 == extent arr2
&& (foldAllS (&&) True (R.zipWith (==) arr1 arr2))