----------------------------------------------------
-- |
-- Module     :  AI.Network
-- License    :  GPL
--
-- Maintainer :  Kiet Lam <ktklam9@gmail.com>
--
--
-- This module provides common cost functions
-- and their derivatives
--
--
----------------------------------------------------


module AI.Calculation.Cost (
  Cost(..),
  getCostFunction,
  getCostDerivative,
) where

import Data.Packed.Matrix
import Data.Packed.Vector
import Numeric.Container

import AI.Signatures
import AI.Network


-- | Represents the cost model
-- of the Neural Network
data Cost = MeanSquared -- ^ The mean-squared cost
          | Logistic    -- ^ The logistic cost


-- | Gets the cost function associated
-- with the cost model
getCostFunction :: Cost -> CostFunction
getCostFunction = generalCost . getErrorFunction


-- | Gets the cost derivative associated
-- with the cost model
getCostDerivative :: Cost -> CostDerivative
getCostDerivative MeanSquared = meanSquaredDerivative
getCostDerivative Logistic = logisticDerivative


-- The general cost function that can be extended by
-- partial function application
generalCost :: ErrorFunction    -- The error function to be used
               -> Network       -- The neural network of interest
               -> Matrix Double -- The matrix of the inputs, where the ith row
                                -- is the input vector of a training set
               -> Matrix Double -- The matrix of the expected output, where the ith
                                -- row is the expected output vector of a
                                -- training set
               -> Double        -- Returns the cost by comparing the network's
                                -- output neurons and the expected output matrix
generalCost errorF nn inMatrix exMatrix =
  let activF = toActivation nn -- The activation function
      ws = toWeightMatrices nn -- The list of weight matrices
      la = toLambda nn         -- The regularization constant
      n = rows exMatrix        -- Get us the number of training sets

      -- Set up the bias neurons for each forward propagation
      fBias = \m -> (fromColumns . ((fromList (replicate n 1)):) . toColumns) m

      -- This is forward propagation right here
      -- We propagate forward by using the functional
      -- foldl accumulating over the weight matrices
      f = \a w -> fBias (mapMatrix activF (a `multiply` w))

      -- Prepare the inMatrix to be used
      inMatrix' = fBias inMatrix

      -- We fold f over the weight matrices accumulating
      -- the activation matrix. The resulting activation matrix
      -- is our output.
      -- We take out the bias neurons in our output layer
      outMatrix = (fromColumns . tail . toColumns) $ foldl f inMatrix' ws

      -- We first convert the outMatrix and the exMatrix into
      -- list of rows. Then we zip them into oVec and eVec
      -- Then get the error between the oVec and the eVec using our
      -- errorF (error function) and the result
      -- is a list of errors. Then we create a vector
      -- from the list, thus resulting in an errorVec
      errorVec = fromList [errorF oVec eVec |
                           (oVec, eVec) <- zip (toRows outMatrix) (toRows exMatrix)]

      -- Now we get the un-regularized cost by
      -- summing the elements and taking the average
      -- by dividing by the number of the training sets
      j = (1 / fromIntegral n) * sumElements errorVec

      -- Now we get the vector representation of the
      -- weights to prepare for regularization
      wsFlattened = toWeights nn
  in
   -- Finally, we regularize our cost by using the lambda constant
   j + (la / (2.0 * fromIntegral n)) * (sumElements $ mapVector (**2) wsFlattened)


-- This is the general error function
-- It requires a function that will calculate an
-- error when given a calculated value and an
-- expected value
generalErrorFunction :: (Double -> Double -> Double) -- The function to calculate an "error"
                                                     -- Given a calculated value and an
                                                     -- expected value
                        -> ErrorFunction             -- Returns the error function
generalErrorFunction errF calVec exVec =
  let n = dim calVec -- Get us the size of the vectors

      -- Get us the error vector between the calculated
      -- vector and the expected vector by zipping
      -- the error function
      errorVec = zipVectorWith errF calVec exVec
  in
   -- Now we take the average of the sum of the errors
   (1 / fromIntegral n) * (sumElements errorVec)


-- Returns the error function given a cost detail
getErrorFunction :: Cost -> ErrorFunction
getErrorFunction MeanSquared = generalErrorFunction meanSquaredError
getErrorFunction Logistic = generalErrorFunction logisticError


-- The mean squared error function
meanSquaredError :: Double -> Double -> Double
meanSquaredError cal ex =
  (cal - ex) ** 2


-- The derivative of the mean squared error function
-- with respect to each parameter
meanSquaredDerivative :: CostDerivative
meanSquaredDerivative (Network {derivative = df}) inMat actMat exMat =
  mapMatrix (*2) ((mapMatrix df inMat) `mul` (actMat `sub` exMat))


-- The logistic error function
-- Also known as the cross-entropy
-- error function
logisticError :: Double -> Double -> Double
logisticError cal ex=
  (-1) * ((ex) * (log cal) + (1 - ex) * (log (1 - cal)))


-- The derivative of the logistic error function
-- with respect to each parameter
logisticDerivative :: CostDerivative
logisticDerivative _ _ actMat exMat = actMat `sub` exMat