{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
-- |
-- Module      : Data.Massiv.Array.Stencil.Convolution
-- Copyright   : (c) Alexey Kuleshevich 2018-2021
-- License     : BSD3
-- Maintainer  : Alexey Kuleshevich <lehins@yandex.ru>
-- Stability   : experimental
-- Portability : non-portable
--
module Data.Massiv.Array.Stencil.Convolution
  ( makeConvolutionStencil
  , makeConvolutionStencilFromKernel
  , makeCorrelationStencil
  , makeCorrelationStencilFromKernel
  ) where

import Data.Massiv.Array.Ops.Fold (ifoldlS)
import Data.Massiv.Array.Stencil.Internal
import Data.Massiv.Core.Common
import GHC.Exts (inline)

-- | Create a convolution stencil by specifying border resolution technique and
-- an accumulator function.
--
-- /Note/ - Using `Data.Massiv.Array.Stencil.Unsafe.makeUnsafeConvolutionStencil` will be
-- much faster, therefore it is recommended to switch from this function, after manual
-- verification that the created stencil behaves as expected.
--
-- ==== __Examples__
--
-- Here is how to create a 2D horizontal Sobel Stencil:
--
-- > sobelX :: Num e => Stencil Ix2 e e
-- > sobelX = makeConvolutionStencil (Sz2 3 3) (1 :. 1) $
-- >            \f -> f (-1 :. -1) (-1) . f (-1 :. 1) 1 .
-- >                  f ( 0 :. -1) (-2) . f ( 0 :. 1) 2 .
-- >                  f ( 1 :. -1) (-1) . f ( 1 :. 1) 1
-- > {-# INLINE sobelX #-}
--
-- @since 0.1.0
makeConvolutionStencil
  :: (Index ix, Num e)
  => Sz ix
  -> ix
  -> ((ix -> e -> e -> e) -> e -> e)
  -> Stencil ix e e
makeConvolutionStencil :: Sz ix -> ix -> ((ix -> e -> e -> e) -> e -> e) -> Stencil ix e e
makeConvolutionStencil !Sz ix
sz !ix
sCenter (ix -> e -> e -> e) -> e -> e
relStencil =
  Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> e) -> Stencil ix e e
forall ix e a.
Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
Stencil Sz ix
sz ix
sInvertCenter (ix -> e) -> (ix -> e) -> ix -> e
forall p. p -> (ix -> e) -> ix -> e
stencil
  where
    !sInvertCenter :: ix
sInvertCenter = (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ((Int -> Int) -> ix -> ix
forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1) (Sz ix -> ix
forall ix. Sz ix -> ix
unSz Sz ix
sz)) ix
sCenter
    stencil :: p -> (ix -> e) -> ix -> e
stencil p
_ ix -> e
getVal !ix
ix =
      (((ix -> e -> e -> e) -> e -> e) -> (ix -> e -> e -> e) -> e -> e
forall a. a -> a
inline (ix -> e -> e -> e) -> e -> e
relStencil ((ix -> e -> e -> e) -> e -> e) -> (ix -> e -> e -> e) -> e -> e
forall a b. (a -> b) -> a -> b
$ \ !ix
ixD !e
kVal !e
acc -> ix -> e
getVal ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ix
ix ix
ixD) e -> e -> e
forall a. Num a => a -> a -> a
* e
kVal e -> e -> e
forall a. Num a => a -> a -> a
+ e
acc) e
0
    {-# INLINE stencil #-}
{-# INLINE makeConvolutionStencil #-}


-- | Make a stencil out of a Kernel Array. This `Stencil` will be slower than if
-- `makeConvolutionStencil` is used, but sometimes we just really don't know the
-- kernel at compile time.
--
-- @since 0.1.0
makeConvolutionStencilFromKernel
  :: (Manifest r e, Index ix, Num e)
  => Array r ix e
  -> Stencil ix e e
makeConvolutionStencilFromKernel :: Array r ix e -> Stencil ix e e
makeConvolutionStencilFromKernel Array r ix e
kArr = Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> e) -> Stencil ix e e
forall ix e a.
Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
Stencil Sz ix
sz ix
sInvertCenter (ix -> e) -> (ix -> e) -> ix -> e
forall p. (ix -> e) -> p -> ix -> e
stencil
  where
    !sz :: Sz ix
sz@(Sz ix
szi) = Array r ix e -> Sz ix
forall r ix e. Size r => Array r ix e -> Sz ix
size Array r ix e
kArr
    !szi1 :: ix
szi1 = (Int -> Int) -> ix -> ix
forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1) ix
szi
    !sInvertCenter :: ix
sInvertCenter = (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ix
szi1 ix
sCenter
    !sCenter :: ix
sCenter = (Int -> Int) -> ix -> ix
forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
2) ix
szi
    stencil :: (ix -> e) -> p -> ix -> e
