---------------------------------------------------- -- | -- Module : AI.Network -- License : GPL -- -- Maintainer : Kiet Lam -- -- -- This module provides common cost functions -- and their derivatives -- -- ---------------------------------------------------- module AI.Calculation.Cost ( Cost(..), getCostFunction, getCostDerivative, ) where import Data.Packed.Matrix import Data.Packed.Vector import Numeric.Container import AI.Signatures import AI.Network -- | Represents the cost model -- of the Neural Network data Cost = MeanSquared -- ^ The mean-squared cost | Logistic -- ^ The logistic cost -- | Gets the cost function associated -- with the cost model getCostFunction :: Cost -> CostFunction getCostFunction = generalCost . getErrorFunction -- | Gets the cost derivative associated -- with the cost model getCostDerivative :: Cost -> CostDerivative getCostDerivative MeanSquared = meanSquaredDerivative getCostDerivative Logistic = logisticDerivative -- The general cost function that can be extended by -- partial function application generalCost :: ErrorFunction -- The error function to be used -> Network -- The neural network of interest -> Matrix Double -- The matrix of the inputs, where the ith row -- is the input vector of a training set -> Matrix Double -- The matrix of the expected output, where the ith -- row is the expected output vector of a -- training set -> Double -- Returns the cost by comparing the network's -- output neurons and the expected output matrix generalCost errorF nn inMatrix exMatrix = let activF = toActivation nn -- The activation function ws = toWeightMatrices nn -- The list of weight matrices la = toLambda nn -- The regularization constant n = rows exMatrix -- Get us the number of training sets -- Set up the bias neurons for each forward propagation fBias = \m -> (fromColumns . ((fromList (replicate n 1)):) . toColumns) m -- This is forward propagation right here -- We propagate forward by using the functional -- foldl accumulating over the weight matrices f = \a w -> fBias (mapMatrix activF (a `multiply` w)) -- Prepare the inMatrix to be used inMatrix' = fBias inMatrix -- We fold f over the weight matrices accumulating -- the activation matrix. The resulting activation matrix -- is our output. -- We take out the bias neurons in our output layer outMatrix = (fromColumns . tail . toColumns) $ foldl f inMatrix' ws -- We first convert the outMatrix and the exMatrix into -- list of rows. Then we zip them into oVec and eVec -- Then get the error between the oVec and the eVec using our -- errorF (error function) and the result -- is a list of errors. Then we create a vector -- from the list, thus resulting in an errorVec errorVec = fromList [errorF oVec eVec | (oVec, eVec) <- zip (toRows outMatrix) (toRows exMatrix)] -- Now we get the un-regularized cost by -- summing the elements and taking the average -- by dividing by the number of the training sets j = (1 / fromIntegral n) * sumElements errorVec -- Now we get the vector representation of the -- weights to prepare for regularization wsFlattened = toWeights nn in -- Finally, we regularize our cost by using the lambda constant j + (la / (2.0 * fromIntegral n)) * (sumElements $ mapVector (**2) wsFlattened) -- This is the general error function -- It requires a function that will calculate an -- error when given a calculated value and an -- expected value generalErrorFunction :: (Double -> Double -> Double) -- The function to calculate an "error" -- Given a calculated value and an -- expected value -> ErrorFunction -- Returns the error function generalErrorFunction errF calVec exVec = let n = dim calVec -- Get us the size of the vectors -- Get us the error vector between the calculated -- vector and the expected vector by zipping -- the error function errorVec = zipVectorWith errF calVec exVec in -- Now we take the average of the sum of the errors (1 / fromIntegral n) * (sumElements errorVec) -- Returns the error function given a cost detail getErrorFunction :: Cost -> ErrorFunction getErrorFunction MeanSquared = generalErrorFunction meanSquaredError getErrorFunction Logistic = generalErrorFunction logisticError -- The mean squared error function meanSquaredError :: Double -> Double -> Double meanSquaredError cal ex = (cal - ex) ** 2 -- The derivative of the mean squared error function -- with respect to each parameter meanSquaredDerivative :: CostDerivative meanSquaredDerivative (Network {derivative = df}) inMat actMat exMat = mapMatrix (*2) ((mapMatrix df inMat) `mul` (actMat `sub` exMat)) -- The logistic error function -- Also known as the cross-entropy -- error function logisticError :: Double -> Double -> Double logisticError cal ex= (-1) * ((ex) * (log cal) + (1 - ex) * (log (1 - cal))) -- The derivative of the logistic error function -- with respect to each parameter logisticDerivative :: CostDerivative logisticDerivative _ _ actMat exMat = actMat `sub` exMat