{-# LANGUAGE ForeignFunctionInterface #-}
module Grenade.Layers.Internal.Pooling (
    poolForward
  , poolBackward
  ) where

import qualified Data.Vector.Storable as U ( unsafeToForeignPtr0, unsafeFromForeignPtr0 )

import           Foreign ( mallocForeignPtrArray, withForeignPtr )
import           Foreign.Ptr ( Ptr )

import           Numeric.LinearAlgebra ( Matrix , flatten )
import qualified Numeric.LinearAlgebra.Devel as U

import           System.IO.Unsafe ( unsafePerformIO )

poolForward :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
poolForward channels height width kernelRows kernelColumns strideRows strideColumns dataIm =
  let vec             = flatten dataIm
      rowOut          = (height - kernelRows) `div` strideRows + 1
      colOut          = (width - kernelColumns) `div` strideColumns + 1
      numberOfPatches = rowOut * colOut
  in unsafePerformIO $ do
    outPtr <- mallocForeignPtrArray (numberOfPatches * channels)
    let (inPtr, _) = U.unsafeToForeignPtr0 vec

    withForeignPtr inPtr $ \inPtr' ->
      withForeignPtr outPtr $ \outPtr' ->
        pool_forwards_cpu inPtr' channels height width kernelRows kernelColumns strideRows strideColumns outPtr'

    let matVec = U.unsafeFromForeignPtr0 outPtr (numberOfPatches * channels)
    return $ U.matrixFromVector U.RowMajor (rowOut * channels) colOut matVec

foreign import ccall unsafe
    pool_forwards_cpu
      :: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()

poolBackward :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
poolBackward channels height width kernelRows kernelColumns strideRows strideColumns dataIm dataGrad =
  let vecIm     = flatten dataIm
      vecGrad   = flatten dataGrad
  in unsafePerformIO $ do
    outPtr <- mallocForeignPtrArray (height * width * channels)
    let (imPtr, _) = U.unsafeToForeignPtr0 vecIm
    let (gradPtr, _) = U.unsafeToForeignPtr0 vecGrad

    withForeignPtr imPtr $ \imPtr' ->
      withForeignPtr gradPtr $ \gradPtr' ->
        withForeignPtr outPtr $ \outPtr' ->
          pool_backwards_cpu imPtr' gradPtr' channels height width kernelRows kernelColumns strideRows strideColumns outPtr'

    let matVec = U.unsafeFromForeignPtr0 outPtr (height * width * channels)
    return $ U.matrixFromVector U.RowMajor (height * channels) width matVec

foreign import ccall unsafe
    pool_backwards_cpu
      :: Ptr Double -> Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()