-- |
-- Module:     AI.Instinct.Brain
-- Copyright:  (c) 2011 Ertugrul Soeylemez
-- License:    BSD3
-- Maintainer: Ertugrul Soeylemez <es@ertes.de>
--
-- This module provides artifical neural networks.

module AI.Instinct.Brain
    ( -- * Brains
      Brain(..),
      Pattern,

      -- * Initialization
      NetInit(..),
      buildNet,

      -- * High level
      runNet,
      runNetList,

      -- * Low level
      activation,
      netInput,
      netInputFrom,

      -- * Utility functions
      listPat,
      patError
    )
    where

import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as U
import AI.Instinct.Activation
import AI.Instinct.ConnMatrix
import Text.Printf


-- | A 'Brain' value is an aritifical neural network.

data Brain =
    Brain {
      brainAct     :: Activation,  -- ^ Activation function.
      brainConns   :: ConnMatrix,  -- ^ Connection matrix.
      brainInputs  :: Int,         -- ^ Number of input neurons.
      brainOutputs :: Int          -- ^ Number of output neurons.
    }

instance Show Brain where
    show (Brain actF cm il ol) =
        printf "Neural network: %i input(s), %i output(s), %s\n%s\n"
               il ol (show actF) (replicate 72 '-') ++
        show cm


-- | Network builder configuration.  See 'buildNet'.

data NetInit =
    -- | Recipe for a multi-layer perceptron.  This is a neural network,
    -- which is made up of neuron layers, where adjacent layers are (in
    -- this case fully) connected.
    InitMLP {
      mlpActFunc :: Activation,  -- ^ Network's activation function.
      mlpLayers  :: [Int]        -- ^ Layer sizes from input to output.
    }
    deriving (Read, Show)


-- | A signal pattern.

type Pattern = U.Vector Double


-- | Feeds the given input vector into the network and calculates the
-- activation vector.

activation :: Brain -> Pattern -> V.Vector Double
activation (Brain actF cm il _) inP = av
    where
    af = actFunc actF

    actOf :: Int -> Double
    actOf dk
        | dk < il   = inP U.! dk
        | otherwise = af $ cmFold dk (\s sk w -> s + w * actOf sk) 0 cm

    av :: V.Vector Double
    av = V.generate (cmSize cm) actOf


-- | Build a random neural network from the given description.

buildNet :: NetInit -> IO Brain
buildNet (InitMLP actF ls) = do
    let il = head ls
        ol = last ls

    cm <- buildLayered ls
    let b = Brain { brainAct = actF,
                    brainConns = cm,
                    brainInputs = il,
                    brainOutputs = ol }

    return b


-- | Construct a pattern vector from a list.

listPat :: [Double] -> Pattern
listPat = U.fromList


-- | Calculate the net input vector, i.e. the values just before
-- applying the activation function.

netInput :: Brain -> Pattern -> V.Vector Double
netInput b@(Brain _ cm il _) inP = iv
    where
    av = activation b inP
    iv = V.generate (cmSize cm) inputOf

    inputOf :: Int -> Double
    inputOf dk
        | dk < il   = inP U.! dk
        | otherwise = cmFold dk (\s sk w -> s + w * (av V.! sk)) 0 cm


-- | Calculate the net input vector from the given activation vector.

netInputFrom :: Brain -> V.Vector Double -> Pattern -> V.Vector Double
netInputFrom (Brain _ cm il _) av inP = iv
    where
    iv = V.generate (cmSize cm) inputOf

    inputOf :: Int -> Double
    inputOf dk
        | dk < il   = inP U.! dk
        | otherwise = cmFold dk (\s sk w -> s + w * (av V.! sk)) 0 cm


-- | The total discrepancy between the two given patterns.  Can be used
-- to calculate the total network error.

patError :: Pattern -> Pattern -> Double
patError p1 p2 = U.sum (U.zipWith (\x y -> let e = x - y in e*e) p1 p2)


-- | Pass the given input pattern through the given neural network and
-- return its output.

runNet :: Brain -> Pattern -> Pattern
runNet b@(Brain _ cm _ ol) inP =
    V.convert .
    V.drop (cmSize cm - ol) $
    activation b inP


-- | Convenience wrapper around 'runNet' using lists instead of vectors.
-- If you care for performance, use 'runNet'.

runNetList :: Brain -> [Double] -> [Double]
runNetList b = U.toList . runNet b . U.fromList