{-# LANGUAGE ForeignFunctionInterface #-}
module Grenade.Layers.Internal.Update (
    decendMatrix
  , decendVector
  ) where

import           Data.Maybe ( fromJust )
import qualified Data.Vector.Storable as U ( unsafeToForeignPtr0, unsafeFromForeignPtr0 )

import           Foreign ( mallocForeignPtrArray, withForeignPtr )
import           Foreign.Ptr ( Ptr )
import           GHC.TypeLits

import           Numeric.LinearAlgebra ( Vector, flatten )
import           Numeric.LinearAlgebra.Static
import qualified Numeric.LinearAlgebra.Devel as U

import           System.IO.Unsafe ( unsafePerformIO )

decendMatrix :: (KnownNat rows, KnownNat columns) => Double -> Double -> Double -> L rows columns -> L rows columns -> L rows columns -> (L rows columns, L rows columns)
decendMatrix rate momentum regulariser weights gradient lastUpdate =
  let (rows, cols) = size weights
      len          = rows * cols
      -- Most gradients come in in ColumnMajor,
      -- so we'll transpose here before flattening them
      -- into a vector to prevent a copy.
      --
      -- This gives ~15% speed improvement for LSTMs.
      weights'     = flatten . tr . extract $ weights
      gradient'    = flatten . tr . extract $ gradient
      lastUpdate'  = flatten . tr . extract $ lastUpdate
      (vw, vm)     = decendUnsafe len rate momentum regulariser weights' gradient' lastUpdate'

      -- Note that it's ColumnMajor, as we did a transpose before
      -- using the internal vectors.
      mw           = U.matrixFromVector U.ColumnMajor rows cols vw
      mm           = U.matrixFromVector U.ColumnMajor rows cols vm
  in  (fromJust . create $ mw, fromJust . create $ mm)

decendVector :: (KnownNat r) => Double -> Double -> Double -> R r -> R r -> R r -> (R r, R r)
decendVector rate momentum regulariser weights gradient lastUpdate =
  let len          = size weights
      weights'     = extract weights
      gradient'    = extract gradient
      lastUpdate'  = extract lastUpdate
      (vw, vm)     = decendUnsafe len rate momentum regulariser weights' gradient' lastUpdate'
  in  (fromJust $ create vw, fromJust $ create vm)

decendUnsafe :: Int -> Double -> Double -> Double -> Vector Double -> Vector Double -> Vector Double -> (Vector Double, Vector Double)
decendUnsafe len rate momentum regulariser weights gradient lastUpdate =
  unsafePerformIO $ do
    outWPtr <- mallocForeignPtrArray len
    outMPtr <- mallocForeignPtrArray len
    let (wPtr, _) = U.unsafeToForeignPtr0 weights
    let (gPtr, _) = U.unsafeToForeignPtr0 gradient
    let (lPtr, _) = U.unsafeToForeignPtr0 lastUpdate

    withForeignPtr wPtr $ \wPtr' ->
      withForeignPtr gPtr $ \gPtr' ->
        withForeignPtr lPtr $ \lPtr' ->
          withForeignPtr outWPtr $ \outWPtr' ->
            withForeignPtr outMPtr $ \outMPtr' ->
              decend_cpu len rate momentum regulariser wPtr' gPtr' lPtr' outWPtr' outMPtr'

    return (U.unsafeFromForeignPtr0 outWPtr len, U.unsafeFromForeignPtr0 outMPtr len)

foreign import ccall unsafe
    decend_cpu
      :: Int -> Double -> Double -> Double -> Ptr Double -> Ptr Double -> Ptr Double -> Ptr Double -> Ptr Double -> IO ()