{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE BangPatterns     #-}
module Data.Massiv.Array.Stencil.Convolution
  ( makeConvolutionStencil
  , makeConvolutionStencilFromKernel
  , makeCorrelationStencil
  , makeCorrelationStencilFromKernel
  ) where
import           Data.Massiv.Core.Common
import           Data.Massiv.Array.Ops.Fold         (ifoldlS)
import           Data.Massiv.Array.Stencil.Internal
import           GHC.Exts                           (inline)
makeConvolutionStencil
  :: (Index ix, Num e)
  => ix
  -> ix
  -> ((ix -> Value e -> Value e -> Value e) -> Value e -> Value e)
  -> Stencil ix e e
makeConvolutionStencil !sSz !sCenter relStencil = validateStencil 0 $ Stencil sSz sCenter stencil
  where
    stencil getVal !ix =
      (inline relStencil $ \ !ixD !kVal !acc -> getVal (liftIndex2 (-) ix ixD) * kVal + acc) 0
    {-# INLINE stencil #-}
{-# INLINE makeConvolutionStencil #-}
makeConvolutionStencilFromKernel
  :: (Manifest r ix e, Num e)
  => Array r ix e
  -> Stencil ix e e
makeConvolutionStencilFromKernel kArr = Stencil sz sCenter stencil
  where
    !sz = size kArr
    !sCenter = liftIndex (`div` 2) sz
    stencil getVal !ix = Value (ifoldlS accum 0 kArr) where
      accum !acc !kIx !kVal =
        unValue (getVal (liftIndex2 (+) ix (liftIndex2 (-) sCenter kIx))) * kVal + acc
      {-# INLINE accum #-}
    {-# INLINE stencil #-}
{-# INLINE makeConvolutionStencilFromKernel #-}
makeCorrelationStencil
  :: (Index ix, Num e)
  => ix
  -> ix
  -> ((ix -> Value e -> Value e -> Value e) -> Value e -> Value e)
  -> Stencil ix e e
makeCorrelationStencil !sSz !sCenter relStencil = validateStencil 0 $ Stencil sSz sCenter stencil
  where
    stencil getVal !ix =
      (inline relStencil $ \ !ixD !kVal !acc -> getVal (liftIndex2 (+) ix ixD) * kVal + acc) 0
    {-# INLINE stencil #-}
{-# INLINE makeCorrelationStencil #-}
makeCorrelationStencilFromKernel
  :: (Manifest r ix e, Num e)
  => Array r ix e
  -> Stencil ix e e
makeCorrelationStencilFromKernel kArr = Stencil sz sCenter stencil
  where
    !sz = size kArr
    !sCenter = liftIndex (`div` 2) sz
    stencil getVal !ix = Value (ifoldlS accum 0 kArr) where
      accum !acc !kIx !kVal =
        unValue (getVal (liftIndex2 (+) ix (liftIndex2 (+) sCenter kIx))) * kVal + acc
      {-# INLINE accum #-}
    {-# INLINE stencil #-}
{-# INLINE makeCorrelationStencilFromKernel #-}