{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
-- |
-- Module      : Data.Massiv.Array.Stencil.Unsafe
-- Copyright   : (c) Alexey Kuleshevich 2018-2021
-- License     : BSD3
-- Maintainer  : Alexey Kuleshevich <lehins@yandex.ru>
-- Stability   : experimental
-- Portability : non-portable
--
module Data.Massiv.Array.Stencil.Unsafe
  ( -- * Stencil
    makeUnsafeStencil
  , makeUnsafeConvolutionStencil
  , makeUnsafeCorrelationStencil
  , unsafeTransformStencil
  -- ** Deprecated
  , unsafeMapStencil
  ) where

import Data.Massiv.Array.Delayed.Windowed (Array(..), DW, Window(..),
                                           insertWindow)
import Data.Massiv.Array.Stencil.Internal
import Data.Massiv.Core.Common
import GHC.Exts (inline)


-- | This is an unsafe version of `Data.Massiv.Array.Stencil.mapStencil`, which does not
-- take a `Stencil`, but instead accepts all necessary information as separate arguments.
--
-- @since 0.5.0
unsafeMapStencil ::
     Manifest r ix e
  => Border e
  -> Sz ix
  -> ix
  -> (ix -> (ix -> e) -> a)
  -> Array r ix e
  -> Array DW ix a
unsafeMapStencil :: Border e
-> Sz ix
-> ix
-> (ix -> (ix -> e) -> a)
-> Array r ix e
-> Array DW ix a
unsafeMapStencil Border e
b Sz ix
sSz ix
sCenter ix -> (ix -> e) -> a
stencilF !Array r ix e
arr = Array D ix a -> Window ix a -> Array DW ix a
forall ix e.
Source D ix e =>
Array D ix e -> Window ix e -> Array DW ix e
insertWindow Array D ix a
warr Window ix a
window
  where
    !warr :: Array D ix a
warr = Comp -> Sz ix -> (ix -> a) -> Array D ix a
forall ix e. Comp -> Sz ix -> (ix -> e) -> Array D ix e
DArray (Array r ix e -> Comp
forall r ix e. Load r ix e => Array r ix e -> Comp
getComp Array r ix e
arr) Sz ix
sz ((ix -> e) -> ix -> a
stencil (Border e -> Array r ix e -> ix -> e
forall r ix e.
Manifest r ix e =>
Border e -> Array r ix e -> ix -> e
borderIndex Border e
b Array r ix e
arr))
    !window :: Window ix a
window =
      Window :: forall ix e. ix -> Sz ix -> (ix -> e) -> Maybe Int -> Window ix e
Window
        { windowStart :: ix
windowStart = ix
sCenter
        , windowSize :: Sz ix
windowSize = Sz ix
windowSz
        , windowIndex :: ix -> a
windowIndex = (ix -> e) -> ix -> a
stencil (Array r ix e -> ix -> e
forall r ix e. Source r ix e => Array r ix e -> ix -> e
unsafeIndex Array r ix e
arr)
        , windowUnrollIx2 :: Maybe Int
windowUnrollIx2 = Sz Int -> Int
forall ix. Sz ix -> ix
unSz (Sz Int -> Int)
-> ((Sz Int, Sz (Lower ix)) -> Sz Int)
-> (Sz Int, Sz (Lower ix))
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Sz Int, Sz (Lower ix)) -> Sz Int
forall a b. (a, b) -> a
fst ((Sz Int, Sz (Lower ix)) -> Int)
-> Maybe (Sz Int, Sz (Lower ix)) -> Maybe Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Sz ix -> Dim -> Maybe (Sz Int, Sz (Lower ix))
forall (m :: * -> *) ix.
(MonadThrow m, Index ix) =>
Sz ix -> Dim -> m (Sz Int, Sz (Lower ix))
pullOutSzM Sz ix
sSz Dim
2
        }
    !sz :: Sz ix
sz = Array r ix e -> Sz ix
forall r ix e. Load r ix e => Array r ix e -> Sz ix
size Array r ix e
arr
    !windowSz :: Sz ix
windowSz = ix -> Sz ix
forall ix. Index ix => ix -> Sz ix
Sz ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) (Sz ix -> ix
forall ix. Sz ix -> ix
unSz Sz ix
sz) ((Int -> Int) -> ix -> ix
forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1) (Sz ix -> ix
forall ix. Sz ix -> ix
unSz Sz ix
sSz)))
    stencil :: (ix -> e) -> ix -> a
