{-# LANGUAGE BangPatterns, MagicHash #-}
module Data.Array.Repa.Eval.Reduction
        ( foldS,    foldP
        , foldAllS, foldAllP)
where
import Data.Array.Repa.Eval.Gang
import qualified Data.Vector.Unboxed            as V
import qualified Data.Vector.Unboxed.Mutable    as M
import GHC.Base                                 ( quotInt, divInt )
import GHC.Exts


-- | Sequential reduction of a multidimensional array along the innermost dimension.
foldS :: V.Unbox a
      => M.IOVector a   -- ^ vector to write elements into
      -> (Int# -> a)    -- ^ function to get an element from the given index
      -> (a -> a -> a)  -- ^ binary associative combination function
      -> a              -- ^ starting value (typically an identity)
      -> Int#           -- ^ inner dimension (length to fold over)
      -> IO ()
{-# INLINE [1] foldS #-}
foldS :: IOVector a -> (Int# -> a) -> (a -> a -> a) -> a -> Int# -> IO ()
foldS !IOVector a
vec Int# -> a
get a -> a -> a
c !a
r !Int#
n
  = Int# -> Int# -> IO ()
iter Int#
0# Int#
0#
  where
    !(I# Int#
end) = IOVector a -> Int
forall a s. Unbox a => MVector s a -> Int
M.length IOVector a
vec

    {-# INLINE iter #-}
    iter :: Int# -> Int# -> IO ()
iter !Int#
sh !Int#
sz 
     | Int#
1# <- Int#
sh Int# -> Int# -> Int#
>=# Int#
end 
     = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

     | Bool
otherwise 
     = do let !next :: Int#
next = Int#
sz Int# -> Int# -> Int#
+# Int#
n
          MVector (PrimState IO) a -> Int -> a -> IO ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite IOVector a
MVector (PrimState IO) a
vec (Int# -> Int
I# Int#
sh) ((Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> a
forall a. (Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> a
reduceAny Int# -> a
get a -> a -> a
c a
r Int#
sz Int#
next)
          Int# -> Int# -> IO ()
iter (Int#
sh Int# -> Int# -> Int#
+# Int#
1#) Int#
next


-- | Parallel reduction of a multidimensional array along the innermost dimension.
--   Each output value is computed by a single thread, with the output values
--   distributed evenly amongst the available threads.
foldP :: V.Unbox a
      => M.IOVector a   -- ^ vector to write elements into
      -> (Int -> a)     -- ^ function to get an element from the given index
      -> (a -> a -> a)  -- ^ binary associative combination operator 
      -> a              -- ^ starting value. Must be neutral with respect
                        -- ^ to the operator. eg @0 + a = a@.
      -> Int            -- ^ inner dimension (length to fold over)
      -> IO ()
{-# INLINE [1] foldP #-}
foldP :: IOVector a -> (Int -> a) -> (a -> a -> a) -> a -> Int -> IO ()
foldP IOVector a
vec Int -> a
f a -> a -> a
c !a
r (I# Int#
n)
  = Gang -> (Int -> IO ()) -> IO ()
gangIO Gang
theGang
  ((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(I# Int#
tid) -> Int# -> Int# -> IO ()
fill (Int# -> Int#
split Int#
tid) (Int# -> Int#
split (Int#
tid Int# -> Int# -> Int#
+# Int#
1#))
  where
    !(I# Int#
threads) = Gang -> Int
gangSize Gang
theGang
    !(I# Int#
len)     = IOVector a -> Int
forall a s. Unbox a => MVector s a -> Int
M.length IOVector a
vec
    !step :: Int#
step         = (Int#
len Int# -> Int# -> Int#
+# Int#
threads Int# -> Int# -> Int#
-# Int#
1#) Int# -> Int# -> Int#
`quotInt#` Int#
threads

    {-# INLINE split #-}
    split :: Int# -> Int#
split !Int#
ix 
     = let !ix' :: Int#
ix' = Int#
ix Int# -> Int# -> Int#
*# Int#
step
       in  case Int#
len Int# -> Int# -> Int#
<# Int#
ix' of
             Int#
0# -> Int#
ix'
             Int#
_  -> Int#
len

    {-# INLINE fill #-}
    fill :: Int# -> Int# -> IO ()
fill !Int#
start !Int#
end 
     = Int# -> Int# -> IO ()
iter Int#
start (Int#
start Int# -> Int# -> Int#
*# Int#
n)
     where
        {-# INLINE iter #-}
        iter :: Int# -> Int# -> IO ()
iter !Int#
sh !Int#
sz 
         | Int#
1# <- Int#
sh Int# -> Int# -> Int#
>=# Int#
end 
         = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

         | Bool
otherwise 
         = do   let !next :: Int#
next = Int#
sz Int# -> Int# -> Int#
+# Int#
n
                MVector (PrimState IO) a -> Int -> a -> IO ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite IOVector a
MVector (PrimState IO) a
vec (Int# -> Int
I# Int#
sh) ((Int -> a) -> (a -> a -> a) -> a -> Int -> Int -> a
forall a. (Int -> a) -> (a -> a -> a) -> a -> Int -> Int -> a
reduce Int -> a
f a -> a -> a
c a
r (Int# -> Int
I# Int#
sz) (Int# -> Int
I# Int#
next))
                Int# -> Int# -> IO ()
iter (Int#
sh Int# -> Int# -> Int#
+# Int#
1#) Int#
next


-- | Sequential reduction of all the elements in an array.
foldAllS :: (Int# -> a)         -- ^ function to get an element from the given index
         -> (a -> a -> a)       -- ^ binary associative combining function
         -> a                   -- ^ starting value
         -> Int#                -- ^ number of elements
         -> a

{-# INLINE [1] foldAllS #-}
foldAllS :: (Int# -> a) -> (a -> a -> a) -> a -> Int# -> a
foldAllS Int# -> a
f a -> a -> a
c !a
r !Int#
len
 = (Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> a
forall a. (Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> a
reduceAny (\Int#
i -> Int# -> a
f Int#
i) a -> a -> a
c a
r Int#
0# Int#
len 



-- | Parallel tree reduction of an array to a single value. Each thread takes an
--   equally sized chunk of the data and computes a partial sum. The main thread
--   then reduces the array of partial sums to the final result.
--
--   We don't require that the initial value be a neutral element, so each thread
--   computes a fold1 on its chunk of the data, and the seed element is only
--   applied in the final reduction step.
--
foldAllP :: V.Unbox a
         => (Int -> a)          -- ^ function to get an element from the given index
         -> (a -> a -> a)       -- ^ binary associative combining function
         -> a                   -- ^ starting value
         -> Int                 -- ^ number of elements
         -> IO a
{-# INLINE [1] foldAllP #-}

foldAllP :: (Int -> a) -> (a -> a -> a) -> a -> Int -> IO a
foldAllP Int -> a
f a -> a -> a
c !a
r !Int
len
  | Int
len Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0    = a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
r
  | Bool
otherwise   = do
      MVector RealWorld a
mvec <- Int -> IO (MVector (PrimState IO) a)
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
Int -> m (MVector (PrimState m) a)
M.unsafeNew Int
chunks
      Gang -> (Int -> IO ()) -> IO ()
gangIO Gang
theGang ((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Int
tid -> MVector RealWorld a -> Int -> Int -> Int -> IO ()
fill MVector RealWorld a
mvec Int
tid (Int -> Int
split Int
tid) (Int -> Int
split (Int
tidInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1))
      Vector a
vec  <- MVector (PrimState IO) a -> IO (Vector a)
forall a (m :: * -> *).
(Unbox a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze MVector RealWorld a
MVector (PrimState IO) a
mvec
      a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> IO a) -> a -> IO a
forall a b. (a -> b) -> a -> b
$! (a -> a -> a) -> a -> Vector a -> a
forall b a. Unbox b => (a -> b -> a) -> a -> Vector b -> a
V.foldl' a -> a -> a
c a
r Vector a
vec
  where
    !threads :: Int
threads    = Gang -> Int
gangSize Gang
theGang
    !step :: Int
step       = (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
threads Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int -> Int -> Int
`quotInt` Int
threads
    chunks :: Int
chunks      = ((Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
step Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int -> Int -> Int
`divInt` Int
step) Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` Int
threads

    {-# INLINE split #-}
    split :: Int -> Int
split !Int
ix   = Int
len Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` (Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
step)

    {-# INLINE fill #-}
    fill :: MVector RealWorld a -> Int -> Int -> Int -> IO ()
fill !MVector RealWorld a
mvec !Int
tid !Int
start !Int
end
      | Int
start Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
end = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      | Bool
otherwise    = MVector (PrimState IO) a -> Int -> a -> IO ()
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite MVector RealWorld a
MVector (PrimState IO) a
mvec Int
tid ((Int -> a) -> (a -> a -> a) -> a -> Int -> Int -> a
forall a. (Int -> a) -> (a -> a -> a) -> a -> Int -> Int -> a
reduce Int -> a
f a -> a -> a
c (Int -> a
f Int
start) (Int
startInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
end)



-- Reduce ---------------------------------------------------------------------
-- | This is the primitive reduction function.
--   We use manual specialisations and rewrite rules to avoid the result
--   being boxed up in the final iteration.
{-# INLINE [0] reduce #-}
reduce  :: (Int -> a)           -- ^ Get data from the array.
        -> (a -> a -> a)        -- ^ Function to combine elements.
        -> a                    -- ^ Starting value.
        -> Int                  -- ^ Starting index in array.
        -> Int                  -- ^ Ending index in array.
        -> a                    -- ^ Result.
reduce :: (Int -> a) -> (a -> a -> a) -> a -> Int -> Int -> a
reduce Int -> a
f a -> a -> a
c !a
r (I# Int#
start) (I# Int#
end)
 = (Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> a
forall a. (Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> a
reduceAny (\Int#
i -> Int -> a
f (Int# -> Int
I# Int#
i)) a -> a -> a
c a
r Int#
start Int#
end


-- | Sequentially reduce values between the given indices
{-# INLINE [0] reduceAny #-}
reduceAny :: (Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> a
reduceAny :: (Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> a
reduceAny Int# -> a
f a -> a -> a
c !a
r !Int#
start !Int#
end 
 = Int# -> a -> a
iter Int#
start a
r
 where
   {-# INLINE iter #-}
   iter :: Int# -> a -> a
iter !Int#
i !a
z 
    | Int#
1# <- Int#
i Int# -> Int# -> Int#
>=# Int#
end  = a
z 
    | Bool
otherwise        = Int# -> a -> a
iter (Int#
i Int# -> Int# -> Int#
+# Int#
1#) (a
z a -> a -> a
`c` Int# -> a
f Int#
i)


{-# INLINE [0] reduceInt #-}
reduceInt
        :: (Int# -> Int#)
        -> (Int# -> Int# -> Int#)
        -> Int# 
        -> Int# -> Int# 
        -> Int#

reduceInt :: (Int# -> Int#)
-> (Int# -> Int# -> Int#) -> Int# -> Int# -> Int# -> Int#
reduceInt Int# -> Int#
f Int# -> Int# -> Int#
c !Int#
r !Int#
start !Int#
end 
 = Int# -> Int# -> Int#
iter Int#
start Int#
r
 where
   {-# INLINE iter #-}
   iter :: Int# -> Int# -> Int#
iter !Int#
i !Int#
z 
    | Int#
1# <- Int#
i Int# -> Int# -> Int#
>=# Int#
end   = Int#
z 
    | Bool
otherwise         = Int# -> Int# -> Int#
iter (Int#
i Int# -> Int# -> Int#
+# Int#
1#) (Int#
z Int# -> Int# -> Int#
`c` Int# -> Int#
f Int#
i)


{-# INLINE [0] reduceFloat #-}
reduceFloat
        :: (Int# -> Float#) 
        -> (Float# -> Float# -> Float#)
        -> Float# 
        -> Int# -> Int# 
        -> Float#

reduceFloat :: (Int# -> Float#)
-> (Float# -> Float# -> Float#) -> Float# -> Int# -> Int# -> Float#
reduceFloat Int# -> Float#
f Float# -> Float# -> Float#
c !Float#
r !Int#
start !Int#
end 
 = Int# -> Float# -> Float#
iter Int#
start Float#
r
 where
   {-# INLINE iter #-}
   iter :: Int# -> Float# -> Float#
iter !Int#
i !Float#
z 
    | Int#
1# <- Int#
i Int# -> Int# -> Int#
>=# Int#
end   = Float#
z 
    | Bool
otherwise         = Int# -> Float# -> Float#
iter (Int#
i Int# -> Int# -> Int#
+# Int#
1#) (Float#
z Float# -> Float# -> Float#
`c` Int# -> Float#
f Int#
i)


{-# INLINE [0] reduceDouble #-}
reduceDouble
        :: (Int# -> Double#) 
        -> (Double# -> Double# -> Double#)
        -> Double# 
        -> Int# -> Int# 
        -> Double#

reduceDouble :: (Int# -> Double#)
-> (Double# -> Double# -> Double#)
-> Double#
-> Int#
-> Int#
-> Double#
reduceDouble Int# -> Double#
f Double# -> Double# -> Double#
c !Double#
r !Int#
start !Int#
end 
 = Int# -> Double# -> Double#
iter Int#
start Double#
r
 where
   {-# INLINE iter #-}
   iter :: Int# -> Double# -> Double#
iter !Int#
i !Double#
z 
    | Int#
1# <- Int#
i Int# -> Int# -> Int#
>=# Int#
end   = Double#
z 
    | Bool
otherwise         = Int# -> Double# -> Double#
iter (Int#
i Int# -> Int# -> Int#
+# Int#
1#) (Double#
z Double# -> Double# -> Double#
`c` Int# -> Double#
f Int#
i)


{-# INLINE unboxInt #-}
unboxInt :: Int -> Int#
unboxInt :: Int -> Int#
unboxInt (I# Int#
i) = Int#
i


{-# INLINE unboxFloat #-}
unboxFloat :: Float -> Float#
unboxFloat :: Float -> Float#
unboxFloat (F# Float#
f) = Float#
f


{-# INLINE unboxDouble #-}
unboxDouble :: Double -> Double#
unboxDouble :: Double -> Double#
unboxDouble (D# Double#
d) = Double#
d


{-# RULES "reduceInt" 
    forall (get :: Int# -> Int) f r start end
    . reduceAny get f r start end 
    = I# (reduceInt 
                (\i     -> unboxInt (get i))
                (\d1 d2 -> unboxInt (f (I# d1) (I# d2)))
                (unboxInt r)
                start
                end)
 #-}


{-# RULES "reduceFloat" 
    forall (get :: Int# -> Float) f r start end
    . reduceAny get f r start end 
    = F# (reduceFloat
                (\i     -> unboxFloat (get i))
                (\d1 d2 -> unboxFloat (f (F# d1) (F# d2)))
                (unboxFloat r)
                start
                end)
 #-}


{-# RULES "reduceDouble" 
    forall (get :: Int# -> Double) f r start end
    . reduceAny get f r start end 
    = D# (reduceDouble 
                (\i     -> unboxDouble (get i))
                (\d1 d2 -> unboxDouble (f (D# d1) (D# d2)))
                (unboxDouble r)
                start
                end)
 #-}