{-# LANGUAGE BangPatterns          #-}
{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE RecordWildCards       #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE ViewPatterns          #-}
{-# LANGUAGE ScopedTypeVariables   #-}

module Grenade.Recurrent.Layers.LSTM (
    LSTM (..)
  , LSTMWeights (..)
  , randomLSTM
  ) where

import           Control.Monad.Random ( MonadRandom, getRandom )

-- import           Data.List ( foldl1' )
import           Data.Proxy
import           Data.Serialize
import           Data.Singletons.TypeLits

import qualified Numeric.LinearAlgebra as LA
import           Numeric.LinearAlgebra.Static

import           Grenade.Core
import           Grenade.Recurrent.Core
import           Grenade.Layers.Internal.Update


-- | Long Short Term Memory Recurrent unit
--
--   This is a Peephole formulation, so the recurrent shape is
--   just the cell state, the previous output is not held or used
--   at all.
data LSTM :: Nat -> Nat -> * where
  LSTM :: ( KnownNat input
          , KnownNat output
          ) => !(LSTMWeights input output) -- Weights
            -> !(LSTMWeights input output) -- Momentums
            -> LSTM input output

data LSTMWeights :: Nat -> Nat -> * where
  LSTMWeights :: ( KnownNat input
                 , KnownNat output
                 ) => {
                   lstmWf :: !(L output input)  -- Weight Forget     (W_f)
                 , lstmUf :: !(L output output) -- Cell State Forget (U_f)
                 , lstmBf :: !(R output)        -- Bias Forget       (b_f)
                 , lstmWi :: !(L output input)  -- Weight Input      (W_i)
                 , lstmUi :: !(L output output) -- Cell State Input  (U_i)
                 , lstmBi :: !(R output)        -- Bias Input        (b_i)
                 , lstmWo :: !(L output input)  -- Weight Output     (W_o)
                 , lstmUo :: !(L output output) -- Cell State Output (U_o)
                 , lstmBo :: !(R output)        -- Bias Output       (b_o)
                 , lstmWc :: !(L output input)  -- Weight Cell       (W_c)
                 , lstmBc :: !(R output)        -- Bias Cell         (b_c)
                 } -> LSTMWeights input output

instance Show (LSTM i o) where
  show LSTM {} = "LSTM"

instance (KnownNat i, KnownNat o) => UpdateLayer (LSTM i o) where
  -- The gradients are the same shape as the weights and momentum
  -- This seems to be a general pattern, maybe it should be enforced.
  type Gradient (LSTM i o) = (LSTMWeights i o)

  -- Run the update function for each group matrix/vector of weights, momentums and gradients.
  -- Hmm, maybe the function should be used instead of passing in the learning parameters.
  runUpdate LearningParameters {..} (LSTM w m) g =
    let (wf, wf') = u lstmWf w m g
        (uf, uf') = u lstmUf w m g
        (bf, bf') = v lstmBf w m g
        (wi, wi') = u lstmWi w m g
        (ui, ui') = u lstmUi w m g
        (bi, bi') = v lstmBi w m g
        (wo, wo') = u lstmWo w m g
        (uo, uo') = u lstmUo w m g
        (bo, bo') = v lstmBo w m g
        (wc, wc') = u lstmWc w m g
        (bc, bc') = v lstmBc w m g
    in LSTM (LSTMWeights wf uf bf wi ui bi wo uo bo wc bc) (LSTMWeights wf' uf' bf' wi' ui' bi' wo' uo' bo' wc' bc')
      where
    -- Utility function for updating with the momentum, gradients, and weights.
    u :: forall x ix out. (KnownNat ix, KnownNat out) => (x -> (L out ix)) -> x -> x -> x -> ((L out ix), (L out ix))
    u e (e -> weights) (e -> momentum) (e -> gradient) =
      decendMatrix learningRate learningMomentum learningRegulariser weights gradient momentum

    v :: forall x ix. (KnownNat ix) => (x -> (R ix)) -> x -> x -> x -> ((R ix), (R ix))
    v e (e -> weights) (e -> momentum) (e -> gradient) =
      decendVector learningRate learningMomentum learningRegulariser weights gradient momentum

  -- There's a lot of updates here, so to try and minimise the number of data copies
  -- we'll create a mutable bucket for each.
  -- runUpdates rate lstm gs =
  --   let combinedGradient = foldl1' uu gs
  --   in  runUpdate rate lstm combinedGradient
  --     where
  --   uu :: (KnownNat i, KnownNat o) => LSTMWeights i o -> LSTMWeights i o -> LSTMWeights i o
  --   uu a b =
  --     let wf = u lstmWf a b
  --         uf = u lstmUf a b
  --         bf = v lstmBf a b
  --         wi = u lstmWi a b
  --         ui = u lstmUi a b
  --         bi = v lstmBi a b
  --         wo = u lstmWo a b
  --         uo = u lstmUo a b
  --         bo = v lstmBo a b
  --         wc = u lstmWc a b
  --         bc = v lstmBc a b
  --     in LSTMWeights wf uf bf wi ui bi wo uo bo wc bc
  --   u :: forall x ix out. (KnownNat ix, KnownNat out) => (x -> (L out ix)) -> x -> x -> L out ix
  --   u e (e -> a) (e -> b) = tr $ tr a + tr b

  --   v :: forall x ix. (x -> (R ix)) -> x -> x -> R ix
  --   v e (e -> a) (e -> b) = a + b
  createRandom = randomLSTM

instance (KnownNat i, KnownNat o) => RecurrentUpdateLayer (LSTM i o) where
  -- The recurrent shape is the same size as the output.
  -- It's actually the cell state however, as this is a peephole variety LSTM.
  type RecurrentShape (LSTM i o) = 'D1 o

instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) where

  type RecTape (LSTM i o) ('D1 i) ('D1 o) = (S ('D1 o), S ('D1 i))
  -- Forward propagation for the LSTM layer.
  -- The size of the cell state is also the size of the output.
  runRecurrentForwards (LSTM (LSTMWeights {..}) _) (S1D cell) (S1D input) =
    let -- Forget state vector
        f_t = sigmoid $ lstmBf + lstmWf #> input + lstmUf #> cell
        -- Input state vector
        i_t = sigmoid $ lstmBi + lstmWi #> input + lstmUi #> cell
        -- Output state vector
        o_t = sigmoid $ lstmBo + lstmWo #> input + lstmUo #> cell
        -- Cell input state vector
        c_x = tanh    $ lstmBc + lstmWc #> input
        -- Cell state
        c_t = f_t * cell + i_t * c_x
        -- Output (it's sometimes recommended to use tanh c_t)
        h_t = o_t * c_t
    in ((S1D cell, S1D input), S1D c_t, S1D h_t)

  -- Run a backpropogation step for an LSTM layer.
  -- We're doing all the derivatives by hand here, so one should
  -- be extra careful when changing this.
  --
  -- There's a test version using the AD library without hmatrix in the test
  -- suite. These should match always.
  runRecurrentBackwards (LSTM (LSTMWeights {..}) _) (S1D cell, S1D input) (S1D cellGrad) (S1D h_t') =
    -- We're not keeping the Wengert tape during the forward pass,
    -- so we're duplicating some work here.
    --
    -- If I was being generous, I'd call it checkpointing.
    --
    -- Maybe think about better ways to store some intermediate states.
    let -- Forget state vector
        f_s = lstmBf + lstmWf #> input + lstmUf #> cell
        f_t = sigmoid f_s
        -- Input state vector
        i_s = lstmBi + lstmWi #> input + lstmUi #> cell
        i_t = sigmoid i_s
        -- Output state vector
        o_s = lstmBo + lstmWo #> input + lstmUo #> cell
        o_t = sigmoid o_s
        -- Cell input state vector
        c_s = lstmBc + lstmWc #> input
        c_x = tanh c_s
        -- Cell state
        c_t = f_t * cell + i_t * c_x

        -- Reverse Mode AD Derivitives
        c_t' = h_t' * o_t + cellGrad

        f_t' = c_t' * cell
        f_s' = sigmoid' f_s * f_t'

        o_t' = h_t' * c_t
        o_s' = sigmoid' o_s * o_t'

        i_t' = c_t' * c_x
        i_s' = sigmoid' i_s * i_t'

        c_x' = c_t' * i_t
        c_s' = tanh' c_s * c_x'

        -- The derivatives to pass sideways (recurrent) and downwards
        cell'  = tr lstmUf #> f_s' + tr lstmUo #> o_s' + tr lstmUi #> i_s' + c_t' * f_t
        input' = tr lstmWf #> f_s' + tr lstmWo #> o_s' + tr lstmWi #> i_s' + tr lstmWc #> c_s'

        -- Calculate the gradient Matricies for the input
        lstmWf' = f_s' `outer` input
        lstmWi' = i_s' `outer` input
        lstmWo' = o_s' `outer` input
        lstmWc' = c_s' `outer` input

        -- Calculate the gradient Matricies for the cell
        lstmUf' = f_s' `outer` cell
        lstmUi' = i_s' `outer` cell
        lstmUo' = o_s' `outer` cell

        -- The biases just get the values, but we'll write it so it's obvious
        lstmBf' = f_s'
        lstmBi' = i_s'
        lstmBo' = o_s'
        lstmBc' = c_s'

        gradients = LSTMWeights lstmWf' lstmUf' lstmBf' lstmWi' lstmUi' lstmBi' lstmWo' lstmUo' lstmBo' lstmWc' lstmBc'
    in  (gradients, S1D cell', S1D input')

-- | Generate an LSTM layer with random Weights
--   one can also just call createRandom from UpdateLayer
--
--   Has forget gate biases set to 1 to encourage early learning.
--
--   https://github.com/karpathy/char-rnn/commit/0dfeaa454e687dd0278f036552ea1e48a0a408c9
--
randomLSTM :: forall m i o. (MonadRandom m, KnownNat i, KnownNat o)
           => m (LSTM i o)
randomLSTM = do
    let w = (\s -> uniformSample s (-1) 1 ) <$> getRandom
        u = (\s -> uniformSample s (-1) 1 ) <$> getRandom
        v = (\s -> randomVector s Uniform * 2 - 1) <$> getRandom

        w0 = konst 0
        u0 = konst 0
        v0 = konst 0

    LSTM <$> (LSTMWeights <$> w <*> u <*> pure (konst 1) <*> w <*> u <*> v <*> w <*> u <*> v <*> w <*> v)
         <*> pure (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0)

-- | Maths
--
-- TODO: move to not here
sigmoid :: Floating a => a -> a
sigmoid x = 1 / (1 + exp (-x))

sigmoid' :: Floating a => a -> a
sigmoid' x = logix * (1 - logix)
  where
    logix = sigmoid x

tanh' :: (Floating a) => a -> a
tanh' t = 1 - s ^ (2 :: Int)  where s = tanh t

instance (KnownNat i, KnownNat o) => Serialize (LSTM i o) where
  put (LSTM LSTMWeights {..} _) = do
    u lstmWf
    u lstmUf
    v lstmBf
    u lstmWi
    u lstmUi
    v lstmBi
    u lstmWo
    u lstmUo
    v lstmBo
    u lstmWc
    v lstmBc
      where
    u :: forall a b. (KnownNat a, KnownNat b) => Putter  (L b a)
    u = putListOf put . LA.toList . LA.flatten . extract
    v :: forall a. (KnownNat a) => Putter (R a)
    v = putListOf put . LA.toList . extract

  get = do
    lstmWf <- u
    lstmUf <- u
    lstmBf <- v
    lstmWi <- u
    lstmUi <- u
    lstmBi <- v
    lstmWo <- u
    lstmUo <- u
    lstmBo <- v
    lstmWc <- u
    lstmBc <- v
    return $ LSTM (LSTMWeights {..}) (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0)
      where
    u :: forall a b. (KnownNat a, KnownNat b) => Get  (L b a)
    u = let f = fromIntegral $ natVal (Proxy :: Proxy a)
        in  maybe (fail "Vector of incorrect size") return . create . LA.reshape f . LA.fromList =<< getListOf get
    v :: forall a. (KnownNat a) => Get (R a)
    v = maybe (fail "Vector of incorrect size") return . create . LA.fromList =<< getListOf get

    w0 = konst 0
    u0 = konst 0
    v0 = konst 0