module Grenade.Layers.Internal.Pooling (
poolForward
, poolBackward
) where
import qualified Data.Vector.Storable as U ( unsafeToForeignPtr0, unsafeFromForeignPtr0 )
import Foreign ( mallocForeignPtrArray, withForeignPtr )
import Foreign.Ptr ( Ptr )
import Numeric.LinearAlgebra ( Matrix , flatten )
import qualified Numeric.LinearAlgebra.Devel as U
import System.IO.Unsafe ( unsafePerformIO )
poolForward :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
poolForward channels height width kernelRows kernelColumns strideRows strideColumns dataIm =
let vec = flatten dataIm
rowOut = (height kernelRows) `div` strideRows + 1
colOut = (width kernelColumns) `div` strideColumns + 1
numberOfPatches = rowOut * colOut
in unsafePerformIO $ do
outPtr <- mallocForeignPtrArray (numberOfPatches * channels)
let (inPtr, _) = U.unsafeToForeignPtr0 vec
withForeignPtr inPtr $ \inPtr' ->
withForeignPtr outPtr $ \outPtr' ->
pool_forwards_cpu inPtr' channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
let matVec = U.unsafeFromForeignPtr0 outPtr (numberOfPatches * channels)
return $ U.matrixFromVector U.RowMajor (rowOut * channels) colOut matVec
foreign import ccall unsafe
pool_forwards_cpu
:: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
poolBackward :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
poolBackward channels height width kernelRows kernelColumns strideRows strideColumns dataIm dataGrad =
let vecIm = flatten dataIm
vecGrad = flatten dataGrad
in unsafePerformIO $ do
outPtr <- mallocForeignPtrArray (height * width * channels)
let (imPtr, _) = U.unsafeToForeignPtr0 vecIm
let (gradPtr, _) = U.unsafeToForeignPtr0 vecGrad
withForeignPtr imPtr $ \imPtr' ->
withForeignPtr gradPtr $ \gradPtr' ->
withForeignPtr outPtr $ \outPtr' ->
pool_backwards_cpu imPtr' gradPtr' channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
let matVec = U.unsafeFromForeignPtr0 outPtr (height * width * channels)
return $ U.matrixFromVector U.RowMajor (height * channels) width matVec
foreign import ccall unsafe
pool_backwards_cpu
:: Ptr Double -> Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()