{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- |
-- Module      : Data.Massiv.Array.Stencil.Internal
-- Copyright   : (c) Alexey Kuleshevich 2018-2022
-- License     : BSD3
-- Maintainer  : Alexey Kuleshevich <lehins@yandex.ru>
-- Stability   : experimental
-- Portability : non-portable
module Data.Massiv.Array.Stencil.Internal (
  Stencil (..),
  dimapStencil,
  lmapStencil,
  rmapStencil,
) where

import Control.Applicative
import Control.DeepSeq
import Data.Massiv.Core.Common

-- | Stencil is abstract description of how to handle elements in the neighborhood of
-- every array cell in order to compute a value for the cells in the new array. Use
-- `Data.Massiv.Array.makeStencil` and `Data.Massiv.Array.makeConvolutionStencil` in order
-- to create a stencil.
data Stencil ix e a = Stencil
  { forall ix e a. Stencil ix e a -> Sz ix
stencilSize :: !(Sz ix)
  , forall ix e a. Stencil ix e a -> ix
stencilCenter :: !ix
  , forall ix e a. Stencil ix e a -> (ix -> e) -> (ix -> e) -> ix -> a
stencilFunc :: (ix -> e) -> (ix -> e) -> ix -> a
  }

instance Index ix => NFData (Stencil ix e a) where
  rnf :: Stencil ix e a -> ()
rnf (Stencil Sz ix
sz ix
ix (ix -> e) -> (ix -> e) -> ix -> a
f) = Sz ix
sz forall a b. NFData a => a -> b -> b
`deepseq` ix
ix forall a b. NFData a => a -> b -> b
`deepseq` (ix -> e) -> (ix -> e) -> ix -> a
f seq :: forall a b. a -> b -> b
`seq` ()

instance Functor (Stencil ix e) where
  fmap :: forall a b. (a -> b) -> Stencil ix e a -> Stencil ix e b
fmap = forall a b ix e. (a -> b) -> Stencil ix e a -> Stencil ix e b
rmapStencil
  {-# INLINE fmap #-}

-- Profunctor

-- | A Profunctor dimap. Same caviat applies as in `lmapStencil`
--
-- @since 0.2.3
dimapStencil :: (c -> d) -> (a -> b) -> Stencil ix d a -> Stencil ix c b
dimapStencil :: forall c d a b ix.
(c -> d) -> (a -> b) -> Stencil ix d a -> Stencil ix c b
dimapStencil c -> d
f a -> b
g stencil :: Stencil ix d a
stencil@Stencil{stencilFunc :: forall ix e a. Stencil ix e a -> (ix -> e) -> (ix -> e) -> ix -> a
stencilFunc = (ix -> d) -> (ix -> d) -> ix -> a
sf} = Stencil ix d a
stencil{stencilFunc :: (ix -> c) -> (ix -> c) -> ix -> b
stencilFunc = (ix -> c) -> (ix -> c) -> ix -> b
sf'}
  where
    sf' :: (ix -> c) -> (ix -> c) -> ix -> b
sf' ix -> c
us ix -> c
s = a -> b
g forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ix -> d) -> (ix -> d) -> ix -> a
sf (c -> d
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. ix -> c
us) (c -> d
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. ix -> c
s)
    {-# INLINE sf' #-}
{-# INLINE dimapStencil #-}

-- | A contravariant map of a second type parameter. In other words map a function over each element
-- of the array, that the stencil will be applied to.
--
-- __Note__: This map can be very inefficient, since for stencils larger than 1 element in size, the
-- supllied function will be repeatedly applied to the same element. It is better to simply map that
-- function over the source array instead.
--
-- @since 0.2.3
lmapStencil :: (c -> d) -> Stencil ix d a -> Stencil ix c a
lmapStencil :: forall c d ix a. (c -> d) -> Stencil ix d a -> Stencil ix c a
lmapStencil c -> d
f stencil :: Stencil ix d a
stencil@Stencil{stencilFunc :: forall ix e a. Stencil ix e a -> (ix -> e) -> (ix -> e) -> ix -> a
stencilFunc = (ix -> d) -> (ix -> d) -> ix -> a
sf} = Stencil ix d a
stencil{stencilFunc :: (ix -> c) -> (ix -> c) -> ix -> a
stencilFunc = (ix -> c) -> (ix -> c) -> ix -> a
sf'}
  where
    sf' :: (ix -> c) -> (ix -> c) -> ix -> a
sf' ix -> c
us ix -> c
s = (ix -> d) -> (ix -> d) -> ix -> a
sf (c -> d
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. ix -> c
us) (c -> d
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. ix -> c
s)
    {-# INLINE sf' #-}
{-# INLINE lmapStencil #-}

-- | A covariant map over the right most type argument. In other words the usual `fmap`
-- from `Functor`:
--
-- > fmap == rmapStencil
--
-- @since 0.2.3
rmapStencil :: (a -> b) -> Stencil ix e a -> Stencil ix e b
rmapStencil :: forall a b ix e. (a -> b) -> Stencil ix e a -> Stencil ix e b
rmapStencil a -> b
f stencil :: Stencil ix e a
stencil@Stencil{stencilFunc :: forall ix e a. Stencil ix e a -> (ix -> e) -> (ix -> e) -> ix -> a
stencilFunc = (ix -> e) -> (ix -> e) -> ix -> a
sf} = Stencil ix e a
stencil{stencilFunc :: (ix -> e) -> (ix -> e) -> ix -> b
stencilFunc = (ix -> e) -> (ix -> e) -> ix -> b
sf'}
  where
    sf' :: (ix -> e) -> (ix -> e) -> ix -> b
sf' ix -> e
us ix -> e
s = a -> b
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ix -> e) -> (ix -> e) -> ix -> a
sf ix -> e
us ix -> e
s
    {-# INLINE sf' #-}
{-# INLINE rmapStencil #-}

unionStencilCenters :: Index ix => Stencil ix e1 a1 -> Stencil ix e2 a2 -> ix
unionStencilCenters :: forall ix e1 a1 e2 a2.
Index ix =>
Stencil ix e1 a1 -> Stencil ix e2 a2 -> ix
unionStencilCenters (Stencil Sz ix
_ ix
sC1 (ix -> e1) -> (ix -> e1) -> ix -> a1
_) (Stencil Sz ix
_ ix
sC2 (ix -> e2) -> (ix -> e2) -> ix -> a2
_) = forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 forall a. Ord a => a -> a -> a
max ix
sC1 ix
sC2
{-# INLINE unionStencilCenters #-}

unionStencilSizes :: Index ix => ix -> Stencil ix e1 a1 -> Stencil ix e2 a2 -> Sz ix
unionStencilSizes :: forall ix e1 a1 e2 a2.
Index ix =>
ix -> Stencil ix e1 a1 -> Stencil ix e2 a2 -> Sz ix
unionStencilSizes ix
maxCenter (Stencil (SafeSz ix
sSz1) ix
sC1 (ix -> e1) -> (ix -> e1) -> ix -> a1
_) (Stencil (SafeSz ix
sSz2) ix
sC2 (ix -> e2) -> (ix -> e2) -> ix -> a2
_) =
  forall ix. Index ix => ix -> Sz ix
Sz forall a b. (a -> b) -> a -> b
$ forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 forall a. Num a => a -> a -> a
(+) ix
maxCenter forall a b. (a -> b) -> a -> b
$ forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 forall a. Ord a => a -> a -> a
max (forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ix
sSz1 ix
sC1) (forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ix
sSz2 ix
sC2)
{-# INLINE unionStencilSizes #-}

-- TODO: Test interchange law (u <*> pure y = pure ($ y) <*> u)
instance Index ix => Applicative (Stencil ix e) where
  pure :: forall a. a -> Stencil ix e a
pure a
a = forall ix e a.
Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
Stencil forall ix. Index ix => Sz ix
oneSz forall ix. Index ix => ix
zeroIndex (\ix -> e
_ ix -> e
_ ix
_ -> a
a)
  {-# INLINE pure #-}
  <*> :: forall a b.
Stencil ix e (a -> b) -> Stencil ix e a -> Stencil ix e b
(<*>) s1 :: Stencil ix e (a -> b)
s1@(Stencil Sz ix
_ ix
_ (ix -> e) -> (ix -> e) -> ix -> a -> b
f1) s2 :: Stencil ix e a
s2@(Stencil Sz ix
_ ix
_ (ix -> e) -> (ix -> e) -> ix -> a
f2) = forall ix e a.
Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
Stencil Sz ix
newSz ix
maxCenter (ix -> e) -> (ix -> e) -> ix -> b
stF
    where
      stF :: (ix -> e) -> (ix -> e) -> ix -> b
stF ix -> e
ug ix -> e
gV !ix
ix = (ix -> e) -> (ix -> e) -> ix -> a -> b
f1 ix -> e
ug ix -> e
gV ix
ix ((ix -> e) -> (ix -> e) -> ix -> a
f2 ix -> e
ug ix -> e
gV ix
ix)
      {-# INLINE stF #-}
      !newSz :: Sz ix
newSz = forall ix e1 a1 e2 a2.
Index ix =>
ix -> Stencil ix e1 a1 -> Stencil ix e2 a2 -> Sz ix
unionStencilSizes ix
maxCenter Stencil ix e (a -> b)
s1 Stencil ix e a
s2
      !maxCenter :: ix
maxCenter = forall ix e1 a1 e2 a2.
Index ix =>
Stencil ix e1 a1 -> Stencil ix e2 a2 -> ix
unionStencilCenters Stencil ix e (a -> b)
s1 Stencil ix e a
s2
  {-# INLINE (<*>) #-}

#if MIN_VERSION_base(4,10,0)
  liftA2 :: forall a b c.
(a -> b -> c) -> Stencil ix e a -> Stencil ix e b -> Stencil ix e c
liftA2 a -> b -> c
f s1 :: Stencil ix e a
s1@(Stencil Sz ix
_ ix
_ (ix -> e) -> (ix -> e) -> ix -> a
f1) s2 :: Stencil ix e b
s2@(Stencil Sz ix
_ ix
_ (ix -> e) -> (ix -> e) -> ix -> b
f2) = forall ix e a.
Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
Stencil Sz ix
newSz ix
maxCenter (ix -> e) -> (ix -> e) -> ix -> c
stF
    where
      stF :: (ix -> e) -> (ix -> e) -> ix -> c
stF ix -> e
ug ix -> e
gV !ix
ix = a -> b -> c
f ((ix -> e) -> (ix -> e) -> ix -> a
f1 ix -> e
ug ix -> e
gV ix
ix) ((ix -> e) -> (ix -> e) -> ix -> b
f2 ix -> e
ug ix -> e
gV ix
ix)
      {-# INLINE stF #-}
      !newSz :: Sz ix
newSz = forall ix e1 a1 e2 a2.
Index ix =>
ix -> Stencil ix e1 a1 -> Stencil ix e2 a2 -> Sz ix
unionStencilSizes ix
maxCenter Stencil ix e a
s1 Stencil ix e b
s2
      !maxCenter :: ix
maxCenter = forall ix e1 a1 e2 a2.
Index ix =>
Stencil ix e1 a1 -> Stencil ix e2 a2 -> ix
unionStencilCenters Stencil ix e a
s1 Stencil ix e b
s2
  {-# INLINE liftA2 #-}
#endif

instance (Index ix, Num a) => Num (Stencil ix e a) where
  + :: Stencil ix e a -> Stencil ix e a -> Stencil ix e a
(+) = forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 forall a. Num a => a -> a -> a
(+)
  {-# INLINE (+) #-}
  (-) = forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (-)
  {-# INLINE (-) #-}
  * :: Stencil ix e a -> Stencil ix e a -> Stencil ix e a
(*) = forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 forall a. Num a => a -> a -> a
(*)
  {-# INLINE (*) #-}
  negate :: Stencil ix e a -> Stencil ix e a
negate = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Num a => a -> a
negate
  {-# INLINE negate #-}
  abs :: Stencil ix e a -> Stencil ix e a
abs = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Num a => a -> a
abs
  {-# INLINE abs #-}
  signum :: Stencil ix e a -> Stencil ix e a
signum = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Num a => a -> a
signum
  {-# INLINE signum #-}
  fromInteger :: Integer -> Stencil ix e a
fromInteger = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => Integer -> a
fromInteger
  {-# INLINE fromInteger #-}

instance (Index ix, Fractional a) => Fractional (Stencil ix e a) where
  / :: Stencil ix e a -> Stencil ix e a -> Stencil ix e a
(/) = forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 forall a. Fractional a => a -> a -> a
(/)
  {-# INLINE (/) #-}
  recip :: Stencil ix e a -> Stencil ix e a
recip = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Fractional a => a -> a
recip
  {-# INLINE recip #-}
  fromRational :: Rational -> Stencil ix e a
fromRational = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Fractional a => Rational -> a
fromRational
  {-# INLINE fromRational #-}

instance (Index ix, Floating a) => Floating (Stencil ix e a) where
  pi :: Stencil ix e a
pi = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Floating a => a
pi
  {-# INLINE pi #-}
  exp :: Stencil ix e a -> Stencil ix e a
exp = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Floating a => a -> a
exp
  {-# INLINE exp #-}
  log :: Stencil ix e a -> Stencil ix e a
log = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Floating a => a -> a
log
  {-# INLINE log #-}
  sqrt :: Stencil ix e a -> Stencil ix e a
sqrt = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Floating a => a -> a
sqrt
  {-# INLINE sqrt #-}
  ** :: Stencil ix e a -> Stencil ix e a -> Stencil ix e a
(**) = forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 forall a. Floating a => a -> a -> a
(**)
  {-# INLINE (**) #-}
  logBase :: Stencil ix e a -> Stencil ix e a -> Stencil ix e a
logBase = forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 forall a. Floating a => a -> a -> a
logBase
  {-# INLINE logBase #-}
  sin :: Stencil ix e a -> Stencil ix e a
sin = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Floating a => a -> a
sin
  {-# INLINE sin #-}
  cos :: Stencil ix e a -> Stencil ix e a
cos = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Floating a => a -> a
cos
  {-# INLINE cos #-}
  tan :: Stencil ix e a -> Stencil ix e a
tan = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Floating a => a -> a
tan
  {-# INLINE tan #-}
  asin :: Stencil ix e a -> Stencil ix e a
asin = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Floating a => a -> a
asin
  {-# INLINE asin #-}
  acos :: Stencil ix e a -> Stencil ix e a
acos = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Floating a => a -> a
acos
  {-# INLINE acos #-}
  atan :: Stencil ix e a -> Stencil ix e a
atan = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Floating a => a -> a
atan
  {-# INLINE atan #-}
  sinh :: Stencil ix e a -> Stencil ix e a
sinh = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Floating a => a -> a
sinh
  {-# INLINE sinh #-}
  cosh :: Stencil ix e a -> Stencil ix e a
cosh = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Floating a => a -> a
cosh
  {-# INLINE cosh #-}
  tanh :: Stencil ix e a -> Stencil ix e a
tanh = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Floating a => a -> a
tanh
  {-# INLINE tanh #-}
  asinh :: Stencil ix e a -> Stencil ix e a
asinh = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Floating a => a -> a
asinh
  {-# INLINE asinh #-}
  acosh :: Stencil ix e a -> Stencil ix e a
acosh = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Floating a => a -> a
acosh
  {-# INLINE acosh #-}
  atanh :: Stencil ix e a -> Stencil ix e a
atanh = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Floating a => a -> a
atanh
  {-# INLINE atanh #-}