stencil ix -> e
getVal !ix
ix = ((ix -> e) -> a) -> (ix -> e) -> a
forall a. a -> a
inline (ix -> (ix -> e) -> a
stencilF ix
ix) ((ix -> e) -> a) -> (ix -> e) -> a
forall a b. (a -> b) -> a -> b
$ \ !ix
ixD -> ix -> e
getVal ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) ix
ix ix
ixD)
    {-# INLINE stencil #-}
{-# INLINE unsafeMapStencil #-}
{-# DEPRECATED unsafeMapStencil "In favor of `Data.Massiv.Array.mapStencil` that is applied to stencil created with `makeUnsafeStencil`" #-}


-- | Similar to `Data.Massiv.Array.Stencil.makeStencil`, but there are no guarantees that the
-- stencil will not read out of bounds memory. This stencil is also a bit more powerful in sense it
-- gets an extra peice of information, namely the exact index for the element it is constructing.
--
-- @since 0.3.0
makeUnsafeStencil
  :: Index ix
  => Sz ix -- ^ Size of the stencil
  -> ix -- ^ Center of the stencil
  -> (ix -> (ix -> e) -> a)
  -- ^ Stencil function.
  -> Stencil ix e a
makeUnsafeStencil :: Sz ix -> ix -> (ix -> (ix -> e) -> a) -> Stencil ix e a
makeUnsafeStencil !Sz ix
sSz !ix
sCenter ix -> (ix -> e) -> a
relStencil = Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
forall ix e a.
Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
Stencil Sz ix
sSz ix
sCenter (ix -> e) -> (ix -> e) -> ix -> a
forall p. (ix -> e) -> p -> ix -> a
stencil
  where
    stencil :: (ix -> e) -> p -> ix -> a
stencil ix -> e
unsafeGetVal p
_getVal !ix
ix =
      a -> a
forall a. a -> a
inline (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ ix -> (ix -> e) -> a
relStencil ix
ix (ix -> e
unsafeGetVal (ix -> e) -> (ix -> ix) -> ix -> e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) ix
ix)
    {-# INLINE stencil #-}
{-# INLINE makeUnsafeStencil #-}

-- | Same as `Data.Massiv.Array.Stencil.makeConvolutionStencil`, but will result in
-- reading memory out of bounds and potential segfaults if supplied arguments are not valid.
--
-- @since 0.6.0
makeUnsafeConvolutionStencil
  :: (Index ix, Num e)
  => Sz ix
  -> ix
  -> ((ix -> e -> e -> e) -> e -> e)
  -> Stencil ix e e
makeUnsafeConvolutionStencil :: Sz ix -> ix -> ((ix -> e -> e -> e) -> e -> e) -> Stencil ix e e
makeUnsafeConvolutionStencil !Sz ix
sz !ix
sCenter (ix -> e -> e -> e) -> e -> e
relStencil =
  Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> e) -> Stencil ix e e
forall ix e a.
Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
Stencil Sz ix
sz ix
sInvertCenter (ix -> e) -> (ix -> e) -> ix -> e
forall p. (ix -> e) -> p -> ix -> e
stencil
  where
    !sInvertCenter :: ix
sInvertCenter = (Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ((Int -> Int) -> ix -> ix
forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1) (Sz ix -> ix
forall ix. Sz ix -> ix
unSz Sz ix
sz)) ix
sCenter
    stencil :: (ix -> e) -> p -> ix -> e
stencil ix -> e
uget p
_ !ix
ix =
      (((ix -> e -> e -> e) -> e -> e) -> (ix -> e -> e -> e) -> e -> e
forall a. a -> a
inline (ix -> e -> e -> e) -> e -> e
relStencil ((ix -> e -> e -> e) -> e -> e) -> (ix -> e -> e -> e) -> e -> e
forall a b. (a -> b) -> a -> b
$ \ !ix
ixD !e
kVal !e
acc -> ix -> e
uget ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 (-) ix
ix ix
ixD) e -> e -> e
forall a. Num a => a -> a -> a
* e
kVal e -> e -> e
forall a. Num a => a -> a -> a
+ e
acc) e
0
    {-# INLINE stencil #-}
{-# INLINE makeUnsafeConvolutionStencil #-}

-- | Same as `Data.Massiv.Array.Stencil.makeCorrelationStencil`, but will result in
-- reading memory out of bounds and potential segfaults if supplied arguments are not
-- valid.
--
-- @since 0.6.0
makeUnsafeCorrelationStencil
  :: (Index ix, Num e)
  => Sz ix
  -> ix
  -> ((ix -> e -> e -> e) -> e -> e)
  -> Stencil ix e e
makeUnsafeCorrelationStencil :: Sz ix -> ix -> ((ix -> e -> e -> e) -> e -> e) -> Stencil ix e e
makeUnsafeCorrelationStencil !Sz ix
sSz !ix
sCenter (ix -> e -> e -> e) -> e -> e
relStencil = Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> e) -> Stencil ix e e
forall ix e a.
Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
Stencil Sz ix
sSz ix
sCenter (ix -> e) -> (ix -> e) -> ix -> e
forall p. p -> (ix -> e) -> ix -> e
stencil
  where
    stencil :: p -> (ix -> e) -> ix -> e
stencil p
_ ix -> e
getVal !ix
ix =
      (((ix -> e -> e -> e) -> e -> e) -> (ix -> e -> e -> e) -> e -> e
forall a. a -> a
inline (ix -> e -> e -> e) -> e -> e
relStencil ((ix -> e -> e -> e) -> e -> e) -> (ix -> e -> e -> e) -> e -> e
forall a b. (a -> b) -> a -> b
$ \ !ix
ixD !e
kVal !e
acc -> ix -> e
getVal ((Int -> Int -> Int) -> ix -> ix -> ix
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) ix
ix ix
ixD) e -> e -> e
forall a. Num a => a -> a -> a
* e
kVal e -> e -> e
forall a. Num a => a -> a -> a
+ e
acc) e
0
    {-# INLINE stencil #-}
{-# INLINE makeUnsafeCorrelationStencil #-}


-- | Perform an arbitrary transformation of a stencil. This stencil modifier can be used for
-- example to turn a vector stencil into a matrix stencil implement, or transpose a matrix
-- stencil. It is really easy to get this wrong, so be extremely careful.
--
-- ====__Examples__
--
-- Convert a 1D stencil into a row or column 2D stencil:
--
-- >>> import Data.Massiv.Array
-- >>> import Data.Massiv.Array.Unsafe
-- >>> let arr = compute $ iterateN 3 succ 0 :: Array P Ix2 Int
-- >>> arr
-- Array P Seq (Sz (3 :. 3))
--   [ [ 1, 2, 3 ]
--   , [ 4, 5, 6 ]
--   , [ 7, 8, 9 ]
--   ]
-- >>> let rowStencil = unsafeTransformStencil (\(Sz n) -> Sz (1 :. n)) (0 :.) $ \ f uget getVal (i :. j) -> f (uget  . (i :.)) (getVal . (i :.)) j
-- >>> applyStencil noPadding (rowStencil (sumStencil (Sz1 3))) arr
-- Array DW Seq (Sz (3 :. 1))
--   [ [ 6 ]
--   , [ 15 ]
--   , [ 24 ]
--   ]
-- >>> let columnStencil = unsafeTransformStencil (\(Sz n) -> Sz (n :. 1)) (:. 0) $ \ f uget getVal (i :. j) -> f (uget . (:. j)) (getVal . (:. j)) i
-- >>> applyStencil noPadding (columnStencil (sumStencil (Sz1 3))) arr
-- Array DW Seq (Sz (1 :. 3))
--   [ [ 12, 15, 18 ]
--   ]
--
-- @since 0.5.4
unsafeTransformStencil ::
     (Sz ix' -> Sz ix)
  -- ^ Forward modifier for the size
  -> (ix' -> ix)
  -- ^ Forward index modifier
  -> (((ix' -> e) -> (ix' -> e) -> ix' -> a)
      -> (ix -> e) -> (ix -> e) -> ix -> a)
  -- ^ Inverse stencil function modifier
  -> Stencil ix' e a
  -- ^ Original stencil.
  -> Stencil ix e a
unsafeTransformStencil :: (Sz ix' -> Sz ix)
-> (ix' -> ix)
-> (((ix' -> e) -> (ix' -> e) -> ix' -> a)
    -> (ix -> e) -> (ix -> e) -> ix -> a)
-> Stencil ix' e a
-> Stencil ix e a
unsafeTransformStencil Sz ix' -> Sz ix
transformSize ix' -> ix
transformIndex ((ix' -> e) -> (ix' -> e) -> ix' -> a)
-> (ix -> e) -> (ix -> e) -> ix -> a
transformFunc Stencil {ix'
Sz ix'
(ix' -> e) -> (ix' -> e) -> ix' -> a
stencilFunc :: forall ix e a. Stencil ix e a -> (ix -> e) -> (ix -> e) -> ix -> a
stencilCenter :: forall ix e a. Stencil ix e a -> ix
stencilSize :: forall ix e a. Stencil ix e a -> Sz ix
stencilFunc :: (ix' -> e) -> (ix' -> e) -> ix' -> a
stencilCenter :: ix'
stencilSize :: Sz ix'
..} =
  Stencil :: forall ix e a.
Sz ix
-> ix -> ((ix -> e) -> (ix -> e) -> ix -> a) -> Stencil ix e a
Stencil
    { stencilSize :: Sz ix
stencilSize = Sz ix' -> Sz ix
transformSize Sz ix'
stencilSize
    , stencilCenter :: ix
stencilCenter = ix' -> ix
transformIndex ix'
stencilCenter
    , stencilFunc :: (ix -> e) -> (ix -> e) -> ix -> a
stencilFunc = ((ix' -> e) -> (ix' -> e) -> ix' -> a)
-> (ix -> e) -> (ix -> e) -> ix -> a
transformFunc (ix' -> e) -> (ix' -> e) -> ix' -> a
stencilFunc
    }
{-# INLINE unsafeTransformStencil #-}



{-

Invalid stencil transformer function.

TODO: figure out if there is a safe way to do stencil index trnasformation.


transformStencil ::
     (Default e, Index ix)
  => (Sz ix' -> Sz ix)
  -- ^ Forward modifier for the size
  -> (ix' -> ix)
  -- ^ Forward index modifier
  -> (ix -> ix')
  -- ^ Inverse index modifier
  -> Stencil ix' e a
  -- ^ Original stencil.
  -> Stencil ix e a
transformStencil transformSize transformIndex transformIndex' stencil =
  validateStencil def $! unsafeTransformStencil transformSize transformIndex transformIndex' stencil
{-# INLINE transformStencil #-}


-}