{-# LANGUAGE BangPatterns          #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE RecordWildCards       #-}

module Grenade.Recurrent.Core.Runner (
    trainRecurrent
  , runRecurrent
  , backPropagateRecurrent
  ) where

import           Data.Singletons.Prelude
import           Grenade.Core

import           Grenade.Recurrent.Core.Layer
import           Grenade.Recurrent.Core.Network

-- | Drive and network and collect its back propogated gradients.
backPropagateRecurrent :: forall shapes layers. (SingI (Last shapes), Num (RecurrentInputs layers))
                       => RecurrentNetwork layers shapes
                       -> RecurrentInputs layers
                       -> [(S (Head shapes), Maybe (S (Last shapes)))]
                       -> (RecurrentGradients layers, RecurrentInputs layers)
backPropagateRecurrent network recinputs examples =
  let (tapes, _, guesses)    = runRecurrentNetwork network recinputs inputs

      backPropagations       = zipWith makeError guesses targets

      (gradients, input', _) = runRecurrentGradient network tapes 0 backPropagations

  in (gradients, input')

    where

  inputs  = fst <$> examples
  targets = snd <$> examples

  makeError :: S (Last shapes) -> Maybe (S (Last shapes)) -> S (Last shapes)
  makeError _ Nothing = 0
  makeError y (Just t) = y - t


trainRecurrent :: forall shapes layers. (SingI (Last shapes), Num (RecurrentInputs layers))
               => LearningParameters
               -> RecurrentNetwork layers shapes
               -> RecurrentInputs layers
               -> [(S (Head shapes), Maybe (S (Last shapes)))]
               -> (RecurrentNetwork layers shapes, RecurrentInputs layers)
trainRecurrent rate network recinputs examples =
  let (gradients, recinputs') = backPropagateRecurrent network recinputs examples

      newInputs               = updateRecInputs rate recinputs recinputs'

      newNetwork              = applyRecurrentUpdate rate network gradients

  in  (newNetwork, newInputs)

updateRecInputs :: LearningParameters
                -> RecurrentInputs sublayers
                -> RecurrentInputs sublayers
                -> RecurrentInputs sublayers

updateRecInputs l@LearningParameters {..} (() :~~+> xs) (() :~~+> ys)
  = () :~~+> updateRecInputs l xs ys

updateRecInputs l@LearningParameters {..} (x :~@+> xs) (y :~@+> ys)
  = (realToFrac (1 - learningRate * learningRegulariser) * x - realToFrac learningRate * y) :~@+> updateRecInputs l xs ys

updateRecInputs _ RINil RINil
  = RINil

-- | Just forwards propagation with no training.
runRecurrent :: RecurrentNetwork layers shapes
             -> RecurrentInputs layers -> S (Head shapes)
             -> (RecurrentInputs layers, S (Last shapes))
runRecurrent (layer :~~> n) (()    :~~+> nr) !x
  = let (_, ys)  = runForwards layer x
        (nr', o) = runRecurrent n nr ys
    in  (() :~~+> nr', o)
runRecurrent (layer :~@> n) (recin :~@+> nr) !x
  = let (_, recin', y) = runRecurrentForwards layer recin x
        (nr', o)       = runRecurrent n nr y
    in  (recin' :~@+> nr', o)
runRecurrent RNil RINil !x
  = (RINil, x)