{-# LANGUAGE BangPatterns, ExplicitForAll, TypeOperators, MagicHash #-}
{-# OPTIONS -fno-warn-orphans #-}
module Data.Array.Repa.Operators.Reduction
        ( foldS,        foldP
        , foldAllS,     foldAllP
        , sumS,         sumP
        , sumAllS,      sumAllP
        , equalsS,      equalsP)
where
import Data.Array.Repa.Base
import Data.Array.Repa.Index
import Data.Array.Repa.Eval
import Data.Array.Repa.Repr.Unboxed
import Data.Array.Repa.Operators.Mapping        as R
import Data.Array.Repa.Shape                    as S
import qualified Data.Vector.Unboxed            as V
import qualified Data.Vector.Unboxed.Mutable    as M
import Prelude                                  hiding (sum)
import qualified Data.Array.Repa.Eval.Reduction as E
import System.IO.Unsafe
import GHC.Exts

-- fold ----------------------------------------------------------------------
-- | Sequential reduction of the innermost dimension of an arbitrary rank array.
--
--   Combine this with `transpose` to fold any other dimension.
--
--   Elements are reduced in the order of their indices, from lowest to highest.
--   Applications of the operator are associatied arbitrarily.
--
--   >>> let c 0 x = x; c x 0 = x; c x y = y
--   >>> let a = fromListUnboxed (Z :. 2 :. 2) [1,2,3,4] :: Array U (Z :. Int :. Int) Int
--   >>> foldS c 0 a
--   AUnboxed (Z :. 2) (fromList [2,4])
--
foldS   :: (Shape sh, Source r a, Unbox a)
        => (a -> a -> a)
        -> a
        -> Array r (sh :. Int) a
        -> Array U sh a

foldS :: (a -> a -> a) -> a -> Array r (sh :. Int) a -> Array U sh a
foldS a -> a -> a
f a
z Array r (sh :. Int) a
arr
 = Array r (sh :. Int) a
arr Array r (sh :. Int) a -> Array U sh a -> Array U sh a
forall r e sh b. (Source r e, Shape sh) => Array r sh e -> b -> b
`deepSeqArray`
   let  sh :: sh :. Int
sh@(sh
sz :. Int
n') = Array r (sh :. Int) a -> sh :. Int
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r (sh :. Int) a
arr
        !(I# Int#
n)       = Int
n'
   in IO (Array U sh a) -> Array U sh a
forall a. IO a -> a
unsafePerformIO
    (IO (Array U sh a) -> Array U sh a)
-> IO (Array U sh a) -> Array U sh a
forall a b. (a -> b) -> a -> b
$ do IOVector a
mvec   <- Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
M.unsafeNew (sh -> Int
forall sh. Shape sh => sh -> Int
S.size sh
sz)
         IOVector a -> (Int# -> a) -> (a -> a -> a) -> a -> Int# -> IO ()
forall a.
Unbox a =>
IOVector a -> (Int# -> a) -> (a -> a -> a) -> a -> Int# -> IO ()
E.foldS IOVector a
mvec (\Int#
ix -> Array r (sh :. Int) a
arr Array r (sh :. Int) a -> (sh :. Int) -> a
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh -> e
`unsafeIndex` (sh :. Int) -> Int -> sh :. Int
forall sh. Shape sh => sh -> Int -> sh
fromIndex sh :. Int
sh (Int# -> Int
I# Int#
ix)) a -> a -> a
f a
z Int#
n
         !Vector a
vec   <- MVector (PrimState IO) a -> IO (Vector a)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze IOVector a
MVector (PrimState IO) a
mvec
         Array U sh a -> IO (Array U sh a)
forall sh r e (m :: * -> *).
(Shape sh, Source r e, Monad m) =>
Array r sh e -> m (Array r sh e)
now (Array U sh a -> IO (Array U sh a))
-> Array U sh a -> IO (Array U sh a)
forall a b. (a -> b) -> a -> b
$ sh -> Vector a -> Array U sh a
forall sh e. sh -> Vector e -> Array U sh e
fromUnboxed sh
sz Vector a
vec
{-# INLINE [1] foldS #-}


-- | Parallel reduction of the innermost dimension of an arbitray rank array.
--
--   The first argument needs to be an associative sequential operator.
--   The starting element must be neutral with respect to the operator, for
--   example @0@ is neutral with respect to @(+)@ as @0 + a = a@.
--   These restrictions are required to support parallel evaluation, as the
--   starting element may be used multiple times depending on the number of threads.
--
--   Elements are reduced in the order of their indices, from lowest to highest.
--   Applications of the operator are associatied arbitrarily.
--
--   >>> let c 0 x = x; c x 0 = x; c x y = y
--   >>> let a = fromListUnboxed (Z :. 2 :. 2) [1,2,3,4] :: Array U (Z :. Int :. Int) Int
--   >>> foldP c 0 a
--   AUnboxed (Z :. 2) (fromList [2,4])
--
foldP   :: (Shape sh, Source r a, Unbox a, Monad m)
        => (a -> a -> a)
        -> a
        -> Array r (sh :. Int) a
        -> m (Array U sh a)

foldP :: (a -> a -> a) -> a -> Array r (sh :. Int) a -> m (Array U sh a)
foldP a -> a -> a
f a
z Array r (sh :. Int) a
arr 
 = Array r (sh :. Int) a
arr Array r (sh :. Int) a -> m (Array U sh a) -> m (Array U sh a)
forall r e sh b. (Source r e, Shape sh) => Array r sh e -> b -> b
`deepSeqArray`
   let  sh :: sh :. Int
sh@(sh
sz :. Int
n) = Array r (sh :. Int) a -> sh :. Int
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r (sh :. Int) a
arr
   in   case (sh :. Int) -> Int
forall sh. Shape sh => sh -> Int
rank sh :. Int
sh of
           -- specialise rank-1 arrays, else one thread does all the work.
           -- We can't match against the shape constructor,
           -- otherwise type error: (sz ~ Z)
           --
           Int
1 -> do
                a
x       <- (a -> a -> a) -> a -> Array r (sh :. Int) a -> m a
forall sh r a (m :: * -> *).
(Shape sh, Source r a, Unbox a, Monad m) =>
(a -> a -> a) -> a -> Array r sh a -> m a
foldAllP a -> a -> a
f a
z Array r (sh :. Int) a
arr
                Array U sh a -> m (Array U sh a)
forall sh r e (m :: * -> *).
(Shape sh, Source r e, Monad m) =>
Array r sh e -> m (Array r sh e)
now (Array U sh a -> m (Array U sh a))
-> Array U sh a -> m (Array U sh a)
forall a b. (a -> b) -> a -> b
$ sh -> Vector a -> Array U sh a
forall sh e. sh -> Vector e -> Array U sh e
fromUnboxed sh
sz (Vector a -> Array U sh a) -> Vector a -> Array U sh a
forall a b. (a -> b) -> a -> b
$ a -> Vector a
forall a. Unbox a => a -> Vector a
V.singleton a
x

           Int
_ -> Array U sh a -> m (Array U sh a)
forall sh r e (m :: * -> *).
(Shape sh, Source r e, Monad m) =>
Array r sh e -> m (Array r sh e)
now
              (Array U sh a -> m (Array U sh a))
-> Array U sh a -> m (Array U sh a)
forall a b. (a -> b) -> a -> b
$ IO (Array U sh a) -> Array U sh a
forall a. IO a -> a
unsafePerformIO 
              (IO (Array U sh a) -> Array U sh a)
-> IO (Array U sh a) -> Array U sh a
forall a b. (a -> b) -> a -> b
$ do IOVector a
mvec   <- Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
M.unsafeNew (sh -> Int
forall sh. Shape sh => sh -> Int
S.size sh
sz)
                   IOVector a -> (Int -> a) -> (a -> a -> a) -> a -> Int -> IO ()
forall a.
Unbox a =>
IOVector a -> (Int -> a) -> (a -> a -> a) -> a -> Int -> IO ()
E.foldP IOVector a
mvec (\Int
ix -> Array r (sh :. Int) a
arr Array r (sh :. Int) a -> (sh :. Int) -> a
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh -> e
`unsafeIndex` (sh :. Int) -> Int -> sh :. Int
forall sh. Shape sh => sh -> Int -> sh
fromIndex sh :. Int
sh Int
ix) a -> a -> a
f a
z Int
n
                   !Vector a
vec   <- MVector (PrimState IO) a -> IO (Vector a)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze IOVector a
MVector (PrimState IO) a
mvec
                   Array U sh a -> IO (Array U sh a)
forall sh r e (m :: * -> *).
(Shape sh, Source r e, Monad m) =>
Array r sh e -> m (Array r sh e)
now (Array U sh a -> IO (Array U sh a))
-> Array U sh a -> IO (Array U sh a)
forall a b. (a -> b) -> a -> b
$ sh -> Vector a -> Array U sh a
forall sh e. sh -> Vector e -> Array U sh e
fromUnboxed sh
sz Vector a
vec
{-# INLINE [1] foldP #-}


-- foldAll --------------------------------------------------------------------
-- | Sequential reduction of an array of arbitrary rank to a single scalar value.
--
--   Elements are reduced in row-major order. Applications of the operator are
--   associated arbitrarily.
--
foldAllS :: (Shape sh, Source r a)
        => (a -> a -> a)
        -> a
        -> Array r sh a
        -> a

foldAllS :: (a -> a -> a) -> a -> Array r sh a -> a
foldAllS a -> a -> a
f a
z Array r sh a
arr 
 = Array r sh a
arr Array r sh a -> a -> a
forall r e sh b. (Source r e, Shape sh) => Array r sh e -> b -> b
`deepSeqArray`
   let  !ex :: sh
ex     = Array r sh a -> sh
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r sh a
arr
        !(I# Int#
n) = sh -> Int
forall sh. Shape sh => sh -> Int
size sh
ex
   in   (Int# -> a) -> (a -> a -> a) -> a -> Int# -> a
forall a. (Int# -> a) -> (a -> a -> a) -> a -> Int# -> a
E.foldAllS 
                (\Int#
ix -> Array r sh a
arr Array r sh a -> sh -> a
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh -> e
`unsafeIndex` sh -> Int -> sh
forall sh. Shape sh => sh -> Int -> sh
fromIndex sh
ex (Int# -> Int
I# Int#
ix))
                a -> a -> a
f a
z Int#
n 
{-# INLINE [1] foldAllS #-}


-- | Parallel reduction of an array of arbitrary rank to a single scalar value.
--
--   The first argument needs to be an associative sequential operator.
--   The starting element must be neutral with respect to the operator,
--   for example @0@ is neutral with respect to @(+)@ as @0 + a = a@.
--   These restrictions are required to support parallel evaluation, as the
--   starting element may be used multiple times depending on the number of threads.
--
--   Elements are reduced in row-major order. Applications of the operator are
--   associated arbitrarily.
--
foldAllP 
        :: (Shape sh, Source r a, Unbox a, Monad m)
        => (a -> a -> a)
        -> a
        -> Array r sh a
        -> m a

foldAllP :: (a -> a -> a) -> a -> Array r sh a -> m a
foldAllP a -> a -> a
f a
z Array r sh a
arr 
 = Array r sh a
arr Array r sh a -> m a -> m a
forall r e sh b. (Source r e, Shape sh) => Array r sh e -> b -> b
`deepSeqArray`
   let  sh :: sh
sh = Array r sh a -> sh
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r sh a
arr
        n :: Int
n  = sh -> Int
forall sh. Shape sh => sh -> Int
size sh
sh
   in   a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return
         (a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$ IO a -> a
forall a. IO a -> a
unsafePerformIO 
         (IO a -> a) -> IO a -> a
forall a b. (a -> b) -> a -> b
$ (Int -> a) -> (a -> a -> a) -> a -> Int -> IO a
forall a.
Unbox a =>
(Int -> a) -> (a -> a -> a) -> a -> Int -> IO a
E.foldAllP (\Int
ix -> Array r sh a
arr Array r sh a -> sh -> a
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh -> e
`unsafeIndex` sh -> Int -> sh
forall sh. Shape sh => sh -> Int -> sh
fromIndex sh
sh Int
ix) a -> a -> a
f a
z Int
n
{-# INLINE [1] foldAllP #-}


-- sum ------------------------------------------------------------------------
-- | Sequential sum the innermost dimension of an array.
sumS    :: (Shape sh, Source r a, Num a, Unbox a)
        => Array r (sh :. Int) a
        -> Array U sh a
sumS :: Array r (sh :. Int) a -> Array U sh a
sumS = (a -> a -> a) -> a -> Array r (sh :. Int) a -> Array U sh a
forall sh r a.
(Shape sh, Source r a, Unbox a) =>
(a -> a -> a) -> a -> Array r (sh :. Int) a -> Array U sh a
foldS a -> a -> a
forall a. Num a => a -> a -> a
(+) a
0
{-# INLINE [3] sumS #-}


-- | Parallel sum the innermost dimension of an array.
sumP    :: (Shape sh, Source r a, Num a, Unbox a, Monad m)
        => Array r (sh :. Int) a
        -> m (Array U sh a)
sumP :: Array r (sh :. Int) a -> m (Array U sh a)
sumP = (a -> a -> a) -> a -> Array r (sh :. Int) a -> m (Array U sh a)
forall sh r a (m :: * -> *).
(Shape sh, Source r a, Unbox a, Monad m) =>
(a -> a -> a) -> a -> Array r (sh :. Int) a -> m (Array U sh a)
foldP a -> a -> a
forall a. Num a => a -> a -> a
(+) a
0 
{-# INLINE [3] sumP #-}


-- sumAll ---------------------------------------------------------------------
-- | Sequential sum of all the elements of an array.
sumAllS :: (Shape sh, Source r a, Num a)
        => Array r sh a
        -> a
sumAllS :: Array r sh a -> a
sumAllS = (a -> a -> a) -> a -> Array r sh a -> a
forall sh r a.
(Shape sh, Source r a) =>
(a -> a -> a) -> a -> Array r sh a -> a
foldAllS a -> a -> a
forall a. Num a => a -> a -> a
(+) a
0
{-# INLINE [3] sumAllS #-}


-- | Parallel sum all the elements of an array.
sumAllP :: (Shape sh, Source r a, Unbox a, Num a, Monad m)
        => Array r sh a
        -> m a
sumAllP :: Array r sh a -> m a
sumAllP = (a -> a -> a) -> a -> Array r sh a -> m a
forall sh r a (m :: * -> *).
(Shape sh, Source r a, Unbox a, Monad m) =>
(a -> a -> a) -> a -> Array r sh a -> m a
foldAllP a -> a -> a
forall a. Num a => a -> a -> a
(+) a
0
{-# INLINE [3] sumAllP #-}


-- Equality ------------------------------------------------------------------
instance (Shape sh, Eq sh, Source r a, Eq a) => Eq (Array r sh a) where
 == :: Array r sh a -> Array r sh a -> Bool
(==) Array r sh a
arr1 Array r sh a
arr2
        =   Array r sh a -> sh
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r sh a
arr1 sh -> sh -> Bool
forall a. Eq a => a -> a -> Bool
== Array r sh a -> sh
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r sh a
arr2
        Bool -> Bool -> Bool
&& ((Bool -> Bool -> Bool) -> Bool -> Array D sh Bool -> Bool
forall sh r a.
(Shape sh, Source r a) =>
(a -> a -> a) -> a -> Array r sh a -> a
foldAllS Bool -> Bool -> Bool
(&&) Bool
True ((a -> a -> Bool) -> Array r sh a -> Array r sh a -> Array D sh Bool
forall sh r1 a r2 b c.
(Shape sh, Source r1 a, Source r2 b) =>
(a -> b -> c) -> Array r1 sh a -> Array r2 sh b -> Array D sh c
R.zipWith a -> a -> Bool
forall a. Eq a => a -> a -> Bool
(==) Array r sh a
arr1 Array r sh a
arr2))


-- | Check whether two arrays have the same shape and contain equal elements,
--   in parallel.
equalsP :: (Shape sh, Source r1 a, Source r2 a, Eq a, Monad m) 
        => Array r1 sh a 
        -> Array r2 sh a
        -> m Bool
equalsP :: Array r1 sh a -> Array r2 sh a -> m Bool
equalsP Array r1 sh a
arr1 Array r2 sh a
arr2
 = do   Bool
same    <- (Bool -> Bool -> Bool) -> Bool -> Array D sh Bool -> m Bool
forall sh r a (m :: * -> *).
(Shape sh, Source r a, Unbox a, Monad m) =>
(a -> a -> a) -> a -> Array r sh a -> m a
foldAllP Bool -> Bool -> Bool
(&&) Bool
True ((a -> a -> Bool)
-> Array r1 sh a -> Array r2 sh a -> Array D sh Bool
forall sh r1 a r2 b c.
(Shape sh, Source r1 a, Source r2 b) =>
(a -> b -> c) -> Array r1 sh a -> Array r2 sh b -> Array D sh c
R.zipWith a -> a -> Bool
forall a. Eq a => a -> a -> Bool
(==) Array r1 sh a
arr1 Array r2 sh a
arr2)
        Bool -> m Bool
forall (m :: * -> *) a. Monad m => a -> m a
return  (Bool -> m Bool) -> Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ (Array r1 sh a -> sh
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r1 sh a
arr1 sh -> sh -> Bool
forall a. Eq a => a -> a -> Bool
== Array r2 sh a -> sh
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r2 sh a
arr2) Bool -> Bool -> Bool
&& Bool
same


-- | Check whether two arrays have the same shape and contain equal elements,
--   sequentially.
equalsS :: (Shape sh, Source r1 a, Source r2 a, Eq a) 
        => Array r1 sh a 
        -> Array r2 sh a
        -> Bool
equalsS :: Array r1 sh a -> Array r2 sh a -> Bool
equalsS Array r1 sh a
arr1 Array r2 sh a
arr2
        =   Array r1 sh a -> sh
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r1 sh a
arr1 sh -> sh -> Bool
forall a. Eq a => a -> a -> Bool
== Array r2 sh a -> sh
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array r2 sh a
arr2
        Bool -> Bool -> Bool
&& ((Bool -> Bool -> Bool) -> Bool -> Array D sh Bool -> Bool
forall sh r a.
(Shape sh, Source r a) =>
(a -> a -> a) -> a -> Array r sh a -> a
foldAllS Bool -> Bool -> Bool
(&&) Bool
True ((a -> a -> Bool)
-> Array r1 sh a -> Array r2 sh a -> Array D sh Bool
forall sh r1 a r2 b c.
(Shape sh, Source r1 a, Source r2 b) =>
(a -> b -> c) -> Array r1 sh a -> Array r2 sh b -> Array D sh c
R.zipWith a -> a -> Bool
forall a. Eq a => a -> a -> Bool
(==) Array r1 sh a
arr1 Array r2 sh a
arr2))