stencil ix -> e
uget p
_ !ix
ix = (e -> ix -> e -> e) -> e -> Array r ix e -> e
forall ix r e a.
(Index ix, Source r e) =>
(a -> ix -> e -> a) -> a -> Array r ix e -> a
ifoldlS e -> ix -> e -> e
accum e
0 Array r ix e
kArr where
      !ixOff :: ix
ixOff = (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) ix
ix ix
sCenter
      accum :: e -> ix -> e -> e
accum !e
acc !ix
kIx !e
kVal = ix -> e
uget ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ix
ixOff ix
kIx) e -> e -> e
forall a. Num a => a -> a -> a
* e
kVal e -> e -> e
forall a. Num a => a -> a -> a
+ e
acc
      {-# INLINE accum #-}
    {-# INLINE stencil #-}
{-# INLINE makeConvolutionStencilFromKernel #-}


-- | Make a <https://en.wikipedia.org/wiki/Cross-correlation cross-correlation> stencil
--
-- /Note/ - Using `Data.Massiv.Array.Stencil.Unsafe.makeUnsafeCorrelationStencil` will be
-- much faster, therefore it is recommended to switch from this function, after manual
-- verification that the created stencil behaves as expected.
--
-- @since 0.1.5
makeCorrelationStencil
  :: (Index ix, Num e)
  => Sz ix
  -> ix
  -> ((ix -> e -> e -> e) -> e -> e)
  -> Stencil ix e e
makeCorrelationStencil :: Sz ix -> ix -> ((ix -> e -> e -> e) -> e -> e) -> Stencil ix e e
makeCorrelationStencil !Sz ix
sSz !ix
sCenter (ix -> e -> e -> e) -> e -> e
relStencil = Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> e) -> Stencil ix e e
forall ix e a.
Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
Stencil Sz ix
sSz ix
sCenter (ix -> e) -> (ix -> e) -> ix -> e
forall p. p -> (ix -> e) -> ix -> e
stencil
  where
    stencil :: p -> (ix -> e) -> ix -> e
stencil p
_ ix -> e
getVal !ix
ix =
      (((ix -> e -> e -> e) -> e -> e) -> (ix -> e -> e -> e) -> e -> e
forall a. a -> a
inline (ix -> e -> e -> e) -> e -> e
relStencil ((ix -> e -> e -> e) -> e -> e) -> (ix -> e -> e -> e) -> e -> e
forall a b. (a -> b) -> a -> b
$ \ !ix
ixD !e
kVal !e
acc -> ix -> e
getVal ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) ix
ix ix
ixD) e -> e -> e
forall a. Num a => a -> a -> a
* e
kVal e -> e -> e
forall a. Num a => a -> a -> a
+ e
acc) e
0
    {-# INLINE stencil #-}
{-# INLINE makeCorrelationStencil #-}

-- | Make a <https://en.wikipedia.org/wiki/Cross-correlation cross-correlation> stencil out of a
-- Kernel Array. This `Stencil` will be slower than if `makeCorrelationStencil` is used, but
-- sometimes we just really don't know the kernel at compile time.
--
-- @since 0.1.5
makeCorrelationStencilFromKernel
  :: (Manifest r e, Index ix, Num e)
  => Array r ix e
  -> Stencil ix e e
makeCorrelationStencilFromKernel :: Array r ix e -> Stencil ix e e
makeCorrelationStencilFromKernel Array r ix e
kArr = Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> e) -> Stencil ix e e
forall ix e a.
Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
Stencil Sz ix
sz ix
sCenter (ix -> e) -> (ix -> e) -> ix -> e
forall p. (ix -> e) -> p -> ix -> e
stencil
  where
    !sz :: Sz ix
sz = Array r ix e -> Sz ix
forall r ix e. Size r => Array r ix e -> Sz ix
size Array r ix e
kArr
    !sCenter :: ix
sCenter = (Int -> Int) -> ix -> ix
forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) (ix -> ix) -> ix -> ix
forall a b. (a -> b) -> a -> b
$ Sz ix -> ix
forall ix. Sz ix -> ix
unSz Sz ix
sz
    stencil :: (ix -> e) -> p -> ix -> e
stencil ix -> e
uget p
_ !ix
ix = (e -> ix -> e -> e) -> e -> Array r ix e -> e
forall ix r e a.
(Index ix, Source r e) =>
(a -> ix -> e -> a) -> a -> Array r ix e -> a
ifoldlS e -> ix -> e -> e
accum e
0 Array r ix e
kArr where
      !ixOff :: ix
ixOff = (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ix
ix ix
sCenter
      accum :: e -> ix -> e -> e
accum !e
acc !ix
kIx !e
kVal = ix -> e
uget ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) ix
ixOff ix
kIx) e -> e -> e
forall a. Num a => a -> a -> a
* e
kVal e -> e -> e
forall a. Num a => a -> a -> a
+ e
acc
      {-# INLINE accum #-}
    {-# INLINE stencil #-}
{-# INLINE makeCorrelationStencilFromKernel #-}