-- |
-- Module:     AI.Instinct.Train.Delta
-- Copyright:  (c) 2011 Ertugrul Soeylemez
-- License:    BSD3
-- Maintainer: Ertugrul Soeylemez <es@ertes.de>
--
-- Delta rule aka backpropagation algorithm.

module AI.Instinct.Train.Delta
    ( -- * Backpropagation training
      TrainPat,
      train,
      trainAtomic,
      trainPat,

      -- * Low level
      learnPat,

      -- * Utility functions
      totalError,
      tpList
    )
    where

import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as U
import AI.Instinct.Activation
import AI.Instinct.Brain
import AI.Instinct.ConnMatrix
import Control.Arrow
import Data.List


-- | A training pattern is a tuple of an input pattern and an expected
-- output pattern.

type TrainPat = (Pattern, Pattern)


-- | Calculate the weight deltas and the total error for a single
-- pattern.  The second argument specifies the learning rate.

learnPat :: Brain -> Double -> TrainPat -> ConnMatrix
learnPat b@(Brain actF cm _ ol) rate (inP, expP) =
    cmMap (\sk dk _ -> rate * (delta V.! dk) * (av V.! sk)) cm

    where
    av   = activation b inP
    iv   = netInputFrom b av inP
    outP = U.convert (V.drop outk av)
    outk = size - ol
    size = cmSize cm

    dact :: Double -> Double
    dact = actDeriv actF

    delta :: V.Vector Double
    delta = V.generate size f
        where
        f k | k >= outk  = let ok = k - outk in del * ((expP U.! ok) - (outP U.! ok))
            | otherwise  = del * cmDests k (\s' dk w -> s' + (delta V.! dk) * w) 0 cm
            where
            del = dact (iv V.! k)


-- | Calculate the total error of a neural network with respect to the
-- given list of training patterns.

totalError :: Brain -> [TrainPat] -> Double
totalError b = foldl' (\e' (inP, expP) -> e' + patError (runNet b inP) expP) 0


-- | Convenience function:  Construct a training pattern from an input
-- and output vector.

tpList :: [Double] -> [Double] -> (Pattern, Pattern)
tpList = curry (U.fromList *** U.fromList)


-- | Non-atomic version of 'trainAtomic'.  Will adjust the weights for
-- each pattern instead of at the end of the epoch.

train :: Brain -> Double -> [TrainPat] -> Brain
train b' rate = foldl' (\b' -> trainPat b' rate) b'


-- | Train a list of patterns with the specified learning rate.  This
-- will adjust the weights at the end of the epoch.  Returns an updated
-- neural network and the new total error.

trainAtomic :: Brain -> Double -> [TrainPat] -> Brain
trainAtomic b'@(Brain _ cm' _ _) rate ps =
    b' { brainConns = foldl' (\m' -> cmAdd m' . learnPat b' rate) cm' ps }


-- | Train a single pattern.  The second argument specifies the learning
-- rate.

trainPat :: Brain -> Double -> TrainPat -> Brain
trainPat b@(Brain _ cm _ _) rate inP =
    b { brainConns = cmAdd cm (learnPat b rate inP) }