{-# LANGUAGE BangPatterns, PackageImports #-}
{-# OPTIONS -Wall -fno-warn-missing-signatures -fno-warn-incomplete-patterns #-}

-- | Generic stencil based convolutions. 
-- 
--   If your stencil fits within a 7x7 tile and is known at compile-time then using
--   then using the built-in stencil support provided by the main Repa package will
--   be 5-10x faster. 
-- 
--   If you have a larger stencil, the coefficients are not statically known, 
--   or need more complex boundary handling than provided by the built-in functions,
--   then use this version instead.
--
module Data.Array.Repa.Algorithms.Convolve
        ( -- * Arbitrary boundary handling
          convolveP

          -- * Specialised boundary handling
        , GetOut
        , outAs
        , outClamp
        , convolveOutP )
where
import Data.Array.Repa                                  as R
import Data.Array.Repa.Unsafe                           as R
import Data.Array.Repa.Repr.Unboxed                     as R
import qualified Data.Vector.Unboxed                    as V
import qualified Data.Array.Repa.Shape                  as S
import Prelude                                          as P


-- Plain Convolve -------------------------------------------------------------
-- | Image-kernel convolution,
--   which takes a function specifying what value to return when the
--   kernel doesn't apply.
convolveP
        :: (Num a, Unbox a, Monad m)
        => (DIM2 -> a)          -- ^ Function to get border elements when 
                                --   the stencil does not apply.
        -> Array U DIM2 a       -- ^ Stencil to use in the convolution.
        -> Array U DIM2 a       -- ^ Input image.
        -> m (Array U DIM2 a)

convolveP makeOut kernel image
 = kernel `deepSeqArray` image `deepSeqArray` 
   computeP $ unsafeTraverse image id update
 where  
        (Z :. krnHeight :. krnWidth)        = extent kernel
        krnVec          = toUnboxed kernel
        
        imgSh@(Z :. imgHeight :. imgWidth)  = extent image
        imgVec          = toUnboxed image

        !krnHeight2     = krnHeight `div` 2
        !krnWidth2      = krnWidth  `div` 2

        -- If we're too close to the edge of the input image then
        -- we can't apply the stencil because we don't have enough data.
        !borderLeft     = krnWidth2
        !borderRight    = imgWidth   - krnWidth2  - 1
        !borderUp       = krnHeight2
        !borderDown     = imgHeight  - krnHeight2 - 1

        {-# INLINE update #-}
        update _ ix@(_ :. j :. i)
         | i < borderLeft       = makeOut ix
         | i > borderRight      = makeOut ix
         | j < borderUp         = makeOut ix
         | j > borderDown       = makeOut ix
         | otherwise            = stencil j i

        -- The actual stencil function.
        {-# INLINE stencil #-}
        stencil j i
         = let  imgStart = S.toIndex imgSh (Z :. j - krnHeight2 :. i - krnWidth2)
           in   integrate 0 0 0 imgStart 0

        {-# INLINE integrate #-}
        integrate !acc !x !y !imgCur !krnCur  
         | y >= krnHeight
         = acc

         | x >= krnWidth
         = integrate acc 0 (y + 1) (imgCur + imgWidth - krnWidth) krnCur 
        
         | otherwise
         = let  imgZ    = imgVec `V.unsafeIndex` imgCur 
                krnZ    = krnVec `V.unsafeIndex` krnCur 
                here    = imgZ * krnZ 
           in   integrate (acc + here) (x + 1) y (imgCur + 1) (krnCur + 1)
{-# INLINE convolveP #-}


-- Convolve Out -----------------------------------------------------------------------------------
-- | A function that gets out of range elements from an image.
type GetOut a
        = (DIM2 -> a)   -- ^ The original get function.
        -> DIM2         -- ^ The shape of the image.
        -> DIM2         -- ^ Index of element we were trying to get.
        -> a


-- | Use the provided value for every out-of-range element.
outAs :: a -> GetOut a
{-# INLINE outAs #-}
outAs x _ _ _ = x


-- | If the requested element is out of range use
--   the closest one from the real image.
outClamp :: GetOut a
{-# INLINE outClamp #-}
outClamp get (_ :. yLen :. xLen) (sh :. j :. i)
 = clampX j i
 where  {-# INLINE clampX #-}
        clampX !y !x
          | x < 0       = clampY y 0
          | x >= xLen   = clampY y (xLen - 1)
          | otherwise   = clampY y x
                
        {-# INLINE clampY #-}
        clampY !y !x
          | y < 0       = get (sh :. 0          :. x)
          | y >= yLen   = get (sh :. (yLen - 1) :. x)
          | otherwise   = get (sh :. y          :. x)


-- | Image-kernel convolution, 
--   which takes a function specifying what value to use for out-of-range elements.
convolveOutP
        :: (Num a, Unbox a, Monad m)
        => GetOut a             -- ^ How to handle out-of-range elements.
        -> Array U DIM2 a       -- ^ Stencil to use in the convolution.
        -> Array U DIM2 a       -- ^ Input image.
        -> m (Array U DIM2 a)

convolveOutP getOut kernel image
 = kernel `deepSeqArray` image `deepSeqArray` 
   computeP $ unsafeTraverse image id stencil
 where  
        krnSh@(Z :. krnHeight :. krnWidth)  = extent kernel        
        imgSh@(Z :. imgHeight :. imgWidth)  = extent image

        !krnHeight2     = krnHeight `div` 2
        !krnWidth2      = krnWidth  `div` 2
        !krnSize        = S.size krnSh

        -- If we're too close to the edge of the input image then
        -- we can't apply the stencil because we don't have enough data.
        !borderLeft     = krnWidth2
        !borderRight    = imgWidth   - krnWidth2  - 1
        !borderUp       = krnHeight2
        !borderDown     = imgHeight  - krnHeight2 - 1

        -- The actual stencil function.
        {-# INLINE stencil #-}
        stencil get (_ :. j :. i)
         = let
                {-# INLINE get' #-}
                get' ix@(_ :. y :. x)
                 | x < borderLeft       = getOut get imgSh ix
                 | x > borderRight      = getOut get imgSh ix
                 | y < borderUp         = getOut get imgSh ix
                 | y > borderDown       = getOut get imgSh ix
                 | otherwise            = get ix

                !ikrnWidth'     = i - krnWidth2
                !jkrnHeight'    = j - krnHeight2

                {-# INLINE integrate #-}
                integrate !count !acc
                 | count == krnSize             = acc
                 | otherwise
                 = let  !ix@(sh :. y :. x)      = S.fromIndex krnSh count
                        !ix'                    = sh :. y + jkrnHeight' :. x + ikrnWidth'
                        !here                   = kernel `unsafeIndex` ix * (get' ix')
                   in   integrate (count + 1) (acc + here)

           in   integrate 0 0
{-# INLINE convolveOutP #-}