{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
module Data.Massiv.Array.Stencil.Convolution
( makeConvolutionStencil
, makeConvolutionStencilFromKernel
, makeCorrelationStencil
, makeCorrelationStencilFromKernel
) where
import Data.Massiv.Array.Ops.Fold (ifoldlS)
import Data.Massiv.Array.Stencil.Internal
import Data.Massiv.Core.Common
import GHC.Exts (inline)
makeConvolutionStencil
:: (Index ix, Num e)
=> Sz ix
-> ix
-> ((ix -> Value e -> Value e -> Value e) -> Value e -> Value e)
-> Stencil ix e e
makeConvolutionStencil :: Sz ix
-> ix
-> ((ix -> Value e -> Value e -> Value e) -> Value e -> Value e)
-> Stencil ix e e
makeConvolutionStencil !Sz ix
sz !ix
sCenter (ix -> Value e -> Value e -> Value e) -> Value e -> Value e
relStencil =
e -> Stencil ix e e -> Stencil ix e e
forall ix e a. Index ix => e -> Stencil ix e a -> Stencil ix e a
validateStencil e
0 (Stencil ix e e -> Stencil ix e e)
-> Stencil ix e e -> Stencil ix e e
forall a b. (a -> b) -> a -> b
$ Sz ix -> ix -> ((ix -> Value e) -> ix -> Value e) -> Stencil ix e e
forall ix e a.
Sz ix -> ix -> ((ix -> Value e) -> ix -> Value a) -> Stencil ix e a
Stencil Sz ix
sz ix
sInvertCenter (ix -> Value e) -> ix -> Value e
stencil
where
!sInvertCenter :: ix
sInvertCenter = (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ((Int -> Int) -> ix -> ix
forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1) (Sz ix -> ix
forall ix. Sz ix -> ix
unSz Sz ix
sz)) ix
sCenter
stencil :: (ix -> Value e) -> ix -> Value e
stencil ix -> Value e
getVal !ix
ix =
(((ix -> Value e -> Value e -> Value e) -> Value e -> Value e)
-> (ix -> Value e -> Value e -> Value e) -> Value e -> Value e
forall a. a -> a
inline (ix -> Value e -> Value e -> Value e) -> Value e -> Value e
relStencil ((ix -> Value e -> Value e -> Value e) -> Value e -> Value e)
-> (ix -> Value e -> Value e -> Value e) -> Value e -> Value e
forall a b. (a -> b) -> a -> b
$ \ !ix
ixD !Value e
kVal !Value e
acc -> ix -> Value e
getVal ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ix
ix ix
ixD) Value e -> Value e -> Value e
forall a. Num a => a -> a -> a
* Value e
kVal Value e -> Value e -> Value e
forall a. Num a => a -> a -> a
+ Value e
acc) Value e
0
{-# INLINE stencil #-}
{-# INLINE makeConvolutionStencil #-}
makeConvolutionStencilFromKernel
:: (Manifest r ix e, Num e)
=> Array r ix e
-> Stencil ix e e
makeConvolutionStencilFromKernel :: Array r ix e -> Stencil ix e e
makeConvolutionStencilFromKernel Array r ix e
kArr = Sz ix -> ix -> ((ix -> Value e) -> ix -> Value e) -> Stencil ix e e
forall ix e a.
Sz ix -> ix -> ((ix -> Value e) -> ix -> Value a) -> Stencil ix e a
Stencil Sz ix
sz ix
sInvertCenter (ix -> Value e) -> ix -> Value e
stencil
where
!sz :: Sz ix
sz@(Sz ix
szi) = Array r ix e -> Sz ix
forall r ix e. Load r ix e => Array r ix e -> Sz ix
size Array r ix e
kArr
!szi1 :: ix
szi1 = (Int -> Int) -> ix -> ix
forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1) ix
szi
!sInvertCenter :: ix
sInvertCenter = (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ix
szi1 ix
sCenter
!sCenter :: ix
sCenter = (Int -> Int) -> ix -> ix
forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
2) ix
szi
stencil :: (ix -> Value e) -> ix -> Value e
stencil ix -> Value e
getVal !ix
ix = e -> Value e
forall e. e -> Value e
Value ((e -> ix -> e -> e) -> e -> Array r ix e -> e
forall r ix e a.
Source r ix e =>
(a -> ix -> e -> a) -> a -> Array r ix e -> a
ifoldlS e -> ix -> e -> e
accum e
0 Array r ix e
kArr) where
!ixOff :: ix
ixOff = (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) ix
ix ix
sCenter
accum :: e -> ix -> e -> e
accum !e
acc !ix
kIx !e
kVal =
Value e -> e
forall e. Value e -> e
unValue (ix -> Value e
getVal ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ix
ixOff ix
kIx)) e -> e -> e
forall a. Num a => a -> a -> a
* e
kVal e -> e -> e
forall a. Num a => a -> a -> a
+ e
acc
{-# INLINE accum #-}
{-# INLINE stencil #-}
{-# INLINE makeConvolutionStencilFromKernel #-}
makeCorrelationStencil
:: (Index ix, Num e)
=> Sz ix
-> ix
-> ((ix -> Value e -> Value e -> Value e) -> Value e -> Value e)
-> Stencil ix e e
makeCorrelationStencil :: Sz ix
-> ix
-> ((ix -> Value e -> Value e -> Value e) -> Value e -> Value e)
-> Stencil ix e e
makeCorrelationStencil !Sz ix
sSz !ix
sCenter (ix -> Value e -> Value e -> Value e) -> Value e -> Value e
relStencil = e -> Stencil ix e e -> Stencil ix e e
forall ix e a. Index ix => e -> Stencil ix e a -> Stencil ix e a
validateStencil e
0 (Stencil ix e e -> Stencil ix e e)
-> Stencil ix e e -> Stencil ix e e
forall a b. (a -> b) -> a -> b
$ Sz ix -> ix -> ((ix -> Value e) -> ix -> Value e) -> Stencil ix e e
forall ix e a.
Sz ix -> ix -> ((ix -> Value e) -> ix -> Value a) -> Stencil ix e a
Stencil Sz ix
sSz ix
sCenter (ix -> Value e) -> ix -> Value e
stencil
where
stencil :: (ix -> Value e) -> ix -> Value e
stencil ix -> Value e
getVal !ix
ix =
(((ix -> Value e -> Value e -> Value e) -> Value e -> Value e)
-> (ix -> Value e -> Value e -> Value e) -> Value e -> Value e
forall a. a -> a
inline (ix -> Value e -> Value e -> Value e) -> Value e -> Value e
relStencil ((ix -> Value e -> Value e -> Value e) -> Value e -> Value e)
-> (ix -> Value e -> Value e -> Value e) -> Value e -> Value e
forall a b. (a -> b) -> a -> b
$ \ !ix
ixD !Value e
kVal !Value e
acc -> ix -> Value e
getVal ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) ix
ix ix
ixD) Value e -> Value e -> Value e
forall a. Num a => a -> a -> a
* Value e
kVal Value e -> Value e -> Value e
forall a. Num a => a -> a -> a
+ Value e
acc) Value e
0
{-# INLINE stencil #-}
{-# INLINE makeCorrelationStencil #-}
makeCorrelationStencilFromKernel
:: (Manifest r ix e, Num e)
=> Array r ix e
-> Stencil ix e e
makeCorrelationStencilFromKernel :: Array r ix e -> Stencil ix e e
makeCorrelationStencilFromKernel Array r ix e
kArr = Sz ix -> ix -> ((ix -> Value e) -> ix -> Value e) -> Stencil ix e e
forall ix e a.
Sz ix -> ix -> ((ix -> Value e) -> ix -> Value a) -> Stencil ix e a
Stencil Sz ix
sz ix
sCenter (ix -> Value e) -> ix -> Value e
stencil
where
!sz :: Sz ix
sz = Array r ix e -> Sz ix
forall r ix e. Load r ix e => Array r ix e -> Sz ix
size Array r ix e
kArr
!sCenter :: ix
sCenter = (Int -> Int) -> ix -> ix
forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) (ix -> ix) -> ix -> ix
forall a b. (a -> b) -> a -> b
$ Sz ix -> ix
forall ix. Sz ix -> ix
unSz Sz ix
sz
stencil :: (ix -> Value e) -> ix -> Value e
stencil ix -> Value e
getVal !ix
ix = e -> Value e
forall e. e -> Value e
Value ((e -> ix -> e -> e) -> e -> Array r ix e -> e
forall r ix e a.
Source r ix e =>
(a -> ix -> e -> a) -> a -> Array r ix e -> a
ifoldlS e -> ix -> e -> e
accum e
0 Array r ix e
kArr) where
!ixOff :: ix
ixOff = (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ix
ix ix
sCenter
accum :: e -> ix -> e -> e
accum !e
acc !ix
kIx !e
kVal =
Value e -> e
forall e. Value e -> e
unValue (ix -> Value e
getVal ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) ix
ixOff ix
kIx)) e -> e -> e
forall a. Num a => a -> a -> a
* e
kVal e -> e -> e
forall a. Num a => a -> a -> a
+ e
acc
{-# INLINE accum #-}
{-# INLINE stencil #-}
{-# INLINE makeCorrelationStencilFromKernel #-}