{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Massiv.Array.Stencil.Internal
  ( Stencil(..)
  , Value(..)
  , dimapStencil
  , lmapStencil
  , rmapStencil
  , validateStencil
  ) where
import Control.Applicative
import Control.DeepSeq
import Data.Massiv.Array.Delayed.Pull
import Data.Massiv.Core.Common
import Data.Massiv.Core.Index.Internal
data Stencil ix e a = Stencil
  { stencilSize   :: !(Sz ix)
  , stencilCenter :: !ix
  , stencilFunc   :: (ix -> Value e) -> ix -> Value a
  }
instance Index ix => NFData (Stencil ix e a) where
  rnf (Stencil sz ix f) = sz `deepseq` ix `deepseq` f `seq` ()
newtype Value e = Value { unValue :: e } deriving (Show, Bounded)
instance Functor Value where
  fmap f (Value e) = Value (f e)
  {-# INLINE fmap #-}
instance Applicative Value where
  pure = Value
  {-# INLINE pure #-}
  (<*>) (Value f) (Value e) = Value (f e)
  {-# INLINE (<*>) #-}
instance Semigroup a => Semigroup (Value a) where
  Value a <> Value b = Value (a <> b)
  {-# INLINE (<>) #-}
instance Monoid a => Monoid (Value a) where
  mempty = Value mempty
  {-# INLINE mempty #-}
  Value a `mappend` Value b = Value (a `mappend` b)
  {-# INLINE mappend #-}
instance Num e => Num (Value e) where
  (+) = liftA2 (+)
  {-# INLINE (+) #-}
  (*) = liftA2 (*)
  {-# INLINE (*) #-}
  negate = fmap negate
  {-# INLINE negate #-}
  abs = fmap abs
  {-# INLINE abs #-}
  signum = fmap signum
  {-# INLINE signum #-}
  fromInteger = Value . fromInteger
  {-# INLINE fromInteger #-}
instance Fractional e => Fractional (Value e) where
  (/) = liftA2 (/)
  {-# INLINE (/) #-}
  recip = fmap recip
  {-# INLINE recip #-}
  fromRational = pure . fromRational
  {-# INLINE fromRational #-}
instance Floating e => Floating (Value e) where
  pi = pure pi
  {-# INLINE pi #-}
  exp = fmap exp
  {-# INLINE exp #-}
  log = fmap log
  {-# INLINE log #-}
  sqrt = fmap sqrt
  {-# INLINE sqrt #-}
  (**) = liftA2 (**)
  {-# INLINE (**) #-}
  logBase = liftA2 logBase
  {-# INLINE logBase #-}
  sin = fmap sin
  {-# INLINE sin #-}
  cos = fmap cos
  {-# INLINE cos #-}
  tan = fmap tan
  {-# INLINE tan #-}
  asin = fmap asin
  {-# INLINE asin #-}
  acos = fmap acos
  {-# INLINE acos #-}
  atan = fmap atan
  {-# INLINE atan #-}
  sinh = fmap sinh
  {-# INLINE sinh #-}
  cosh = fmap cosh
  {-# INLINE cosh #-}
  tanh = fmap tanh
  {-# INLINE tanh #-}
  asinh = fmap asinh
  {-# INLINE asinh #-}
  acosh = fmap acosh
  {-# INLINE acosh #-}
  atanh = fmap atanh
  {-# INLINE atanh #-}
instance Functor (Stencil ix e) where
  fmap = rmapStencil
  {-# INLINE fmap #-}
dimapStencil :: (c -> d) -> (a -> b) -> Stencil ix d a -> Stencil ix c b
dimapStencil f g stencil@Stencil {stencilFunc = sf} = stencil {stencilFunc = sf'}
  where
    sf' s = Value . g . unValue . sf (Value . f . unValue . s)
    {-# INLINE sf' #-}
{-# INLINE dimapStencil #-}
lmapStencil :: (c -> d) -> Stencil ix d a -> Stencil ix c a
lmapStencil f stencil@Stencil {stencilFunc = sf} = stencil {stencilFunc = sf'}
  where
    sf' s = sf (Value . f . unValue . s)
    {-# INLINE sf' #-}
{-# INLINE lmapStencil #-}
rmapStencil :: (a -> b) -> Stencil ix e a -> Stencil ix e b
rmapStencil f stencil@Stencil {stencilFunc = sf} = stencil {stencilFunc = sf'}
  where
    sf' s = Value . f . unValue . sf s
    {-# INLINE sf' #-}
{-# INLINE rmapStencil #-}
instance Index ix => Applicative (Stencil ix e) where
  pure a = Stencil oneSz zeroIndex (const (const (Value a)))
  {-# INLINE pure #-}
  (<*>) (Stencil (SafeSz sSz1) sC1 f1) (Stencil (SafeSz sSz2) sC2 f2) = Stencil newSz maxCenter stF
    where
      stF gV !ix = Value (unValue (f1 gV ix) (unValue (f2 gV ix)))
      {-# INLINE stF #-}
      !newSz =
        Sz
          (liftIndex2
             (+)
             maxCenter
             (liftIndex2 max (liftIndex2 (-) sSz1 sC1) (liftIndex2 (-) sSz2 sC2)))
      !maxCenter = liftIndex2 max sC1 sC2
  {-# INLINE (<*>) #-}
instance (Index ix, Num a) => Num (Stencil ix e a) where
  (+) = liftA2 (+)
  {-# INLINE (+) #-}
  (-) = liftA2 (-)
  {-# INLINE (-) #-}
  (*) = liftA2 (*)
  {-# INLINE (*) #-}
  negate = fmap negate
  {-# INLINE negate #-}
  abs = fmap abs
  {-# INLINE abs #-}
  signum = fmap signum
  {-# INLINE signum #-}
  fromInteger = pure . fromInteger
  {-# INLINE fromInteger #-}
instance (Index ix, Fractional a) => Fractional (Stencil ix e a) where
  (/) = liftA2 (/)
  {-# INLINE (/) #-}
  recip = fmap recip
  {-# INLINE recip #-}
  fromRational = pure . fromRational
  {-# INLINE fromRational #-}
instance (Index ix, Floating a) => Floating (Stencil ix e a) where
  pi = pure pi
  {-# INLINE pi #-}
  exp = fmap exp
  {-# INLINE exp #-}
  log = fmap log
  {-# INLINE log #-}
  sqrt = fmap sqrt
  {-# INLINE sqrt #-}
  (**) = liftA2 (**)
  {-# INLINE (**) #-}
  logBase = liftA2 logBase
  {-# INLINE logBase #-}
  sin = fmap sin
  {-# INLINE sin #-}
  cos = fmap cos
  {-# INLINE cos #-}
  tan = fmap tan
  {-# INLINE tan #-}
  asin = fmap asin
  {-# INLINE asin #-}
  acos = fmap acos
  {-# INLINE acos #-}
  atan = fmap atan
  {-# INLINE atan #-}
  sinh = fmap sinh
  {-# INLINE sinh #-}
  cosh = fmap cosh
  {-# INLINE cosh #-}
  tanh = fmap tanh
  {-# INLINE tanh #-}
  asinh = fmap asinh
  {-# INLINE asinh #-}
  acosh = fmap acosh
  {-# INLINE acosh #-}
  atanh = fmap atanh
  {-# INLINE atanh #-}
safeStencilIndex :: Index ix => Array D ix e -> ix -> e
safeStencilIndex DArray {..} ix
  | isSafeIndex dSize ix = dIndex ix
  | otherwise = throw $ IndexOutOfBoundsException dSize ix
validateStencil
  :: Index ix
  => e -> Stencil ix e a -> Stencil ix e a
validateStencil d s@(Stencil sSz sCenter stencil) =
  let valArr = DArray Seq sSz (const d)
  in stencil (Value . safeStencilIndex valArr) sCenter `seq` s
{-# INLINE validateStencil #-}