{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
-- |
-- Module      : Data.Massiv.Array.Stencil
-- Copyright   : (c) Alexey Kuleshevich 2018-2021
-- License     : BSD3
-- Maintainer  : Alexey Kuleshevich <lehins@yandex.ru>
-- Stability   : experimental
-- Portability : non-portable
--
module Data.Massiv.Array.Stencil
  ( -- * Stencil
    Stencil
  , makeStencil
  , getStencilSize
  , getStencilCenter
  -- ** Padding
  , Padding(..)
  , noPadding
  , samePadding
  -- ** Application
  , mapStencil
  , applyStencil
  -- ** Common stencils
  , idStencil
  , sumStencil
  , productStencil
  , avgStencil
  , maxStencil
  , minStencil
  , foldlStencil
  , foldrStencil
  , foldStencil
  -- ** Profunctor
  , dimapStencil
  , lmapStencil
  , rmapStencil
  -- * Convolution
  , module Data.Massiv.Array.Stencil.Convolution
  ) where

import Data.Coerce
import Data.Massiv.Array.Delayed.Windowed
import Data.Massiv.Array.Manifest
import Data.Massiv.Array.Stencil.Convolution
import Data.Massiv.Array.Stencil.Internal
import Data.Massiv.Array.Stencil.Unsafe
import Data.Massiv.Core.Common
import Data.Semigroup
import GHC.Exts (inline)

-- | Get the size of the stencil
--
-- @since 0.4.3
getStencilSize :: Stencil ix e a -> Sz ix
getStencilSize :: Stencil ix e a -> Sz ix
getStencilSize = Stencil ix e a -> Sz ix
forall ix e a. Stencil ix e a -> Sz ix
stencilSize

-- | Get the index of the stencil's center
--
-- @since 0.4.3
getStencilCenter :: Stencil ix e a -> ix
getStencilCenter :: Stencil ix e a -> ix
getStencilCenter = Stencil ix e a -> ix
forall ix e a. Stencil ix e a -> ix
stencilCenter

-- | Map a constructed stencil over an array. Resulting array must be
-- `Data.Massiv.Array.compute`d in order to be useful.
--
-- @since 0.1.0
mapStencil ::
     (Index ix, Manifest r e)
  => Border e -- ^ Border resolution technique
  -> Stencil ix e a -- ^ Stencil to map over the array
  -> Array r ix e -- ^ Source array
  -> Array DW ix a
mapStencil :: Border e -> Stencil ix e a -> Array r ix e -> Array DW ix a
mapStencil Border e
b Stencil ix e a
stencil = Padding ix e -> Stencil ix e a -> Array r ix e -> Array DW ix a
forall ix r e a.
(Index ix, Manifest r e) =>
Padding ix e -> Stencil ix e a -> Array r ix e -> Array DW ix a
applyStencil (Stencil ix e a -> Border e -> Padding ix e
forall ix e a.
Index ix =>
Stencil ix e a -> Border e -> Padding ix e
samePadding Stencil ix e a
stencil Border e
b) Stencil ix e a
stencil
{-# INLINE mapStencil #-}


-- | Padding of the source array before stencil application.
--
-- ==== __Examples__
--
-- In order to see the affect of padding we can simply apply an identity stencil to an
-- array:
--
-- >>> import Data.Massiv.Array as A
-- >>> a = computeAs P $ resize' (Sz2 2 3) (Ix1 1 ... 6)
-- >>> applyStencil noPadding idStencil a
-- Array DW Seq (Sz (2 :. 3))
--   [ [ 1, 2, 3 ]
--   , [ 4, 5, 6 ]
--   ]
-- >>> applyStencil (Padding (Sz2 1 2) (Sz2 3 4) (Fill 0)) idStencil a
-- Array DW Seq (Sz (6 :. 9))
--   [ [ 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
--   , [ 0, 0, 1, 2, 3, 0, 0, 0, 0 ]
--   , [ 0, 0, 4, 5, 6, 0, 0, 0, 0 ]
--   , [ 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
--   , [ 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
--   , [ 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
--   ]
--
-- It is also a nice technique to see the border resolution strategy in action:
--
-- >>> applyStencil (Padding (Sz2 2 3) (Sz2 2 3) Wrap) idStencil a
-- Array DW Seq (Sz (6 :. 9))
--   [ [ 1, 2, 3, 1, 2, 3, 1, 2, 3 ]
--   , [ 4, 5, 6, 4, 5, 6, 4, 5, 6 ]
--   , [ 1, 2, 3, 1, 2, 3, 1, 2, 3 ]
--   , [ 4, 5, 6, 4, 5, 6, 4, 5, 6 ]
--   , [ 1, 2, 3, 1, 2, 3, 1, 2, 3 ]
--   , [ 4, 5, 6, 4, 5, 6, 4, 5, 6 ]
--   ]
-- >>> applyStencil (Padding (Sz2 2 3) (Sz2 2 3) Edge) idStencil a
-- Array DW Seq (Sz (6 :. 9))
--   [ [ 1, 1, 1, 1, 2, 3, 3, 3, 3 ]
--   , [ 1, 1, 1, 1, 2, 3, 3, 3, 3 ]
--   , [ 1, 1, 1, 1, 2, 3, 3, 3, 3 ]
--   , [ 4, 4, 4, 4, 5, 6, 6, 6, 6 ]
--   , [ 4, 4, 4, 4, 5, 6, 6, 6, 6 ]
--   , [ 4, 4, 4, 4, 5, 6, 6, 6, 6 ]
--   ]
-- >>> applyStencil (Padding (Sz2 2 3) (Sz2 2 3) Reflect) idStencil a
-- Array DW Seq (Sz (6 :. 9))
--   [ [ 6, 5, 4, 4, 5, 6, 6, 5, 4 ]
--   , [ 3, 2, 1, 1, 2, 3, 3, 2, 1 ]
--   , [ 3, 2, 1, 1, 2, 3, 3, 2, 1 ]
--   , [ 6, 5, 4, 4, 5, 6, 6, 5, 4 ]
--   , [ 6, 5, 4, 4, 5, 6, 6, 5, 4 ]
--   , [ 3, 2, 1, 1, 2, 3, 3, 2, 1 ]
--   ]
-- >>> applyStencil (Padding (Sz2 2 3) (Sz2 2 3) Continue) idStencil a
-- Array DW Seq (Sz (6 :. 9))
--   [ [ 1, 3, 2, 1, 2, 3, 2, 1, 3 ]
--   , [ 4, 6, 5, 4, 5, 6, 5, 4, 6 ]
--   , [ 1, 3, 2, 1, 2, 3, 2, 1, 3 ]
--   , [ 4, 6, 5, 4, 5, 6, 5, 4, 6 ]
--   , [ 1, 3, 2, 1, 2, 3, 2, 1, 3 ]
--   , [ 4, 6, 5, 4, 5, 6, 5, 4, 6 ]
--   ]
--
-- @since 0.4.3
data Padding ix e = Padding
  { Padding ix e -> Sz ix
paddingFromOrigin  :: !(Sz ix)
  , Padding ix e -> Sz ix
paddingFromBottom  :: !(Sz ix)
  , Padding ix e -> Border e
paddingWithElement :: !(Border e)
  -- ^ Element to do padding with
  } deriving (Padding ix e -> Padding ix e -> Bool
(Padding ix e -> Padding ix e -> Bool)
-> (Padding ix e -> Padding ix e -> Bool) -> Eq (Padding ix e)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall ix e. (Eq ix, Eq e) => Padding ix e -> Padding ix e -> Bool
/= :: Padding ix e -> Padding ix e -> Bool
$c/= :: forall ix e. (Eq ix, Eq e) => Padding ix e -> Padding ix e -> Bool
== :: Padding ix e -> Padding ix e -> Bool
$c== :: forall ix e. (Eq ix, Eq e) => Padding ix e -> Padding ix e -> Bool
Eq, Int -> Padding ix e -> ShowS
[Padding ix e] -> ShowS
Padding ix e -> String
(Int -> Padding ix e -> ShowS)
-> (Padding ix e -> String)
-> ([Padding ix e] -> ShowS)
-> Show (Padding ix e)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall ix e. (Index ix, Show e) => Int -> Padding ix e -> ShowS
forall ix e. (Index ix, Show e) => [Padding ix e] -> ShowS
forall ix e. (Index ix, Show e) => Padding ix e -> String
showList :: [Padding ix e] -> ShowS
$cshowList :: forall ix e. (Index ix, Show e) => [Padding ix e] -> ShowS
show :: Padding ix e -> String
$cshow :: forall ix e. (Index ix, Show e) => Padding ix e -> String
showsPrec :: Int -> Padding ix e -> ShowS
$cshowsPrec :: forall ix e. (Index ix, Show e) => Int -> Padding ix e -> ShowS
Show)

-- | Also known as "valid" padding. When stencil is applied to an array, that array will
-- shrink, unless the stencil is of size 1.
--
-- @since 0.4.3
noPadding :: Index ix => Padding ix e
noPadding :: Padding ix e
noPadding = Sz ix -> Sz ix -> Border e -> Padding ix e
forall ix e. Sz ix -> Sz ix -> Border e -> Padding ix e
Padding Sz ix
forall ix. Index ix => Sz ix
zeroSz Sz ix
forall ix. Index ix => Sz ix
zeroSz Border e
forall e. Border e
Edge

-- | Padding that matches the size of the stencil, which is known as "same" padding,
-- because when a stencil is applied to an array with such matching padding, the resulting
-- array will be of the same size as the source array. This is exactly the behavior of
-- `mapStencil`
--
-- @since 0.4.3
samePadding :: Index ix => Stencil ix e a -> Border e -> Padding ix e
samePadding :: Stencil ix e a -> Border e -> Padding ix e
samePadding (Stencil (Sz ix
sSz) ix
sCenter (ix -> e) -> (ix -> e) -> ix -> a
_) Border e
border =
  Padding :: forall ix e. Sz ix -> Sz ix -> Border e -> Padding ix e
Padding
    { paddingFromOrigin :: Sz ix
paddingFromOrigin = ix -> Sz ix
forall ix. Index ix => ix -> Sz ix
Sz ix
sCenter
    , paddingFromBottom :: Sz ix
paddingFromBottom = ix -> Sz ix
forall ix. Index ix => ix -> Sz ix
Sz ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ix
sSz ((Int -> Int) -> ix -> ix
forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) ix
sCenter))
    , paddingWithElement :: Border e
paddingWithElement = Border e
border
    }


-- | Apply a constructed stencil over an array. Resulting array must be
-- `Data.Massiv.Array.compute`d in order to be useful. Unlike `mapStencil`, the size of
-- the resulting array will not necesserally be the same as the source array, which will
-- depend on the padding.
--
-- @since 0.4.3
applyStencil ::
     (Index ix, Manifest r e)
  => Padding ix e
  -- ^ Padding to be applied to the source array. This will dictate the resulting size of
  -- the array. No padding will cause it to shrink by the size of the stencil
  -> Stencil ix e a -- ^ Stencil to apply to the array
  -> Array r ix e -- ^ Source array
  -> Array DW ix a
applyStencil :: Padding ix e -> Stencil ix e a -> Array r ix e -> Array DW ix a
applyStencil (Padding (Sz ix
po) (Sz ix
pb) Border e
border) (Stencil Sz ix
sSz ix
sCenter (ix -> e) -> (ix -> e) -> ix -> a
stencilF) !Array r ix e
arr =
  Array D ix a -> Window ix a -> Array DW ix a
forall ix e.
Index ix =>
Array D ix e -> Window ix e -> Array DW ix e
insertWindow Array D ix a
warr Window ix a
window
  where
    !offset :: ix
offset = (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ix
sCenter ix
po
    !warr :: Array D ix a
warr =
      Comp -> Sz ix -> (ix -> a) -> Array D ix a
forall ix e. Comp -> Sz ix -> (ix -> e) -> Array D ix e
DArray
        (Array r ix e -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
getComp Array r ix e
arr)
        Sz ix
sz
        ((ix -> e) -> (ix -> e) -> ix -> a
stencilF (Border e -> Array r ix e -> ix -> e
forall ix r e.
(Index ix, Manifest r e) =>
Border e -> Array r ix e -> ix -> e
borderIndex Border e
border Array r ix e
arr) (Border e -> Array r ix e -> ix -> e
forall ix r e.
(Index ix, Manifest r e) =>
Border e -> Array r ix e -> ix -> e
borderIndex Border e
border Array r ix e
arr) (ix -> a) -> (ix -> ix) -> ix -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) ix
offset)
    -- Size by which the resulting array will shrink (not accounting for padding)
    !shrinkSz :: Sz ix
shrinkSz = ix -> Sz ix
forall ix. Index ix => ix -> Sz ix
Sz ((Int -> Int) -> ix -> ix
forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1) (Sz ix -> ix
forall ix. Sz ix -> ix
unSz Sz ix
sSz))
    !sz :: Sz ix
sz = (Int -> Int -> Int) -> Sz ix -> Sz ix -> Sz ix
forall ix.
Index ix =>
(Int -> Int -> Int) -> Sz ix -> Sz ix -> Sz ix
liftSz2 (-) (ix -> Sz ix
forall ix. ix -> Sz ix
SafeSz ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) ix
po ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) ix
pb (Sz ix -> ix
forall ix. Sz ix -> ix
unSz (Array r ix e -> Sz ix
forall r ix e. Size r => Array r ix e -> Sz ix
size Array r ix e
arr))))) Sz ix
shrinkSz
    !wsz :: Sz ix
