----------------------------------------------------------
-- |
-- Module     :  AI.Network
-- License    :  GPL
--
-- Maintainer :  Kiet Lam <ktklam9@gmail.com>
--
--
-- Efficient representation of an Artificial Neural Network
-- using vector to represent the weights between each layer
--
-- This module provides the neural network data representation
-- that will be used extensively
--
--
---------------------------------------------------------


module AI.Network (
  Network(..),
  toActivation, toDerivative,
  toLambda, toWeights,
  toWeightMatrices, toArchitecture,
  setActivation, setDerivative,
  setLambda, setWeights,
  setArchitecture
  ) where

import Data.Packed.Vector
import Data.Packed.Matrix


-- | The representation of an Artificial Neural Network
data Network = Network
               {
                 activation   :: (Double -> Double), -- ^ The activation function for each
                                                     --   neuron
                 derivative   :: (Double -> Double), -- ^ The derivative of the activation
                                                     --   function
                 lambda       :: Double,             -- ^ The regularization constant
                 weights      :: Vector Double,      -- ^ The vector of the weights between each
                                                     --   layer of the neural network
                 architecture :: [Int]               -- ^ The architecture of the neural
                                                     --   networks.
                                                     --
                                                     --   e.g., a network of an architecture
                                                     --   of 2-3-1 would have an architecture
                                                     --   representation of [2,3,1]
                                                     --
                 -- NOTE: The library will automatically create
                 -- a bias neuron in each layer, so you do not need
                 -- to state them explicitly
               }


-- Self-explanatory
toActivation :: Network -> (Double -> Double)
toActivation (Network {activation = f}) = f

toDerivative :: Network -> (Double -> Double)
toDerivative (Network {derivative = df}) = df

toLambda :: Network -> Double
toLambda (Network {lambda = la}) = la

toWeights :: Network -> Vector Double
toWeights (Network {weights = w}) = w

-- | Get the list of matrices of weights between
-- each layer. This can be more useful
-- than the barebone vector representation
-- of the weights
toWeightMatrices :: Network -> [Matrix Double]
toWeightMatrices (Network {weights = ws, architecture = arch}) =
  let elems = 0:[((x + 1) * y) | (x,y) <- zip arch (tail arch)] in
  [reshape i v | (i, v) <- zip (tail arch) (takesV (tail elems) ws)]

toArchitecture :: Network -> [Int]
toArchitecture (Network {architecture = a}) = a


-- Self-explanatory
setActivation :: Network -> (Double -> Double) -> Network
setActivation (Network {derivative = df, lambda = la, weights = w, architecture = a}) f =
  (Network {activation = f, derivative = df, lambda = la, weights = w, architecture = a})

setDerivative :: Network -> (Double -> Double) -> Network
setDerivative (Network {activation = f, lambda = la, weights = w, architecture = a}) df =
  (Network {activation = f, derivative = df, lambda = la, weights = w, architecture = a})

setLambda :: Network -> Double -> Network
setLambda (Network {activation = f, derivative = df, weights = w, architecture = a}) la =
  Network {activation = f, derivative = df, lambda = la, weights = w, architecture = a}

setWeights :: Network -> Vector Double -> Network
setWeights (Network {activation = f, derivative = df, lambda = la, architecture = a}) w =
  (Network {activation = f, derivative = df, lambda = la, weights = w, architecture = a})

setArchitecture :: Network -> [Int] -> Network
setArchitecture (Network {activation = f, derivative = df, lambda = la, weights = w}) a =
  (Network {activation = f, derivative = df, lambda = la, weights = w, architecture = a})