module Data.Repa.Eval.Generic.Seq.Reduction
        ( foldAll
        , foldRange
        , foldInner)
where
import GHC.Exts


-- | Sequential reduction of all the elements in an array.
foldAll :: (Int# -> a)         -- ^ Function to get an element from the source.
        -> (a -> a -> a)       -- ^ Binary associative combining function.
        -> a                   -- ^ Neutral starting value.
        -> Int#                -- ^ Number of elements.
        -> a

foldAll :: forall a. (Int# -> a) -> (a -> a -> a) -> a -> Int# -> a
foldAll Int# -> a
get a -> a -> a
c !a
r !Int#
len
 = (Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> a
forall a. (Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> a
foldRange Int# -> a
get a -> a -> a
c a
r Int#
0# Int#
len 
{-# INLINE [1] foldAll #-}


-- | Sequential reduction of a multidimensional array along the innermost dimension.
foldInner   
        :: (Int# -> a -> IO ()) -- ^ Function to write into the result buffer.
        -> (Int# -> a)          -- ^ Function to get an element from the source.
        -> (a -> a -> a)        -- ^ Binary associative combination function.
        -> a                    -- ^ Neutral starting value.
        -> Int#                 -- ^ Total length of source.
        -> Int#                 -- ^ Inner dimension (length to fold over).
        -> IO ()

foldInner :: forall a.
(Int# -> a -> IO ())
-> (Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> IO ()
foldInner Int# -> a -> IO ()
write Int# -> a
get a -> a -> a
c !a
r !Int#
end !Int#
n
 = Int# -> Int# -> IO ()
iter Int#
0# Int#
0#
 where
        iter :: Int# -> Int# -> IO ()
iter !Int#
sh !Int#
sz 
         | Int#
1# <- Int#
sh Int# -> Int# -> Int#
>=# Int#
end 
         = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

         | Bool
otherwise 
         = do   let !next :: Int#
next = Int#
sz Int# -> Int# -> Int#
+# Int#
n
                Int# -> a -> IO ()
write Int#
sh ((Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> a
forall a. (Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> a
foldRange Int# -> a
get a -> a -> a
c a
r Int#
sz Int#
next)
                Int# -> Int# -> IO ()
iter (Int#
sh Int# -> Int# -> Int#
+# Int#
1#) Int#
next
        {-# INLINE iter #-}
{-# INLINE [1] foldInner #-}


-- Reduce ---------------------------------------------------------------------
-- | Sequentially reduce values between the given indices.
---
--   We use manual specialisations and rewrite rules to avoid the result
--   being boxed up in the final iteration.
foldRange
        :: (Int# -> a)          -- ^ Function to get an element from the source.
        -> (a -> a -> a)        -- ^ Binary associative combining function.
        -> a                    -- ^ Neutral starting value.
        -> Int#                 -- ^ Starting index.
        -> Int#                 -- ^ Ending index.
        -> a

foldRange :: forall a. (Int# -> a) -> (a -> a -> a) -> a -> Int# -> Int# -> a
foldRange Int# -> a
f a -> a -> a
c !a
r !Int#
start !Int#
end 
 = Int# -> a -> a
iter Int#
start a
r
 where  iter :: Int# -> a -> a
iter !Int#
i !a
z 
         | Int#
1# <- Int#
i Int# -> Int# -> Int#
>=# Int#
end  = a
z 
         | Bool
otherwise        = Int# -> a -> a
iter (Int#
i Int# -> Int# -> Int#
+# Int#
1#) (Int# -> a
f Int#
i a -> a -> a
`c` a
z)
        {-# INLINE iter #-}
{-# INLINE [0] foldRange #-}


foldRangeInt
        :: (Int# -> Int#)
        -> (Int# -> Int# -> Int#)
        -> Int# 
        -> Int# -> Int# 
        -> Int#

foldRangeInt :: (Int# -> Int#)
-> (Int# -> Int# -> Int#) -> Int# -> Int# -> Int# -> Int#
foldRangeInt Int# -> Int#
f Int# -> Int# -> Int#
c !Int#
r !Int#
start !Int#
end 
 = Int# -> Int# -> Int#
iter Int#
start Int#
r
 where  iter :: Int# -> Int# -> Int#
iter !Int#
i !Int#
z 
         | Int#
1# <- Int#
i Int# -> Int# -> Int#
>=# Int#
end  = Int#
z 
         | Bool
otherwise        = Int# -> Int# -> Int#
iter (Int#
i Int# -> Int# -> Int#
+# Int#
1#) (Int# -> Int#
f Int#
i Int# -> Int# -> Int#
`c` Int#
z)
        {-# INLINE iter #-}
{-# INLINE [0] foldRangeInt #-}


foldRangeFloat
        :: (Int# -> Float#) 
        -> (Float# -> Float# -> Float#)
        -> Float# 
        -> Int# -> Int# 
        -> Float#

foldRangeFloat :: (Int# -> Float#)
-> (Float# -> Float# -> Float#) -> Float# -> Int# -> Int# -> Float#
foldRangeFloat Int# -> Float#
f Float# -> Float# -> Float#
c !Float#
r !Int#
start !Int#
end 
 = Int# -> Float# -> Float#
iter Int#
start Float#
r
 where  iter :: Int# -> Float# -> Float#
iter !Int#
i !Float#
z 
         | Int#
1# <- Int#
i Int# -> Int# -> Int#
>=# Int#
end  = Float#
z 
         | Bool
otherwise         = Int# -> Float# -> Float#
iter (Int#
i Int# -> Int# -> Int#
+# Int#
1#) (Int# -> Float#
f Int#
i Float# -> Float# -> Float#
`c` Float#
z)
        {-# INLINE iter #-}
{-# INLINE [0] foldRangeFloat #-}


foldRangeDouble
        :: (Int# -> Double#) 
        -> (Double# -> Double# -> Double#)
        -> Double# 
        -> Int# -> Int# 
        -> Double#

foldRangeDouble :: (Int# -> Double#)
-> (Double# -> Double# -> Double#)
-> Double#
-> Int#
-> Int#
-> Double#
foldRangeDouble Int# -> Double#
f Double# -> Double# -> Double#
c !Double#
r !Int#
start !Int#
end 
 = Int# -> Double# -> Double#
iter Int#
start Double#
r
 where  iter :: Int# -> Double# -> Double#
iter !Int#
i !Double#
z 
         | Int#
1# <- Int#
i Int# -> Int# -> Int#
>=# Int#
end  = Double#
z 
         | Bool
otherwise        = Int# -> Double# -> Double#
iter (Int#
i Int# -> Int# -> Int#
+# Int#
1#) (Int# -> Double#
f Int#
i Double# -> Double# -> Double#
`c` Double#
z)
        {-# INLINE iter #-}
{-# INLINE [0] foldRangeDouble #-}


unboxInt :: Int -> Int#
unboxInt :: Int -> Int#
unboxInt (I# Int#
i) = Int#
i
{-# INLINE unboxInt #-}


unboxFloat :: Float -> Float#
unboxFloat :: Float -> Float#
unboxFloat (F# Float#
f) = Float#
f
{-# INLINE unboxFloat #-}


unboxDouble :: Double -> Double#
unboxDouble :: Double -> Double#
unboxDouble (D# Double#
d) = Double#
d
{-# INLINE unboxDouble #-}


{-# RULES "foldRangeInt" 
    forall (get :: Int# -> Int) f r start end
    . foldRange get f r start end 
    = I# (foldRangeInt
                (\i     -> unboxInt (get i))
                (\d1 d2 -> unboxInt (f (I# d1) (I# d2)))
                (unboxInt r)
                start
                end)
 #-}


{-# RULES "foldRangeFloat" 
    forall (get :: Int# -> Float) f r start end
    . foldRange get f r start end 
    = F# (foldRangeFloat
                (\i     -> unboxFloat (get i))
                (\d1 d2 -> unboxFloat (f (F# d1) (F# d2)))
                (unboxFloat r)
                start
                end)
 #-}


{-# RULES "foldRangeDouble" 
    forall (get :: Int# -> Double) f r start end
    . foldRange get f r start end 
    = D# (foldRangeDouble
                (\i     -> unboxDouble (get i))
                (\d1 d2 -> unboxDouble (f (D# d1) (D# d2)))
                (unboxDouble r)
                start
                end)
 #-}