-- |
-- Module     :  AI.Network
-- License    :  GPL
-- Maintainer :  Kiet Lam <ktklam9@gmail.com>
-- This module represents ways to calculate the gradient
-- vector of the weights of the neural network
-- Backpropagation should always be preferred over
-- the Numerical Gradient method

module AI.Calculation.Gradients (
  ) where

import Numeric.Container

import AI.Signatures
import AI.Network

-- | Calculate the analytical gradient of the weights of the network
-- by using backpropagation
backpropagation :: GradientFunction
backpropagation _ outputDeltasF nn inMatrix exMatrix =
  let n = rows exMatrix -- Get us the number of training sets

      -- Get important informations from the neural network
      activF = toActivation nn
      df = toDerivative nn
      ws = toWeightMatrices nn
      la = toLambda nn

      -- Function to set up the bias neurons for each forward propagation
      fBias = fromColumns . ((fromList (replicate n 1)):) . toColumns

      -- This is forward propagation right here
      -- We propagate forward by using the functional
      -- foldl accumulating over the weight matrices
      -- NOTE: Also, for the ability to use more complex activation
      -- function with non-trivial derivatives, we also accumulate
      -- the weighted inputs to each layer, so the accumulation
      -- for each forward propagation is a tuple of the activation
      -- of each neuron in the layer and the weighted inputs into the
      -- layer.
      f = \p w -> ((fBias . mapMatrix activF) (p `multiply` w), fBias $ p `multiply` w)

      -- Prepare the inMatrix to be used
      inMatrix' = fBias inMatrix

      -- Forward propagate each row of the inputs matrix and matrix-multiply it with the
      -- weight matrices calculated before.
      -- NOTE: Forward propagation can be calculated efficiently by using
      -- foldl where the initial value is the input matrix and we calculate
      -- and simultaneously accumulate the activation values of each layer
      activations = foldl (\(a@(p,_):as) w -> ((f p w):a:as)) [(inMatrix',inMatrix')] ws

      -- Helper function to remove bias neurons
      fRemoveBias = fromColumns . tail . toColumns

      -- Because we cannot possibly calculate the outputs node deltas,
      -- the user must supply a function that will do that
      -- We pass in the weighted inputs to the output neurons,
      -- the activation values of the output neurons,
      -- the expected matrix of the training set,
      -- the derivative of the activation function of the networks
      -- And we expect it to return for us the output nodes
      -- deltas for us to propagate backwards to each layer
      initialDeltas = outputDeltasF nn ((fRemoveBias . snd . head) activations)
                      ((fRemoveBias . fst . head) activations) exMatrix

      -- This one line is basically the entire backpropagation
      -- NOTE: Backpropagation can be computed efficiently
      -- using foldl. Because the activations we calculated above
      -- are in reverse order, we can efficiently backpropagate
      -- the initial output nodes deltas by simply folding
      -- backwards on the reverse of the weights
      -- The initial value for foldl is our initialDeltas
      -- and we accumulate the node deltas of each
      -- previous layer.
      -- NOTE: Because we also accumulated the weighted inputs
      -- to each layer, we can use more exotic activation
      -- function instead of the ones with trivial derivatives.
      -- Example: Instead of using the derivative of the
      -- sigmoid function as x * (1 - x), where x is the
      -- "sigmoided value", we can actually use the real
      -- derivative, which is (sigmoid x) * (1 - sigmoid x)
      -- where x is the weighted input
      -- This allows us to use more exotic activation
      -- function whose derivatives is non-trivial
      allDeltas = foldl (\(d:ds) (as,w) -> (fRemoveBias $ (d `multiply` (trans w)) `mul` (mapMatrix df as)):d:ds)
                  [initialDeltas] (zip ((tail . map snd) activations) (reverse ws))

      -- Now this is where we finally calculate the gradients
      -- by multiplying activations of each layer to the
      -- node deltas of the next layer
      grads = [[a `outer` d | (a, d) <- zip (toRows as) (toRows deltas)]
              | (as, deltas) <- zip ((tail . map fst) activations) (reverse allDeltas)]

      -- zeroF gets us a zero matrix given a row and a column
      -- I believe the HMatrix package must have a function to
      -- create zero matrices, but I haven't fond it yet... >_<
      zeroF = \m -> buildMatrix (rows m) (cols m) (\_ -> 0.0)

      -- Now we add all of the gradients together
      -- and the average of the gradients by dividing by
      -- the number of the training sets
      gradsSums = map (mapMatrix (/(fromIntegral n))) [foldl add ((zeroF . head) g) g | g <- grads]
   -- And after all that exhaustive work, we flatten the matrices into
   -- one big vector and add regularization to it
   zipVectorWith (\x y -> x + (la / fromIntegral n) * y) ((join . map flatten) (reverse gradsSums)) (toWeights nn)

-- | NOTE: This should only be used as a last resort
-- if for some reason (bugs?) the backpropagation
-- algorithm does not give you good gradients
-- The numerical algorithm requires two forward
-- propagations, while the backpropagation algorithm
-- only requires one, so this is more costly
-- Also, analytical gradients almost always perform
-- better than numerical gradients
-- User must provide an epsilon value.
-- Make sure to use a very small value for the epsilon
-- for more accurate gradients
numericalGradients :: Double              -- ^ The epsilon
                      -> GradientFunction -- ^ Returns a gradient function
                                          --   that calculates the numerical
                                          --   gradients of the weights
numericalGradients epsilon costF _ nn inMat exMat =
  let plusE = \x -> x + epsilon  -- Add epsilon to the argument
      minusE = \x -> x - epsilon -- Subtract epsilon from the argument

      -- Get the vector representation of the weights
      params = toWeights nn

      -- Calculate a matrix of parameters that have been
      -- modified by adding and subtracting the epsilon value
      -- The result is two lists whose element is a vector
      -- of the parameters that have been modified
      dx1s = (toRows . mapElementToVector (modifyElementAt plusE)) params
      dx2s = (toRows . mapElementToVector (modifyElementAt minusE)) params

      f = \ws -> costF (setWeights nn ws) inMat exMat
      -- Now we calculate the costs of each modified parameters
      cost1 = (fromList . map f) dx1s
      cost2 = (fromList . map f) dx2s
   -- Use the (f(x+e) - (f(x-e)))/(2*e) to get the gradients
   mapVector (/ (2 * epsilon)) $ cost1 `sub` cost2

-- Map over every single element of a vector
-- and apply the function on each vector
-- This returns a matrix where the ith row is
-- a vector whose ith element is applied
-- to the function f
mapElementToVector :: (Vector Double -> Int -> Vector Double)
                      -> Vector Double
                      -> Matrix Double
mapElementToVector f vec =
  let n = dim vec
      -- Get the size of the vector

      -- Apply f over each element indexed by i
      -- and we join the list of vectors into a giant
      -- vector
      flattened = join [f vec i | i <- [0..n - 1]]
   -- Reshape the flattened vector into a matrix
   -- by using the size of the vector calculated before
   reshape n flattened

-- Modify one single element of the vector
-- by applying f to it
modifyElementAt :: (Double -> Double) -- The function to be applied
                   -> Vector Double   -- The vector of interest
                   -> Int             -- The index of the element ot be modified
                   -> Vector Double   -- The resulting modified vector
modifyElementAt f vec index =
  -- If the index passed into us is the index, we apply the
  -- function to the element, otherwise we leave it alone
  let g = \i v -> if' (i == index) (f v) (v) in
  mapVectorWithIndex g vec

if' :: Bool -> a -> a -> a
if' True  x _ = x
if' False _ y = y