{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
-- |
-- Module      : Data.Massiv.Array.Stencil.Convolution
-- Copyright   : (c) Alexey Kuleshevich 2018-2019
-- License     : BSD3
-- Maintainer  : Alexey Kuleshevich <lehins@yandex.ru>
-- Stability   : experimental
-- Portability : non-portable
--
module Data.Massiv.Array.Stencil.Convolution
  ( makeConvolutionStencil
  , makeConvolutionStencilFromKernel
  , makeCorrelationStencil
  , makeCorrelationStencilFromKernel
  ) where

import Data.Massiv.Array.Ops.Fold (ifoldlS)
import Data.Massiv.Array.Stencil.Internal
import Data.Massiv.Core.Common
import GHC.Exts (inline)

-- | Create a convolution stencil by specifying border resolution technique and
-- an accumulator function.
--
-- ==== __Examples__
--
-- Here is how to create a 2D horizontal Sobel Stencil:
--
-- > sobelX :: Num e => Stencil Ix2 e e
-- > sobelX = makeConvolutionStencil (Sz2 3 3) (1 :. 1) $
-- >            \f -> f (-1 :. -1) (-1) . f (-1 :. 1) 1 .
-- >                  f ( 0 :. -1) (-2) . f ( 0 :. 1) 2 .
-- >                  f ( 1 :. -1) (-1) . f ( 1 :. 1) 1
-- > {-# INLINE sobelX #-}
--
-- @since 0.1.0
makeConvolutionStencil
  :: (Index ix, Num e)
  => Sz ix
  -> ix
  -> ((ix -> Value e -> Value e -> Value e) -> Value e -> Value e)
  -> Stencil ix e e
makeConvolutionStencil :: Sz ix
-> ix
-> ((ix -> Value e -> Value e -> Value e) -> Value e -> Value e)
-> Stencil ix e e
makeConvolutionStencil !Sz ix
sz !ix
sCenter (ix -> Value e -> Value e -> Value e) -> Value e -> Value e
relStencil =
  e -> Stencil ix e e -> Stencil ix e e
forall ix e a. Index ix => e -> Stencil ix e a -> Stencil ix e a
validateStencil e
0 (Stencil ix e e -> Stencil ix e e)
-> Stencil ix e e -> Stencil ix e e
forall a b. (a -> b) -> a -> b
$ Sz ix -> ix -> ((ix -> Value e) -> ix -> Value e) -> Stencil ix e e
forall ix e a.
Sz ix -> ix -> ((ix -> Value e) -> ix -> Value a) -> Stencil ix e a
Stencil Sz ix
sz ix
sInvertCenter (ix -> Value e) -> ix -> Value e
stencil
  where
    !sInvertCenter :: ix
sInvertCenter = (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ((Int -> Int) -> ix -> ix
forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1) (Sz ix -> ix
forall ix. Sz ix -> ix
unSz Sz ix
sz)) ix
sCenter
    stencil :: (ix -> Value e) -> ix -> Value e
stencil ix -> Value e
getVal !ix
ix =
      (((ix -> Value e -> Value e -> Value e) -> Value e -> Value e)
-> (ix -> Value e -> Value e -> Value e) -> Value e -> Value e
forall a. a -> a
inline (ix -> Value e -> Value e -> Value e) -> Value e -> Value e
relStencil ((ix -> Value e -> Value e -> Value e) -> Value e -> Value e)
-> (ix -> Value e -> Value e -> Value e) -> Value e -> Value e
forall a b. (a -> b) -> a -> b
$ \ !ix
ixD !Value e
kVal !Value e
acc -> ix -> Value e
getVal ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ix
ix ix
ixD) Value e -> Value e -> Value e
forall a. Num a => a -> a -> a
* Value e
kVal Value e -> Value e -> Value e
forall a. Num a => a -> a -> a
+ Value e
acc) Value e
0
    {-# INLINE stencil #-}
{-# INLINE makeConvolutionStencil #-}


-- | Make a stencil out of a Kernel Array. This `Stencil` will be slower than if
-- `makeConvolutionStencil` is used, but sometimes we just really don't know the
-- kernel at compile time.
--
-- @since 0.1.0
makeConvolutionStencilFromKernel
  :: (Manifest r ix e, Num e)
  => Array r ix e
  -> Stencil ix e e
makeConvolutionStencilFromKernel :: Array r ix e -> Stencil ix e e
makeConvolutionStencilFromKernel Array r ix e
kArr = Sz ix -> ix -> ((ix -> Value e) -> ix -> Value e) -> Stencil ix e e
forall ix e a.
Sz ix -> ix -> ((ix -> Value e) -> ix -> Value a) -> Stencil ix e a
Stencil Sz ix
sz ix
sInvertCenter (ix -> Value e) -> ix -> Value e
stencil
  where
    !sz :: Sz ix
sz@(Sz ix
szi) = Array r ix e -> Sz ix
forall r ix e. Load r ix e => Array r ix e -> Sz ix
size Array r ix e
kArr
    !szi1 :: ix
szi1 = (Int -> Int) -> ix -> ix
forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1) ix
szi
    !sInvertCenter :: ix
sInvertCenter = (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ix
szi1 ix
sCenter
    !sCenter :: ix
sCenter = (Int -> Int) -> ix -> ix
forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
2) ix
szi
    stencil :: (ix -> Value e) -> ix -> Value e
stencil ix -> Value e
getVal !ix
ix = e -> Value e
forall e. e -> Value e
Value ((e -> ix -> e -> e) -> e -> Array r ix e -> e
forall r ix e a.
Source r ix e =>
(a -> ix -> e -> a) -> a -> Array r ix e -> a
ifoldlS e -> ix -> e -> e
accum e
0 Array r ix e
kArr) where
      !ixOff :: ix
ixOff = (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) ix
ix ix
sCenter
      accum :: e -> ix -> e -> e
accum !e
acc !ix
kIx !e
kVal =
        Value e -> e
forall e. Value e -> e
unValue (ix -> Value e
getVal ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ix
ixOff ix
kIx)) e -> e -> e
forall a. Num a => a -> a -> a
* e
kVal e -> e -> e
forall a. Num a => a -> a -> a
+ e
acc
      {-# INLINE accum #-}
    {-# INLINE stencil #-}
{-# INLINE makeConvolutionStencilFromKernel #-}


-- | Make a <https://en.wikipedia.org/wiki/Cross-correlation cross-correlation> stencil
--
-- @since 0.1.5
makeCorrelationStencil
  :: (Index ix, Num e)
  => Sz ix
  -> ix
  -> ((ix -> Value e -> Value e -> Value e) -> Value e -> Value e)
  -> Stencil ix e e
makeCorrelationStencil :: Sz ix
-> ix
-> ((ix -> Value e -> Value e -> Value e) -> Value e -> Value e)
-> Stencil ix e e
makeCorrelationStencil !Sz ix
sSz !ix
sCenter (ix -> Value e -> Value e -> Value e) -> Value e -> Value e
relStencil = e -> Stencil ix e e -> Stencil ix e e
forall ix e a. Index ix => e -> Stencil ix e a -> Stencil ix e a
validateStencil e
0 (Stencil ix e e -> Stencil ix e e)
-> Stencil ix e e -> Stencil ix e e
forall a b. (a -> b) -> a -> b
$ Sz ix -> ix -> ((ix -> Value e) -> ix -> Value e) -> Stencil ix e e
forall ix e a.
Sz ix -> ix -> ((ix -> Value e) -> ix -> Value a) -> Stencil ix e a
Stencil Sz ix
sSz ix
sCenter (ix -> Value e) -> ix -> Value e
stencil
  where
    stencil :: (ix -> Value e) -> ix -> Value e
stencil ix -> Value e
getVal !ix
ix =
      (((ix -> Value e -> Value e -> Value e) -> Value e -> Value e)
-> (ix -> Value e -> Value e -> Value e) -> Value e -> Value e
forall a. a -> a
inline (ix -> Value e -> Value e -> Value e) -> Value e -> Value e
relStencil ((ix -> Value e -> Value e -> Value e) -> Value e -> Value e)
-> (ix -> Value e -> Value e -> Value e) -> Value e -> Value e
forall a b. (a -> b) -> a -> b
$ \ !ix
ixD !Value e
kVal !Value e
acc -> ix -> Value e
getVal ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) ix
ix ix
ixD) Value e -> Value e -> Value e
forall a. Num a => a -> a -> a
* Value e
kVal Value e -> Value e -> Value e
forall a. Num a => a -> a -> a
+ Value e
acc) Value e
0
    {-# INLINE stencil #-}
{-# INLINE makeCorrelationStencil #-}

-- | Make a <https://en.wikipedia.org/wiki/Cross-correlation cross-correlation> stencil out of a
-- Kernel Array. This `Stencil` will be slower than if `makeCorrelationStencil` is used, but
-- sometimes we just really don't know the kernel at compile time.
--
-- @since 0.1.5
makeCorrelationStencilFromKernel
  :: (Manifest r ix e, Num e)
  => Array r ix e
  -> Stencil ix e e
makeCorrelationStencilFromKernel :: Array r ix e -> Stencil ix e e
makeCorrelationStencilFromKernel Array r ix e
kArr = Sz ix -> ix -> ((ix -> Value e) -> ix -> Value e) -> Stencil ix e e
forall ix e a.
Sz ix -> ix -> ((ix -> Value e) -> ix -> Value a) -> Stencil ix e a
Stencil Sz ix
sz ix
sCenter (ix -> Value e) -> ix -> Value e
stencil
  where
    !sz :: Sz ix
sz = Array r ix e -> Sz ix
forall r ix e. Load r ix e => Array r ix e -> Sz ix
size Array r ix e
kArr
    !sCenter :: ix
sCenter = (Int -> Int) -> ix -> ix
forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) (ix -> ix) -> ix -> ix
forall a b. (a -> b) -> a -> b
$ Sz ix -> ix
forall ix. Sz ix -> ix
unSz Sz ix
sz
    stencil :: (ix -> Value e) -> ix -> Value e
stencil ix -> Value e
getVal !ix
ix = e -> Value e
forall e. e -> Value e
Value ((e -> ix -> e -> e) -> e -> Array r ix e -> e
forall r ix e a.
Source r ix e =>
(a -> ix -> e -> a) -> a -> Array r ix e -> a
ifoldlS e -> ix -> e -> e
accum e
0 Array r ix e
kArr) where
      !ixOff :: ix
ixOff = (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ix
ix ix
sCenter
      accum :: e -> ix -> e -> e
accum !e
acc !ix
kIx !e
kVal =
        Value e -> e
forall e. Value e -> e
unValue (ix -> Value e
getVal ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) ix
ixOff ix
kIx)) e -> e -> e
forall a. Num a => a -> a -> a
* e
kVal e -> e -> e
forall a. Num a => a -> a -> a
+ e
acc
      {-# INLINE accum #-}
    {-# INLINE stencil #-}
{-# INLINE makeCorrelationStencilFromKernel #-}