{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
-- |
-- Module      : Data.Massiv.Array.Stencil.Internal
-- Copyright   : (c) Alexey Kuleshevich 2018-2022
-- License     : BSD3
-- Maintainer  : Alexey Kuleshevich <lehins@yandex.ru>
-- Stability   : experimental
-- Portability : non-portable
--
module Data.Massiv.Array.Stencil.Internal
  ( Stencil(..)
  , dimapStencil
  , lmapStencil
  , rmapStencil
  ) where

import Control.Applicative
import Control.DeepSeq
import Data.Massiv.Core.Common

-- | Stencil is abstract description of how to handle elements in the neighborhood of
-- every array cell in order to compute a value for the cells in the new array. Use
-- `Data.Massiv.Array.makeStencil` and `Data.Massiv.Array.makeConvolutionStencil` in order
-- to create a stencil.
data Stencil ix e a = Stencil
  { Stencil ix e a -> Sz ix
stencilSize   :: !(Sz ix)
  , Stencil ix e a -> ix
stencilCenter :: !ix
  , Stencil ix e a -> (ix -> e) -> (ix -> e) -> ix -> a
stencilFunc   :: (ix -> e) -> (ix -> e) -> ix -> a
  }


instance Index ix => NFData (Stencil ix e a) where
  rnf :: Stencil ix e a -> ()
rnf (Stencil Sz ix
sz ix
ix (ix -> e) -> (ix -> e) -> ix -> a
f) = Sz ix
sz Sz ix -> ix -> ix
forall a b. NFData a => a -> b -> b
`deepseq` ix
ix ix
-> ((ix -> e) -> (ix -> e) -> ix -> a)
-> (ix -> e)
-> (ix -> e)
-> ix
-> a
forall a b. NFData a => a -> b -> b
`deepseq` (ix -> e) -> (ix -> e) -> ix -> a
f ((ix -> e) -> (ix -> e) -> ix -> a) -> () -> ()
`seq` ()

instance Functor (Stencil ix e) where
  fmap :: (a -> b) -> Stencil ix e a -> Stencil ix e b
fmap = (a -> b) -> Stencil ix e a -> Stencil ix e b
forall a b ix e. (a -> b) -> Stencil ix e a -> Stencil ix e b
rmapStencil
  {-# INLINE fmap #-}


-- Profunctor

-- | A Profunctor dimap. Same caviat applies as in `lmapStencil`
--
-- @since 0.2.3
dimapStencil :: (c -> d) -> (a -> b) -> Stencil ix d a -> Stencil ix c b
dimapStencil :: (c -> d) -> (a -> b) -> Stencil ix d a -> Stencil ix c b
dimapStencil c -> d
f a -> b
g stencil :: Stencil ix d a
stencil@Stencil {stencilFunc :: forall ix e a. Stencil ix e a -> (ix -> e) -> (ix -> e) -> ix -> a
stencilFunc = (ix -> d) -> (ix -> d) -> ix -> a
sf} = Stencil ix d a
stencil {stencilFunc :: (ix -> c) -> (ix -> c) -> ix -> b
stencilFunc = (ix -> c) -> (ix -> c) -> ix -> b
sf'}
  where
    sf' :: (ix -> c) -> (ix -> c) -> ix -> b
sf' ix -> c
us ix -> c
s = a -> b
g (a -> b) -> (ix -> a) -> ix -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ix -> d) -> (ix -> d) -> ix -> a
sf (c -> d
f (c -> d) -> (ix -> c) -> ix -> d
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ix -> c
us) (c -> d
f (c -> d) -> (ix -> c) -> ix -> d
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ix -> c
s)
    {-# INLINE sf' #-}
{-# INLINE dimapStencil #-}

-- | A contravariant map of a second type parameter. In other words map a function over each element
-- of the array, that the stencil will be applied to.
--
-- __Note__: This map can be very inefficient, since for stencils larger than 1 element in size, the
-- supllied function will be repeatedly applied to the same element. It is better to simply map that
-- function over the source array instead.
--
-- @since 0.2.3
lmapStencil :: (c -> d) -> Stencil ix d a -> Stencil ix c a
lmapStencil :: (c -> d) -> Stencil ix d a -> Stencil ix c a
lmapStencil c -> d
f stencil :: Stencil ix d a
stencil@Stencil {stencilFunc :: forall ix e a. Stencil ix e a -> (ix -> e) -> (ix -> e) -> ix -> a
stencilFunc = (ix -> d) -> (ix -> d) -> ix -> a
sf} = Stencil ix d a
stencil {stencilFunc :: (ix -> c) -> (ix -> c) -> ix -> a
stencilFunc = (ix -> c) -> (ix -> c) -> ix -> a
sf'}
  where
    sf' :: (ix -> c) -> (ix -> c) -> ix -> a
sf' ix -> c
us ix -> c
s = (ix -> d) -> (ix -> d) -> ix -> a
sf (c -> d
f (c -> d) -> (ix -> c) -> ix -> d
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ix -> c
us) (c -> d
f (c -> d) -> (ix -> c) -> ix -> d
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ix -> c
s)
    {-# INLINE sf' #-}
{-# INLINE lmapStencil #-}

-- | A covariant map over the right most type argument. In other words the usual `fmap`
-- from `Functor`:
--
-- > fmap == rmapStencil
--
-- @since 0.2.3
rmapStencil :: (a -> b) -> Stencil ix e a -> Stencil ix e b
rmapStencil :: (a -> b) -> Stencil ix e a -> Stencil ix e b
rmapStencil a -> b
f stencil :: Stencil ix e a
stencil@Stencil {stencilFunc :: forall ix e a. Stencil ix e a -> (ix -> e) -> (ix -> e) -> ix -> a
stencilFunc = (ix -> e) -> (ix -> e) -> ix -> a
sf} = Stencil ix e a
stencil {stencilFunc :: (ix -> e) -> (ix -> e) -> ix -> b
stencilFunc = (ix -> e) -> (ix -> e) -> ix -> b
sf'}
  where
    sf' :: (ix -> e) -> (ix -> e) -> ix -> b
sf' ix -> e
us ix -> e
s = a -> b
f (a -> b) -> (ix -> a) -> ix -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ix -> e) -> (ix -> e) -> ix -> a
sf ix -> e
us ix -> e
s
    {-# INLINE sf' #-}
{-# INLINE rmapStencil #-}

unionStencilCenters :: Index ix => Stencil ix e1 a1 -> Stencil ix e2 a2 -> ix
unionStencilCenters :: Stencil ix e1 a1 -> Stencil ix e2 a2 -> ix
unionStencilCenters (Stencil Sz ix
_ ix
sC1 (ix -> e1) -> (ix -> e1) -> ix -> a1
_) (Stencil Sz ix
_ ix
sC2 (ix -> e2) -> (ix -> e2) -> ix -> a2
_) = (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Ord a => a -> a -> a
max ix
sC1 ix
sC2
{-# INLINE unionStencilCenters #-}

unionStencilSizes :: Index ix => ix -> Stencil ix e1 a1 -> Stencil ix e2 a2 -> Sz ix
unionStencilSizes :: ix -> Stencil ix e1 a1 -> Stencil ix e2 a2 -> Sz ix
unionStencilSizes ix
maxCenter (Stencil (SafeSz ix
sSz1) ix
sC1 (ix -> e1) -> (ix -> e1) -> ix -> a1
_) (Stencil (SafeSz ix
sSz2) ix
sC2 (ix -> e2) -> (ix -> e2) -> ix -> a2
_) =
  ix -> Sz ix
forall ix. Index ix => ix -> Sz ix
Sz (ix -> Sz ix) -> ix -> Sz ix
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) ix
maxCenter (ix -> ix) -> ix -> ix
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Ord a => a -> a -> a
max ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ix
sSz1 ix
sC1) ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ix
sSz2 ix
sC2)
{-# INLINE unionStencilSizes #-}

-- TODO: Test interchange law (u <*> pure y = pure ($ y) <*> u)
instance Index ix => Applicative (Stencil ix e) where
  pure :: a -> Stencil ix e a
pure a
a = Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
forall ix e a.
Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
Stencil Sz ix
forall ix. Index ix => Sz ix
oneSz ix
forall ix. Index ix => ix
zeroIndex (\ix -> e
_ ix -> e
_ ix
_ -> a
a)
  {-# INLINE pure #-}
  <*> :: Stencil ix e (a -> b) -> Stencil ix e a -> Stencil ix e b
(<*>) s1 :: Stencil ix e (a -> b)
s1@(Stencil Sz ix
_ ix
_ (ix -> e) -> (ix -> e) -> ix -> a -> b
f1) s2 :: Stencil ix e a
s2@(Stencil Sz ix
_ ix
_ (ix -> e) -> (ix -> e) -> ix -> a
f2) = Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> b) -> Stencil ix e b
forall ix e a.
Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
Stencil Sz ix
newSz ix
maxCenter (ix -> e) -> (ix -> e) -> ix -> b
stF
    where
      stF :: (ix -> e) -> (ix -> e) -> ix -> b
stF ix -> e
ug ix -> e
gV !ix
ix = (ix -> e) -> (ix -> e) -> ix -> a -> b
f1 ix -> e
ug ix -> e
gV ix
ix ((ix -> e) -> (ix -> e) -> ix -> a
f2 ix -> e
ug ix -> e
gV ix
ix)
      {-# INLINE stF #-}
      !newSz :: Sz ix
newSz = ix -> Stencil ix e (a -> b) -> Stencil ix e a -> Sz ix
forall ix e1 a1 e2 a2.
Index ix =>
ix -> Stencil ix e1 a1 -> Stencil ix e2 a2 -> Sz ix
unionStencilSizes ix
maxCenter Stencil ix e (a -> b)
s1 Stencil ix e a
s2
      !maxCenter :: ix
maxCenter = Stencil ix e (a -> b) -> Stencil ix e a -> ix
forall ix e1 a1 e2 a2.
Index ix =>
Stencil ix e1 a1 -> Stencil ix e2 a2 -> ix
unionStencilCenters Stencil ix e (a -> b)
s1 Stencil ix e a
s2
  {-# INLINE (<*>) #-}

#if MIN_VERSION_base(4,10,0)
  liftA2 :: (a -> b -> c) -> Stencil ix e a -> Stencil ix e b -> Stencil ix e c
liftA2 a -> b -> c
f s1 :: Stencil ix e a
s1@(Stencil Sz ix
_ ix
_ (ix -> e) -> (ix -> e) -> ix -> a
f1) s2 :: Stencil ix e b
s2@(Stencil Sz ix
_ ix
_ (ix -> e) -> (ix -> e) -> ix -> b
f2) = Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> c) -> Stencil ix e c
forall ix e a.
Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
Stencil Sz ix
newSz ix
maxCenter (ix -> e) -> (ix -> e) -> ix -> c
stF
    where
      stF :: (ix -> e) -> (ix -> e) -> ix -> c
stF ix -> e
ug ix -> e
gV !ix
ix = a -> b -> c
f ((ix -> e) -> (ix -> e) -> ix -> a
f1 ix -> e
ug ix -> e
gV ix
ix) ((ix -> e) -> (ix -> e) -> ix -> b
f2 ix -> e
ug ix -> e
gV ix
ix)
      {-# INLINE stF #-}
      !newSz :: Sz ix
newSz = ix -> Stencil ix e a -> Stencil ix e b -> Sz ix
forall ix e1 a1 e2 a2.
Index ix =>
ix -> Stencil ix e1 a1 -> Stencil ix e2 a2 -> Sz ix
unionStencilSizes ix
maxCenter Stencil ix e a
s1 Stencil ix e b
s2
      !maxCenter :: ix
maxCenter = Stencil ix e a -> Stencil ix e b -> ix
forall ix e1 a1 e2 a2.
Index ix =>
Stencil ix e1 a1 -> Stencil ix e2 a2 -> ix
unionStencilCenters Stencil ix e a
s1 Stencil ix e b
s2
  {-# INLINE liftA2 #-}
#endif

instance (Index ix, Num a) => Num (Stencil ix e a) where
  + :: Stencil ix e a -> Stencil ix e a -> Stencil ix e a
(+) = (a -> a -> a) -> Stencil ix e a -> Stencil ix e a -> Stencil ix e a
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 a -> a -> a
forall a. Num a => a -> a -> a
(+)
  {-# INLINE (+) #-}
  (-) = (a -> a -> a) -> Stencil ix e a -> Stencil ix e a -> Stencil ix e a
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (-)
  {-# INLINE (-) #-}
  * :: Stencil ix e a -> Stencil ix e a -> Stencil ix e a
(*) = (a -> a -> a) -> Stencil ix e a -> Stencil ix e a -> Stencil ix e a
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 a -> a -> a
forall a. Num a => a -> a -> a
(*)
  {-# INLINE (*) #-}
  negate :: Stencil ix e a -> Stencil ix e a
negate = (a -> a) -> Stencil ix e a -> Stencil ix e a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Num a => a -> a
negate
  {-# INLINE negate #-}
  abs :: Stencil ix e a -> Stencil ix e a
abs = (a -> a) -> Stencil ix e a -> Stencil ix e a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Num a => a -> a
abs
  {-# INLINE abs #-}
  signum :: Stencil ix e a -> Stencil ix e a
signum = (a -> a) -> Stencil ix e a -> Stencil ix e a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Num a => a -> a
signum
  {-# INLINE signum #-}
  fromInteger :: Integer -> Stencil ix e a
fromInteger = a -> Stencil ix e a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> Stencil ix e a)
-> (Integer -> a) -> Integer -> Stencil ix e a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> a
forall a. Num a => Integer -> a
fromInteger
  {-# INLINE fromInteger #-}

instance (Index ix, Fractional a) => Fractional (Stencil ix e a) where
  / :: Stencil ix e a -> Stencil ix e a -> Stencil ix e a
(/) = (a -> a -> a) -> Stencil ix e a -> Stencil ix e a -> Stencil ix e a
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 a -> a -> a
forall a. Fractional a => a -> a -> a
(/)
  {-# INLINE (/) #-}
  recip :: Stencil ix e a -> Stencil ix e a
recip = (a -> a) -> Stencil ix e a -> Stencil ix e a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Fractional a => a -> a
recip
  {-# INLINE recip #-}
  fromRational :: Rational -> Stencil ix e a
fromRational = a -> Stencil ix e a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> Stencil ix e a)
-> (Rational -> a) -> Rational -> Stencil ix e a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rational -> a
forall a. Fractional a => Rational -> a
fromRational
  {-# INLINE fromRational #-}

instance (Index ix, Floating a) => Floating (Stencil ix e a) where
  pi :: Stencil ix e a
pi = a -> Stencil ix e a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
forall a. Floating a => a
pi
  {-# INLINE pi #-}
  exp :: Stencil ix e a -> Stencil ix e a
exp = (a -> a) -> Stencil ix e a -> Stencil ix e a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
exp
  {-# INLINE exp #-}
  log :: Stencil ix e a -> Stencil ix e a
log = (a -> a) -> Stencil ix e a -> Stencil ix e a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
log
  {-# INLINE log #-}
  sqrt :: Stencil ix e a -> Stencil ix e a
sqrt = (a -> a) -> Stencil ix e a -> Stencil ix e a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
sqrt
  {-# INLINE sqrt #-}
  ** :: Stencil ix e a -> Stencil ix e a -> Stencil ix e a
(**) = (a -> a -> a) -> Stencil ix e a -> Stencil ix e a -> Stencil ix e a
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 a -> a -> a
forall a. Floating a => a -> a -> a
(**)
  {-# INLINE (**) #-}
  logBase :: Stencil ix e a -> Stencil ix e a -> Stencil ix e a
logBase = (a -> a -> a) -> Stencil ix e a -> Stencil ix e a -> Stencil ix e a
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 a -> a -> a
forall a. Floating a => a -> a -> a
logBase
  {-# INLINE logBase #-}
  sin :: Stencil ix e a -> Stencil ix e a
sin = (a -> a) -> Stencil ix e a -> Stencil ix e a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
sin
  {-# INLINE sin #-}
  cos :: Stencil ix e a -> Stencil ix e a
cos = (a -> a) -> Stencil ix e a -> Stencil ix e a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
cos
  {-# INLINE cos #-}
  tan :: Stencil ix e a -> Stencil ix e a
tan = (a -> a) -> Stencil ix e a -> Stencil ix e a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
tan
  {-# INLINE tan #-}
  asin :: Stencil ix e a -> Stencil ix e a
asin = (a -> a) -> Stencil ix e a -> Stencil ix e a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
asin
  {-# INLINE asin #-}
  acos :: Stencil ix e a -> Stencil ix e a
acos = (a -> a) -> Stencil ix e a -> Stencil ix e a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
acos
  {-# INLINE acos #-}
  atan :: Stencil ix e a -> Stencil ix e a
atan = (a -> a) -> Stencil ix e a -> Stencil ix e a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
atan
  {-# INLINE atan #-}
  sinh :: Stencil ix e a -> Stencil ix e a
sinh = (a -> a) -> Stencil ix e a -> Stencil ix e a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
sinh
  {-# INLINE sinh #-}
  cosh :: Stencil ix e a -> Stencil ix e a
cosh = (a -> a) -> Stencil ix e a -> Stencil ix e a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
cosh
  {-# INLINE cosh #-}
  tanh :: Stencil ix e a -> Stencil ix e a
tanh = (a -> a) -> Stencil ix e a -> Stencil ix e a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
tanh
  {-# INLINE tanh #-}
  asinh :: Stencil ix e a -> Stencil ix e a
asinh = (a -> a) -> Stencil ix e a -> Stencil ix e a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
asinh
  {-# INLINE asinh #-}
  acosh :: Stencil ix e a -> Stencil ix e a
acosh = (a -> a) -> Stencil ix e a -> Stencil ix e a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
acosh
  {-# INLINE acosh #-}
  atanh :: Stencil ix e a -> Stencil ix e a
atanh = (a -> a) -> Stencil ix e a -> Stencil ix e a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> a
forall a. Floating a => a -> a
atanh
  {-# INLINE atanh #-}