module Data.Massiv.Array.Stencil.Internal where
import Control.Applicative
import Control.DeepSeq
import Data.Massiv.Core.Common
import Data.Massiv.Array.Delayed.Internal
import Data.Default.Class (Default (def))
data Stencil ix e a = Stencil
{ stencilBorder :: Border e
, stencilSize :: !ix
, stencilCenter :: !ix
, stencilFunc :: (ix -> Value e) -> ix -> Value a
}
instance (NFData e, Index ix) => NFData (Stencil ix e a) where
rnf (Stencil b sz ix f) = b `deepseq` sz `deepseq` ix `deepseq` f `seq` ()
newtype Value e = Value { unValue :: e } deriving (Show, Eq, Ord, Bounded)
instance Functor Value where
fmap f (Value e) = Value (f e)
instance Applicative Value where
pure = Value
(<*>) (Value f) (Value e) = Value (f e)
instance Num e => Num (Value e) where
(+) = liftA2 (+)
(*) = liftA2 (*)
negate = fmap negate
abs = fmap abs
signum = fmap signum
fromInteger = Value . fromInteger
instance Fractional e => Fractional (Value e) where
(/) = liftA2 (/)
recip = fmap recip
fromRational = pure . fromRational
instance Floating e => Floating (Value e) where
pi = pure pi
exp = fmap exp
log = fmap log
sqrt = fmap sqrt
(**) = liftA2 (**)
logBase = liftA2 logBase
sin = fmap sin
cos = fmap cos
tan = fmap tan
asin = fmap asin
acos = fmap acos
atan = fmap atan
sinh = fmap sinh
cosh = fmap cosh
tanh = fmap tanh
asinh = fmap asinh
acosh = fmap acosh
atanh = fmap atanh
instance Functor (Stencil ix e) where
fmap f stencil@(Stencil {stencilFunc = g}) = stencil {stencilFunc = stF}
where
stF s = Value . f . unValue . g s
instance (Default e, Index ix) => Applicative (Stencil ix e) where
pure a = Stencil Edge (pureIndex 1) zeroIndex (const (const (Value a)))
(<*>) (Stencil _ sSz1 sC1 f1) (Stencil sB sSz2 sC2 f2) =
validateStencil def (Stencil sB newSz maxCenter stF)
where
stF gV !ix = Value ((unValue (f1 gV ix)) (unValue (f2 gV ix)))
!newSz =
liftIndex2
(+)
maxCenter
(liftIndex2 max (liftIndex2 () sSz1 sC1) (liftIndex2 () sSz2 sC2))
!maxCenter = liftIndex2 max sC1 sC2
instance (Index ix, Default e, Num a) => Num (Stencil ix e a) where
(+) = liftA2 (+)
() = liftA2 ()
(*) = liftA2 (*)
negate = fmap negate
abs = fmap abs
signum = fmap signum
fromInteger = pure . fromInteger
instance (Index ix, Default e, Fractional a) => Fractional (Stencil ix e a) where
(/) = liftA2 (/)
recip = fmap recip
fromRational = pure . fromRational
instance (Index ix, Default e, Floating a) => Floating (Stencil ix e a) where
pi = pure pi
exp = fmap exp
log = fmap log
sqrt = fmap sqrt
(**) = liftA2 (**)
logBase = liftA2 logBase
sin = fmap sin
cos = fmap cos
tan = fmap tan
asin = fmap asin
acos = fmap acos
atan = fmap atan
sinh = fmap sinh
cosh = fmap cosh
tanh = fmap tanh
asinh = fmap asinh
acosh = fmap acosh
atanh = fmap atanh
safeStencilIndex :: Index ix => Array D ix e -> ix -> e
safeStencilIndex DArray {..} ix
| isSafeIndex dSize ix = dUnsafeIndex ix
| otherwise =
error $
"Index is out of bounds: " ++ show ix ++ " for stencil size: " ++ show dSize
validateStencil
:: Index ix
=> e -> Stencil ix e a -> Stencil ix e a
validateStencil d s@(Stencil _ sSz sCenter stencil) =
let valArr = DArray Seq sSz (const d)
in stencil (Value . safeStencilIndex valArr) sCenter `seq` s