{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
module Data.Massiv.Array.Stencil.Unsafe
  ( 
    makeUnsafeStencil
  , makeUnsafeConvolutionStencil
  , makeUnsafeCorrelationStencil
  , unsafeTransformStencil
  ) where
import Data.Massiv.Array.Stencil.Internal
import Data.Massiv.Core.Common
import GHC.Exts (inline)
makeUnsafeStencil
  :: Index ix
  => Sz ix 
  -> ix 
  -> (ix -> (ix -> e) -> a)
  
  -> Stencil ix e a
makeUnsafeStencil :: Sz ix -> ix -> (ix -> (ix -> e) -> a) -> Stencil ix e a
makeUnsafeStencil !Sz ix
sSz !ix
sCenter ix -> (ix -> e) -> a
relStencil = Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
forall ix e a.
Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
Stencil Sz ix
sSz ix
sCenter (ix -> e) -> (ix -> e) -> ix -> a
forall p. (ix -> e) -> p -> ix -> a
stencil
  where
    stencil :: (ix -> e) -> p -> ix -> a
stencil ix -> e
unsafeGetVal p
_getVal !ix
ix =
      a -> a
forall a. a -> a
inline (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ ix -> (ix -> e) -> a
relStencil ix
ix (ix -> e
unsafeGetVal (ix -> e) -> (ix -> ix) -> ix -> e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) ix
ix)
    {-# INLINE stencil #-}
{-# INLINE makeUnsafeStencil #-}
makeUnsafeConvolutionStencil
  :: (Index ix, Num e)
  => Sz ix
  -> ix
  -> ((ix -> e -> e -> e) -> e -> e)
  -> Stencil ix e e
makeUnsafeConvolutionStencil :: Sz ix -> ix -> ((ix -> e -> e -> e) -> e -> e) -> Stencil ix e e
makeUnsafeConvolutionStencil !Sz ix
sz !ix
sCenter (ix -> e -> e -> e) -> e -> e
relStencil =
  Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> e) -> Stencil ix e e
forall ix e a.
Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
Stencil Sz ix
sz ix
sInvertCenter (ix -> e) -> (ix -> e) -> ix -> e
forall p. (ix -> e) -> p -> ix -> e
stencil
  where
    !sInvertCenter :: ix
sInvertCenter = (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ((Int -> Int) -> ix -> ix
forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1) (Sz ix -> ix
forall ix. Sz ix -> ix
unSz Sz ix
sz)) ix
sCenter
    stencil :: (ix -> e) -> p -> ix -> e
stencil ix -> e
uget p
_ !ix
ix =
      (((ix -> e -> e -> e) -> e -> e) -> (ix -> e -> e -> e) -> e -> e
forall a. a -> a
inline (ix -> e -> e -> e) -> e -> e
relStencil ((ix -> e -> e -> e) -> e -> e) -> (ix -> e -> e -> e) -> e -> e
forall a b. (a -> b) -> a -> b
$ \ !ix
ixD !e
kVal !e
acc -> ix -> e
uget ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ix
ix ix
ixD) e -> e -> e
forall a. Num a => a -> a -> a
* e
kVal e -> e -> e
forall a. Num a => a -> a -> a
+ e
acc) e
0
    {-# INLINE stencil #-}
{-# INLINE makeUnsafeConvolutionStencil #-}
makeUnsafeCorrelationStencil
  :: (Index ix, Num e)
  => Sz ix
  -> ix
  -> ((ix -> e -> e -> e) -> e -> e)
  -> Stencil ix e e
makeUnsafeCorrelationStencil :: Sz ix -> ix -> ((ix -> e -> e -> e) -> e -> e) -> Stencil ix e e
makeUnsafeCorrelationStencil !Sz ix
sSz !ix
sCenter (ix -> e -> e -> e) -> e -> e
relStencil = Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> e) -> Stencil ix e e
forall ix e a.
Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
Stencil Sz ix
sSz ix
sCenter (ix -> e) -> (ix -> e) -> ix -> e
forall p. p -> (ix -> e) -> ix -> e
stencil
  where
    stencil :: p -> (ix -> e) -> ix -> e
stencil p
_ ix -> e
getVal !ix
ix =
      (((ix -> e -> e -> e) -> e -> e) -> (ix -> e -> e -> e) -> e -> e
forall a. a -> a
inline (ix -> e -> e -> e) -> e -> e
relStencil ((ix -> e -> e -> e) -> e -> e) -> (ix -> e -> e -> e) -> e -> e
forall a b. (a -> b) -> a -> b
$ \ !ix
ixD !e
kVal !e
acc -> ix -> e
getVal ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) ix
ix ix
ixD) e -> e -> e
forall a. Num a => a -> a -> a
* e
kVal e -> e -> e
forall a. Num a => a -> a -> a
+ e
acc) e
0
    {-# INLINE stencil #-}
{-# INLINE makeUnsafeCorrelationStencil #-}
unsafeTransformStencil ::
     (Sz ix' -> Sz ix)
  
  -> (ix' -> ix)
  
  -> (((ix' -> e) -> (ix' -> e) -> ix' -> a)
      -> (ix -> e) -> (ix -> e) -> ix -> a)
  
  -> Stencil ix' e a
  
  -> Stencil ix e a
unsafeTransformStencil :: (Sz ix' -> Sz ix)
-> (ix' -> ix)
-> (((ix' -> e) -> (ix' -> e) -> ix' -> a)
    -> (ix -> e) -> (ix -> e) -> ix -> a)
-> Stencil ix' e a
-> Stencil ix e a
unsafeTransformStencil Sz ix' -> Sz ix
transformSize ix' -> ix
transformIndex ((ix' -> e) -> (ix' -> e) -> ix' -> a)
-> (ix -> e) -> (ix -> e) -> ix -> a
transformFunc Stencil {ix'
Sz ix'
(ix' -> e) -> (ix' -> e) -> ix' -> a
stencilFunc :: forall ix e a. Stencil ix e a -> (ix -> e) -> (ix -> e) -> ix -> a
stencilCenter :: forall ix e a. Stencil ix e a -> ix
stencilSize :: forall ix e a. Stencil ix e a -> Sz ix
stencilFunc :: (ix' -> e) -> (ix' -> e) -> ix' -> a
stencilCenter :: ix'
stencilSize :: Sz ix'
..} =
  Stencil :: forall ix e a.
Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
Stencil
    { stencilSize :: Sz ix
stencilSize = Sz ix' -> Sz ix
transformSize Sz ix'
stencilSize
    , stencilCenter :: ix
stencilCenter = ix' -> ix
transformIndex ix'
stencilCenter
    , stencilFunc :: (ix -> e) -> (ix -> e) -> ix -> a
stencilFunc = ((ix' -> e) -> (ix' -> e) -> ix' -> a)
-> (ix -> e) -> (ix -> e) -> ix -> a
transformFunc (ix' -> e) -> (ix' -> e) -> ix' -> a
stencilFunc
    }
{-# INLINE unsafeTransformStencil #-}