{-# LANGUAGE BangPatterns, ExplicitForAll, TypeOperators, MagicHash #-}
{-# OPTIONS -fno-warn-orphans #-}
module Data.Array.Repa.Operators.Reduction
        ( foldS,        foldP
        , foldAllS,     foldAllP
        , sumS,         sumP
        , sumAllS,      sumAllP
        , equalsS,      equalsP)
where
import Data.Array.Repa.Base
import Data.Array.Repa.Index
import Data.Array.Repa.Eval
import Data.Array.Repa.Repr.Unboxed
import Data.Array.Repa.Operators.Mapping        as R
import Data.Array.Repa.Shape                    as S
import qualified Data.Vector.Unboxed            as V
import qualified Data.Vector.Unboxed.Mutable    as M
import Prelude                                  hiding (sum)
import qualified Data.Array.Repa.Eval.Reduction as E
import System.IO.Unsafe
import GHC.Exts

-- fold ----------------------------------------------------------------------
-- | Sequential reduction of the innermost dimension of an arbitrary rank array.
--
--   Combine this with `transpose` to fold any other dimension.
--
--   Elements are reduced in the order of their indices, from lowest to highest.
--   Applications of the operator are associatied arbitrarily.
--
--   >>> let c 0 x = x; c x 0 = x; c x y = y
--   >>> let a = fromListUnboxed (Z :. 2 :. 2) [1,2,3,4] :: Array U (Z :. Int :. Int) Int
--   >>> foldS c 0 a
--   AUnboxed (Z :. 2) (fromList [2,4])
--
foldS   :: (Shape sh, Source r a, Unbox a)
        => (a -> a -> a)
        -> a
        -> Array r (sh :. Int) a
        -> Array U sh a

foldS f z arr
 = arr `deepSeqArray`
   let  sh@(sz :. n') = extent arr
        !(I# n)       = n'
   in unsafePerformIO
    $ do mvec   <- M.unsafeNew (S.size sz)
         E.foldS mvec (\ix -> arr `unsafeIndex` fromIndex sh (I# ix)) f z n
         !vec   <- V.unsafeFreeze mvec
         now $ fromUnboxed sz vec
{-# INLINE [1] foldS #-}


-- | Parallel reduction of the innermost dimension of an arbitray rank array.
--
--   The first argument needs to be an associative sequential operator.
--   The starting element must be neutral with respect to the operator, for
--   example @0@ is neutral with respect to @(+)@ as @0 + a = a@.
--   These restrictions are required to support parallel evaluation, as the
--   starting element may be used multiple times depending on the number of threads.
--
--   Elements are reduced in the order of their indices, from lowest to highest.
--   Applications of the operator are associatied arbitrarily.
--
--   >>> let c 0 x = x; c x 0 = x; c x y = y
--   >>> let a = fromListUnboxed (Z :. 2 :. 2) [1,2,3,4] :: Array U (Z :. Int :. Int) Int
--   >>> foldP c 0 a
--   AUnboxed (Z :. 2) (fromList [2,4])
--
foldP   :: (Shape sh, Source r a, Unbox a, Monad m)
        => (a -> a -> a)
        -> a
        -> Array r (sh :. Int) a
        -> m (Array U sh a)

foldP f z arr 
 = arr `deepSeqArray`
   let  sh@(sz :. n) = extent arr
   in   case rank sh of
           -- specialise rank-1 arrays, else one thread does all the work.
           -- We can't match against the shape constructor,
           -- otherwise type error: (sz ~ Z)
           --
           1 -> do
                x       <- foldAllP f z arr
                now $ fromUnboxed sz $ V.singleton x

           _ -> now
              $ unsafePerformIO 
              $ do mvec   <- M.unsafeNew (S.size sz)
                   E.foldP mvec (\ix -> arr `unsafeIndex` fromIndex sh ix) f z n
                   !vec   <- V.unsafeFreeze mvec
                   now $ fromUnboxed sz vec
{-# INLINE [1] foldP #-}


-- foldAll --------------------------------------------------------------------
-- | Sequential reduction of an array of arbitrary rank to a single scalar value.
--
--   Elements are reduced in row-major order. Applications of the operator are
--   associated arbitrarily.
--
foldAllS :: (Shape sh, Source r a)
        => (a -> a -> a)
        -> a
        -> Array r sh a
        -> a

foldAllS f z arr 
 = arr `deepSeqArray`
   let  !ex     = extent arr
        !(I# n) = size ex
   in   E.foldAllS 
                (\ix -> arr `unsafeIndex` fromIndex ex (I# ix))
                f z n 
{-# INLINE [1] foldAllS #-}


-- | Parallel reduction of an array of arbitrary rank to a single scalar value.
--
--   The first argument needs to be an associative sequential operator.
--   The starting element must be neutral with respect to the operator,
--   for example @0@ is neutral with respect to @(+)@ as @0 + a = a@.
--   These restrictions are required to support parallel evaluation, as the
--   starting element may be used multiple times depending on the number of threads.
--
--   Elements are reduced in row-major order. Applications of the operator are
--   associated arbitrarily.
--
foldAllP 
        :: (Shape sh, Source r a, Unbox a, Monad m)
        => (a -> a -> a)
        -> a
        -> Array r sh a
        -> m a

foldAllP f z arr 
 = arr `deepSeqArray`
   let  sh = extent arr
        n  = size sh
   in   return
         $ unsafePerformIO 
         $ E.foldAllP (\ix -> arr `unsafeIndex` fromIndex sh ix) f z n
{-# INLINE [1] foldAllP #-}


-- sum ------------------------------------------------------------------------
-- | Sequential sum the innermost dimension of an array.
sumS    :: (Shape sh, Source r a, Num a, Unbox a)
        => Array r (sh :. Int) a
        -> Array U sh a
sumS = foldS (+) 0
{-# INLINE [3] sumS #-}


-- | Parallel sum the innermost dimension of an array.
sumP    :: (Shape sh, Source r a, Num a, Unbox a, Monad m)
        => Array r (sh :. Int) a
        -> m (Array U sh a)
sumP = foldP (+) 0 
{-# INLINE [3] sumP #-}


-- sumAll ---------------------------------------------------------------------
-- | Sequential sum of all the elements of an array.
sumAllS :: (Shape sh, Source r a, Num a)
        => Array r sh a
        -> a
sumAllS = foldAllS (+) 0
{-# INLINE [3] sumAllS #-}


-- | Parallel sum all the elements of an array.
sumAllP :: (Shape sh, Source r a, Unbox a, Num a, Monad m)
        => Array r sh a
        -> m a
sumAllP = foldAllP (+) 0
{-# INLINE [3] sumAllP #-}


-- Equality ------------------------------------------------------------------
instance (Shape sh, Eq sh, Source r a, Eq a) => Eq (Array r sh a) where
 (==) arr1 arr2
        =   extent arr1 == extent arr2
        && (foldAllS (&&) True (R.zipWith (==) arr1 arr2))


-- | Check whether two arrays have the same shape and contain equal elements,
--   in parallel.
equalsP :: (Shape sh, Source r1 a, Source r2 a, Eq a, Monad m) 
        => Array r1 sh a 
        -> Array r2 sh a
        -> m Bool
equalsP arr1 arr2
 = do   same    <- foldAllP (&&) True (R.zipWith (==) arr1 arr2)
        return  $ (extent arr1 == extent arr2) && same


-- | Check whether two arrays have the same shape and contain equal elements,
--   sequentially.
equalsS :: (Shape sh, Source r1 a, Source r2 a, Eq a) 
        => Array r1 sh a 
        -> Array r2 sh a
        -> Bool
equalsS arr1 arr2
        =   extent arr1 == extent arr2
        && (foldAllS (&&) True (R.zipWith (==) arr1 arr2))