----------------------------------------------------
-- |
-- Module     :  AI.Network
-- License    :  GPL
--
-- Maintainer :  Kiet Lam <ktklam9@gmail.com>
--
--
-- This module provides training algorithms to train
-- a neural network given training data.
--
-- User should only use LBFGS though because
-- it uses custom bindings to the C-library liblbfgs
--
-- GSL's multivariate minimization algorithms are known to be inefficient
-- <http://www.alglib.net/optimization/lbfgsandcg.php#header6>
-- and LBFGS outperforms them on many (of my) tests
--
--
----------------------------------------------------


module AI.Training (
  TrainingAlgorithm(..),
  trainNetwork
  ) where

import Numeric.GSL.Minimization
import Data.Packed.Vector
import Data.Packed.Matrix

import AI.Training.Internal
import AI.Signatures
import AI.Calculation
import AI.Network


-- | The types of training algorithm to use
--
-- NOTE: These are all batch training algorithms
data TrainingAlgorithm = GradientDescent   -- ^ hmatrix's binding to GSL
                       | ConjugateGradient -- ^ hmatrix's binding to GSL
                       | BFGS              -- ^ hmatrix's binding to GSL
                       | LBFGS             -- ^ home-made binding to liblbfgs
                         deriving (Show, Read, Enum)


-- This function is needed to work with HMatrix's
-- multivariate minimization algorithms
vectorWeightToCost :: CostFunction     -- The cost function
                      -> Network       -- The neural network
                      -> Matrix Double -- The input matrix
                      -> Matrix Double -- The output matrix
                      -> Vector Double -- The vector weights
                      -> Double        -- Returns the calculated cost
vectorWeightToCost costF nn inMat exMat ws = costF (setWeights nn ws) inMat exMat


-- This function is needed to work with HMatrix's
-- multivariate minimization algorithms
vectorWeightToGradients :: GradientFunction -- The function can can calculate the
                                            -- gradient vector given a cost model
                           -> Cost          -- the cost model
                           -> Network       -- The neural network
                           -> Matrix Double -- The input matrix
                           -> Matrix Double -- The output matrix
                           -> Vector Double -- The vector weights
                           -> Vector Double -- Returns the vector gradients
vectorWeightToGradients gradF cost nn inMat exMat ws =
  gradF (getCostFunction cost) (getCostDerivative cost) (setWeights nn ws) inMat exMat


-- | Train the neural network given a training algorithm,
-- the training parameters and the training data
trainNetwork :: TrainingAlgorithm   -- ^ The training algorithm to use
                -> Cost             -- ^ The cost model of the neural network
                -> GradientFunction -- ^ The function that can calculate the
                                    --   gradients vector
                -> Network          -- ^ The network to be trained
                -> Double           -- ^ The precision of the training with regards
                                    --   to the cost function
                -> Int              -- ^ The maximum number of iterations
                -> Matrix Double    -- ^ The input matrix
                -> Matrix Double    -- ^ The expected output matrix
                -> Network          -- ^ Returns the trained network
trainNetwork algo cost gradF nn prec iterations inMat exMat =
  let ws = toWeights nn        -- Get the initial weights of the network

      -- f represents the cost function to minimize
      f = vectorWeightToCost (getCostFunction cost) nn inMat exMat

      -- g represents the function that can calculate the gradient
      -- vector of the parameters (the weights)
      g = vectorWeightToGradients gradF cost nn inMat exMat

      -- Get the training algorithm
      trainAlgo = getTrainAlgo algo

      -- Set the tol and initial step size
      initStepSize = 0.1
      tol = 0.1

      -- Use the training algorithm to train the weights
      trainedWeights = trainAlgo prec iterations initStepSize tol f g ws
  in
   setWeights nn trainedWeights


-- Auxilary function for trainNetwork
getTrainAlgo :: TrainingAlgorithm
                -> Double
                -> Int
                -> Double
                -> Double
                -> (Vector Double -> Double)
                -> (Vector Double -> Vector Double)
                -> Vector Double
                -> Vector Double
getTrainAlgo GradientDescent prec iter step tol f df initVec = fst $ minimizeVD SteepestDescent prec iter step tol f df initVec
getTrainAlgo ConjugateGradient prec iter step tol f df initVec = fst $ minimizeVD ConjugatePR prec iter step tol f df initVec
getTrainAlgo BFGS prec iter step tol f df initVec = fst $ minimizeVD VectorBFGS2 prec iter step tol f df initVec
getTrainAlgo LBFGS prec iter step tol f df initVec = minimizeLBFGS prec iter step tol f df initVec