wsz = (Int -> Int -> Int) -> Sz ix -> Sz ix -> Sz ix
forall ix.
Index ix =>
(Int -> Int -> Int) -> Sz ix -> Sz ix -> Sz ix
liftSz2 (-) (Array r ix e -> Sz ix
forall r ix e. Size r => Array r ix e -> Sz ix
size Array r ix e
arr) Sz ix
shrinkSz
    !window :: Window ix a
window =
      Window :: forall ix e. ix -> Sz ix -> (ix -> e) -> Maybe Int -> Window ix e
Window
        { windowStart :: ix
windowStart = ix
po
        , windowSize :: Sz ix
windowSize = Sz ix
wsz
        , windowIndex :: ix -> a
windowIndex = (ix -> e) -> (ix -> e) -> ix -> a
stencilF (Array r ix e -> ix -> e
forall r e ix. (Source r e, Index ix) => Array r ix e -> ix -> e
unsafeIndex Array r ix e
arr) (Array r ix e -> ix -> e
forall ix r e.
(HasCallStack, Index ix, Manifest r e) =>
Array r ix e -> ix -> e
index' Array r ix e
arr) (ix -> a) -> (ix -> ix) -> ix -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) ix
offset
        , windowUnrollIx2 :: Maybe Int
windowUnrollIx2 = Sz Int -> Int
forall ix. Sz ix -> ix
unSz (Sz Int -> Int)
-> ((Sz Int, Sz (Lower ix)) -> Sz Int)
-> (Sz Int, Sz (Lower ix))
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Sz Int, Sz (Lower ix)) -> Sz Int
forall a b. (a, b) -> a
fst ((Sz Int, Sz (Lower ix)) -> Int)
-> Maybe (Sz Int, Sz (Lower ix)) -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Sz ix -> Dim -> Maybe (Sz Int, Sz (Lower ix))
forall (m :: * -> *) ix.
(MonadThrow m, Index ix) =>
Sz ix -> Dim -> m (Sz Int, Sz (Lower ix))
pullOutSzM Sz ix
sSz Dim
2
        }
{-# INLINE applyStencil #-}


-- | Construct a stencil from a function, which describes how to calculate the
-- value at a point while having access to neighboring elements with a function
-- that accepts idices relative to the center of stencil. Trying to index
-- outside the stencil box will result in a runtime error upon stencil
-- creation.
--
-- /Note/ - Once correctness of stencil is verified then switching to `makeUnsafeStencil`
-- is recommended in order to get the most performance out of the `Stencil`
--
-- ==== __Example__
--
-- Below is an example of creating a `Stencil`, which, when mapped over a
-- 2-dimensional array, will compute an average of all elements in a 3x3 square
-- for each element in that array.
--
-- /Note/ - Make sure to add an @INLINE@ pragma, otherwise performance will be terrible.
--
-- > average3x3Stencil :: Fractional a => Stencil Ix2 a a
-- > average3x3Stencil = makeStencil (Sz (3 :. 3)) (1 :. 1) $ \ get ->
-- >   (  get (-1 :. -1) + get (-1 :. 0) + get (-1 :. 1) +
-- >      get ( 0 :. -1) + get ( 0 :. 0) + get ( 0 :. 1) +
-- >      get ( 1 :. -1) + get ( 1 :. 0) + get ( 1 :. 1)   ) / 9
-- > {-# INLINE average3x3Stencil #-}
--
-- @since 0.1.0
makeStencil
  :: Index ix
  => Sz ix -- ^ Size of the stencil
  -> ix -- ^ Center of the stencil
  -> ((ix -> e) -> a)
  -- ^ Stencil function that receives a "get" function as it's argument that can
  -- retrieve values of cells in the source array with respect to the center of
  -- the stencil. Stencil function must return a value that will be assigned to
  -- the cell in the result array. Offset supplied to the "get" function
  -- cannot go outside the boundaries of the stencil, otherwise an error will be
  -- raised during stencil creation.
  -> Stencil ix e a
makeStencil :: Sz ix -> ix -> ((ix -> e) -> a) -> Stencil ix e a
makeStencil !Sz ix
sSz !ix
sCenter (ix -> e) -> a
relStencil = Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
forall ix e a.
Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
Stencil Sz ix
sSz ix
sCenter (ix -> e) -> (ix -> e) -> ix -> a
stencil
  where
    stencil :: (ix -> e) -> (ix -> e) -> ix -> a
stencil ix -> e
_ ix -> e
getVal !ix
ix =
      ((ix -> e) -> a) -> (ix -> e) -> a
forall a. a -> a
inline (ix -> e) -> a
relStencil ((ix -> e) -> a) -> (ix -> e) -> a
forall a b. (a -> b) -> a -> b
$ \ !ix
ixD -> ix -> e
getVal ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) ix
ix ix
ixD)
    {-# INLINE stencil #-}
{-# INLINE makeStencil #-}

-- | Identity stencil that does not change the elements of the source array.
--
-- @since 0.4.3
idStencil :: Index ix => Stencil ix e e
idStencil :: Stencil ix e e
idStencil = Sz ix -> ix -> (ix -> (ix -> e) -> e) -> Stencil ix e e
forall ix e a.
Index ix =>
Sz ix -> ix -> (ix -> (ix -> e) -> a) -> Stencil ix e a
makeUnsafeStencil Sz ix
forall ix. Index ix => Sz ix
oneSz ix
forall ix. Index ix => ix
zeroIndex ((ix -> (ix -> e) -> e) -> Stencil ix e e)
-> (ix -> (ix -> e) -> e) -> Stencil ix e e
forall a b. (a -> b) -> a -> b
$ \ ix
_ ix -> e
get -> ix -> e
get ix
forall ix. Index ix => ix
zeroIndex
{-# INLINE idStencil #-}


-- | Stencil that does a left fold in a row-major order. Regardless of the supplied size
-- resulting stencil will be centered at zero, although by using `Padding` it is possible
-- to overcome this limitation.
--
-- ==== __Examples__
--
-- >>> import Data.Massiv.Array as A
-- >>> a = computeAs P $ iterateN (Sz2 3 4) (+1) (10 :: Int)
-- >>> a
-- Array P Seq (Sz (3 :. 4))
--   [ [ 11, 12, 13, 14 ]
--   , [ 15, 16, 17, 18 ]
--   , [ 19, 20, 21, 22 ]
--   ]
-- >>> applyStencil noPadding (foldlStencil (flip (:)) [] (Sz2 3 2)) a
-- Array DW Seq (Sz (1 :. 3))
--   [ [ [20,19,16,15,12,11], [21,20,17,16,13,12], [22,21,18,17,14,13] ]
--   ]
-- >>> applyStencil (Padding (Sz2 1 0) 0 (Fill 10)) (foldlStencil (flip (:)) [] (Sz2 3 2)) a
-- Array DW Seq (Sz (2 :. 3))
--   [ [ [16,15,12,11,10,10], [17,16,13,12,10,10], [18,17,14,13,10,10] ]
--   , [ [20,19,16,15,12,11], [21,20,17,16,13,12], [22,21,18,17,14,13] ]
--   ]
--
-- @since 0.4.3
foldlStencil :: Index ix => (a -> e -> a) -> a -> Sz ix -> Stencil ix e a
foldlStencil :: (a -> e -> a) -> a -> Sz ix -> Stencil ix e a
foldlStencil a -> e -> a
f a
acc0 Sz ix
sz =
  Sz ix -> ix -> (ix -> (ix -> e) -> a) -> Stencil ix e a
forall ix e a.
Index ix =>
Sz ix -> ix -> (ix -> (ix -> e) -> a) -> Stencil ix e a
makeUnsafeStencil Sz ix
sz ix
forall ix. Index ix => ix
zeroIndex ((ix -> (ix -> e) -> a) -> Stencil ix e a)
-> (ix -> (ix -> e) -> a) -> Stencil ix e a
forall a b. (a -> b) -> a -> b
$ \ix
_ ix -> e
get ->
    ix -> ix -> ix -> (Int -> Int -> Bool) -> a -> (ix -> a -> a) -> a
forall ix a.
Index ix =>
ix -> ix -> ix -> (Int -> Int -> Bool) -> a -> (ix -> a -> a) -> a
iter ix
forall ix. Index ix => ix
zeroIndex (Sz ix -> ix
forall ix. Sz ix -> ix
unSz Sz ix
sz) ix
forall ix. Index ix => ix
oneIndex Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
(<) a
acc0 ((ix -> a -> a) -> a) -> (ix -> a -> a) -> a
forall a b. (a -> b) -> a -> b
$ \ix
ix -> (a -> e -> a
`f` ix -> e
get ix
ix)
{-# INLINE foldlStencil #-}

-- | Stencil that does a right fold in a row-major order. Regardless of the supplied size
-- resulting stencil will be centered at zero, although by using `Padding` it is possible
-- to overcome this limitation.
--
-- ==== __Examples__
--
-- >>> import Data.Massiv.Array as A
-- >>> a = computeAs P $ iterateN (Sz2 3 4) (+1) (10 :: Int)
-- >>> a
-- Array P Seq (Sz (3 :. 4))
--   [ [ 11, 12, 13, 14 ]
--   , [ 15, 16, 17, 18 ]
--   , [ 19, 20, 21, 22 ]
--   ]
-- >>> applyStencil noPadding (foldrStencil (:) [] (Sz2 2 3)) a
-- Array DW Seq (Sz (2 :. 2))
--   [ [ [11,12,13,15,16,17], [12,13,14,16,17,18] ]
--   , [ [15,16,17,19,20,21], [16,17,18,20,21,22] ]
--   ]
--
-- @since 0.4.3
foldrStencil :: Index ix => (e -> a -> a) -> a -> Sz ix -> Stencil ix e a
foldrStencil :: (e -> a -> a) -> a -> Sz ix -> Stencil ix e a
foldrStencil e -> a -> a
f a
acc0 Sz ix
sz =
  let ixStart :: ix
ixStart = (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) (Sz ix -> ix
forall ix. Sz ix -> ix
unSz Sz ix
sz) ix
forall ix. Index ix => ix
oneIndex
   in Sz ix -> ix -> (ix -> (ix -> e) -> a) -> Stencil ix e a
forall ix e a.
Index ix =>
Sz ix -> ix -> (ix -> (ix -> e) -> a) -> Stencil ix e a
makeUnsafeStencil Sz ix
sz ix
forall ix. Index ix => ix
zeroIndex ((ix -> (ix -> e) -> a) -> Stencil ix e a)
-> (ix -> (ix -> e) -> a) -> Stencil ix e a
forall a b. (a -> b) -> a -> b
$ \ix
_ ix -> e
get ->
        ix -> ix -> ix -> (Int -> Int -> Bool) -> a -> (ix -> a -> a) -> a
forall ix a.
Index ix =>
ix -> ix -> ix -> (Int -> Int -> Bool) -> a -> (ix -> a -> a) -> a
iter ix
ixStart ix
forall ix. Index ix => ix
zeroIndex (Int -> ix
forall ix. Index ix => Int -> ix
pureIndex (-Int
1)) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
(>=) a
acc0 ((ix -> a -> a) -> a) -> (ix -> a -> a) -> a
forall a b. (a -> b) -> a -> b
$ \ix
ix -> e -> a -> a
f (ix -> e
get ix
ix)
{-# INLINE foldrStencil #-}


-- | Create a stencil that will fold all elements in the region monoidally.
--
-- @since 0.4.3
foldStencil :: (Monoid e, Index ix) => Sz ix -> Stencil ix e e
foldStencil :: Sz ix -> Stencil ix e e
foldStencil = (e -> e -> e) -> e -> Sz ix -> Stencil ix e e
forall ix a e.
Index ix =>
(a -> e -> a) -> a -> Sz ix -> Stencil ix e a
foldlStencil e -> e -> e
forall a. Monoid a => a -> a -> a
mappend e
forall a. Monoid a => a
mempty
{-# INLINE foldStencil #-}

-- | Create a stencil centered at 0 that will extract the maximum value in the region of
-- supplied size.
--
-- ==== __Example__
--
-- Here is a sample implementation of max pooling.
--
-- >>> import Data.Massiv.Array as A
-- >>> a <- computeAs P <$> resizeM (Sz2 9 9) (Ix1 10 ..: 91)
-- >>> a
-- Array P Seq (Sz (9 :. 9))
--   [ [ 10, 11, 12, 13, 14, 15, 16, 17, 18 ]
--   , [ 19, 20, 21, 22, 23, 24, 25, 26, 27 ]
--   , [ 28, 29, 30, 31, 32, 33, 34, 35, 36 ]
--   , [ 37, 38, 39, 40, 41, 42, 43, 44, 45 ]
--   , [ 46, 47, 48, 49, 50, 51, 52, 53, 54 ]
--   , [ 55, 56, 57, 58, 59, 60, 61, 62, 63 ]
--   , [ 64, 65, 66, 67, 68, 69, 70, 71, 72 ]
--   , [ 73, 74, 75, 76, 77, 78, 79, 80, 81 ]
--   , [ 82, 83, 84, 85, 86, 87, 88, 89, 90 ]
--   ]
-- >>> computeWithStrideAs P (Stride 3) $ mapStencil Edge (maxStencil (Sz 3)) a
-- Array P Seq (Sz (3 :. 3))
--   [ [ 30, 33, 36 ]
--   , [ 57, 60, 63 ]
--   , [ 84, 87, 90 ]
--   ]
--
-- @since 0.4.3
maxStencil :: (Bounded e, Ord e, Index ix) => Sz ix -> Stencil ix e e
maxStencil :: Sz ix -> Stencil ix e e
maxStencil = (e -> Max e)
-> (Max e -> e) -> Stencil ix (Max e) (Max e) -> Stencil ix e e
forall c d a b ix.
(c -> d) -> (a -> b) -> Stencil ix d a -> Stencil ix c b
dimapStencil e -> Max e
coerce Max e -> e
forall a. Max a -> a
getMax (Stencil ix (Max e) (Max e) -> Stencil ix e e)
-> (Sz ix -> Stencil ix (Max e) (Max e)) -> Sz ix -> Stencil ix e e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sz ix -> Stencil ix (Max e) (Max e)
forall e ix. (Monoid e, Index ix) => Sz ix -> Stencil ix e e
foldStencil
{-# INLINE maxStencil #-}


-- | Create a stencil centered at 0 that will extract the maximum value in the region of
-- supplied size.
--
-- ==== __Example__
--
-- Here is a sample implementation of min pooling.
--
-- >>> import Data.Massiv.Array as A
-- >>> a <- computeAs P <$> resizeM (Sz2 9 9) (Ix1 10 ..: 91)
-- >>> a
-- Array P Seq (Sz (9 :. 9))
--   [ [ 10, 11, 12, 13, 14, 15, 16, 17, 18 ]
--   , [ 19, 20, 21, 22, 23, 24, 25, 26, 27 ]
--   , [ 28, 29, 30, 31, 32, 33, 34, 35, 36 ]
--   , [ 37, 38, 39, 40, 41, 42, 43, 44, 45 ]
--   , [ 46, 47, 48, 49, 50, 51, 52, 53, 54 ]
--   , [ 55, 56, 57, 58, 59, 60, 61, 62, 63 ]
--   , [ 64, 65, 66, 67, 68, 69, 70, 71, 72 ]
--   , [ 73, 74, 75, 76, 77, 78, 79, 80, 81 ]
--   , [ 82, 83, 84, 85, 86, 87, 88, 89, 90 ]
--   ]
-- >>> computeWithStrideAs P (Stride 3) $ mapStencil Edge (minStencil (Sz 3)) a
-- Array P Seq (Sz (3 :. 3))
--   [ [ 10, 13, 16 ]
--   , [ 37, 40, 43 ]
--   , [ 64, 67, 70 ]
--   ]
--
-- @since 0.4.3
minStencil :: (Bounded e, Ord e, Index ix) => Sz ix -> Stencil ix e e
minStencil :: Sz ix -> Stencil ix e e
minStencil = (e -> Min e)
-> (Min e -> e) -> Stencil ix (Min e) (Min e) -> Stencil ix e e
forall c d a b ix.
(c -> d) -> (a -> b) -> Stencil ix d a -> Stencil ix c b
dimapStencil e -> Min e
coerce Min e -> e
forall a. Min a -> a
getMin (Stencil ix (Min e) (Min e) -> Stencil ix e e)
-> (Sz ix -> Stencil ix (Min e) (Min e)) -> Sz ix -> Stencil ix e e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sz ix -> Stencil ix (Min e) (Min e)
forall e ix. (Monoid e, Index ix) => Sz ix -> Stencil ix e e
foldStencil
{-# INLINE minStencil #-}

-- | Sum all elements in the stencil region
--
-- ==== __Examples__
--
-- >>> import Data.Massiv.Array as A
-- >>> a = computeAs P $ iterateN (Sz2 2 5) (* 2) (1 :: Int)
-- >>> a
-- Array P Seq (Sz (2 :. 5))
--   [ [ 2, 4, 8, 16, 32 ]
--   , [ 64, 128, 256, 512, 1024 ]
--   ]
-- >>> applyStencil noPadding (sumStencil (Sz2 1 2)) a
-- Array DW Seq (Sz (2 :. 4))
--   [ [ 6, 12, 24, 48 ]
--   , [ 192, 384, 768, 1536 ]
--   ]
-- >>> [2 + 4, 4 + 8, 8 + 16, 16 + 32] :: [Int]
-- [6,12,24,48]
--
-- @since 0.4.3
sumStencil :: (Num e, Index ix) => Sz ix -> Stencil ix e e
sumStencil :: Sz ix -> Stencil ix e e
sumStencil = (e -> Sum e)
-> (Sum e -> e) -> Stencil ix (Sum e) (Sum e) -> Stencil ix e e
forall c d a b ix.
(c -> d) -> (a -> b) -> Stencil ix d a -> Stencil ix c b
dimapStencil e -> Sum e
coerce Sum e -> e
forall a. Sum a -> a
getSum (Stencil ix (Sum e) (Sum e) -> Stencil ix e e)
-> (Sz ix -> Stencil ix (Sum e) (Sum e)) -> Sz ix -> Stencil ix e e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sz ix -> Stencil ix (Sum e) (Sum e)
forall e ix. (Monoid e, Index ix) => Sz ix -> Stencil ix e e
foldStencil
{-# INLINE sumStencil #-}

-- | Multiply all elements in the stencil region
--
-- ==== __Examples__
--
-- >>> import Data.Massiv.Array as A
-- >>> a = computeAs P $ iterateN (Sz2 2 2) (+1) (0 :: Int)
-- >>> a
-- Array P Seq (Sz (2 :. 2))
--   [ [ 1, 2 ]
--   , [ 3, 4 ]
--   ]
-- >>> applyStencil (Padding 0 2 (Fill 0)) (productStencil 2) a
-- Array DW Seq (Sz (3 :. 3))
--   [ [ 24, 0, 0 ]
--   , [ 0, 0, 0 ]
--   , [ 0, 0, 0 ]
--   ]
-- >>> applyStencil (Padding 0 2 Reflect) (productStencil 2) a
-- Array DW Seq (Sz (3 :. 3))
--   [ [ 24, 64, 24 ]
--   , [ 144, 256, 144 ]
--   , [ 24, 64, 24 ]
--   ]
--
-- @since 0.4.3
productStencil :: (Num e, Index ix) => Sz ix -> Stencil ix e e
productStencil :: Sz ix -> Stencil ix e e
productStencil = (e -> Product e)
-> (Product e -> e)
-> Stencil ix (Product e) (Product e)
-> Stencil ix e e
forall c d a b ix.
(c -> d) -> (a -> b) -> Stencil ix d a -> Stencil ix c b
dimapStencil e -> Product e
coerce Product e -> e
forall a. Product a -> a
getProduct (Stencil ix (Product e) (Product e) -> Stencil ix e e)
-> (Sz ix -> Stencil ix (Product e) (Product e))
-> Sz ix
-> Stencil ix e e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sz ix -> Stencil ix (Product e) (Product e)
forall e ix. (Monoid e, Index ix) => Sz ix -> Stencil ix e e
foldStencil
{-# INLINE productStencil #-}

-- | Find the average value of all elements in the stencil region
--
-- ==== __Example__
--
-- >>> import Data.Massiv.Array as A
-- >>> a = computeAs P $ iterateN (Sz2 3 4) (+1) (10 :: Double)
-- >>> a
-- Array P Seq (Sz (3 :. 4))
--   [ [ 11.0, 12.0, 13.0, 14.0 ]
--   , [ 15.0, 16.0, 17.0, 18.0 ]
--   , [ 19.0, 20.0, 21.0, 22.0 ]
--   ]
-- >>> applyStencil noPadding (avgStencil (Sz2 2 3)) a
-- Array DW Seq (Sz (2 :. 2))
--   [ [ 14.0, 15.0 ]
--   , [ 18.0, 19.0 ]
--   ]
-- >>> Prelude.sum [11.0, 12.0, 13.0, 15.0, 16.0, 17.0] / 6 :: Double
-- 14.0
--
-- @since 0.4.3
avgStencil :: (Fractional e, Index ix) => Sz ix -> Stencil ix e e
avgStencil :: Sz ix -> Stencil ix e e
avgStencil Sz ix
sz = Sz ix -> Stencil ix e e
forall e ix. (Num e, Index ix) => Sz ix -> Stencil ix e e
sumStencil Sz ix
sz Stencil ix e e -> Stencil ix e e -> Stencil ix e e
forall a. Fractional a => a -> a -> a
/ Int -> Stencil ix e e
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Sz ix -> Int
forall ix. Index ix => Sz ix -> Int
totalElem Sz ix
sz)
{-# INLINE avgStencil #-}