-- |
-- Module:     AI.Instinct.ConnMatrix
-- Copyright:  (c) 2011 Ertugrul Soeylemez
-- License:    BSD3
-- Maintainer: Ertugrul Soeylemez <es@ertes.de>
--
-- This module provides an efficient connection matrix type.

module AI.Instinct.ConnMatrix
    ( -- * Connection matrix
      ConnMatrix,

      -- * Construction
      buildLayered,
      buildRandom,
      buildZero,

      -- * Accessing
      cmAdd,
      cmDests,
      cmFold,
      cmMap,
      cmSize,

      -- * Modification
      addLayer
    )
    where

import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as U
import Control.Applicative
import Control.Arrow
import Data.List (foldl')
import Data.Monoid
import System.Random.Mersenne
import Text.Printf


-- | A connection matrix is essentially a two-dimensional array of
-- synaptic weights.

newtype ConnMatrix =
    CM { getCM :: V.Vector ConnVector }

instance Monoid ConnMatrix where
    mempty = CM V.empty
    mappend = cmAdd

instance Show ConnMatrix where
    show (CM m) = "      " ++ header ++ rows
        where
        header = concatMap (printf "%9i") $ take (V.length m) [0 :: Int ..]
        rows = V.foldl (++) [] . V.imap (\i -> printf "\n%4i: %s" i . show) $ m


-- | A connection vector contains the incoming weights.

newtype ConnVector =
    CV { getCV :: U.Vector (Bool, Double) }

instance Show ConnVector where
    show =
        concatMap (\(b, w) -> if b then printf "%9.5f" w else "        .") .
        U.toList .
        getCV


-- | @addLayer s1 n1 s2 n2@ overwrite @n1@ nodes starting from @s1@ to
-- be fully connected with random weights to the @n2@ nodes starting
-- from @s2@.

addLayer :: Int -> Int -> Int -> Int -> ConnMatrix -> IO ConnMatrix
addLayer s1 n1 s2 n2 (CM m') = do
    mt <- getStdGen
    let (m1, m3) = second (V.drop n1) $ V.splitAt s1 m'

    m2 <-
        V.replicateM n1 $
        fmap (\ws -> CV $ U.replicate s2 (False, 0) U.++ ws)
             (U.replicateM n2 ((True, ) <$> random1 mt))

    return (CM $ m1 V.++ m2 V.++ m3)


-- | Build a layered connection matrix, where adjacent layers are fully
-- connected.

buildLayered :: [Int] -> IO ConnMatrix
buildLayered ls = mkLayer ls 0 0 0 (buildZero size)
    where
    mkLayer :: [Int] -> Int -> Int -> Int -> ConnMatrix -> IO ConnMatrix
    mkLayer [] _ _ _ m' = return m'
    mkLayer (l:ls) s1 s2 n2 m' =
        addLayer s1 l s2 n2 m' >>= mkLayer ls (s1+l) s1 l

    size :: Int
    size = foldl' (+) 0 ls


-- | Build a completely random connection matrix with the given edge
-- length.  The random values will be between -1 and 1 exclusive.

buildRandom :: Int -> IO ConnMatrix
buildRandom size = do
    mt <- getStdGen
    CM <$> V.replicateM size (CV <$> U.replicateM size ((True, ) <$> random1 mt))


-- | Build a zero connection matrix.  It will represent a completely
-- disconnected network, where all nodes are isolated.

buildZero :: Int -> ConnMatrix
buildZero size = CM $ V.replicate size (CV U.empty)


-- | Add two connection matrices.  Note that this function is
-- left-biased in that it will adopt the connectivity of the first
-- connection matrix.
--
-- You may want to use the 'Monoid' instance instead of this function.

cmAdd :: ConnMatrix -> ConnMatrix -> ConnMatrix
cmAdd (CM cm1) (CM cm2) =
    CM $
    V.zipWith (\(CV cv1) (CV cv2) -> CV $ U.zipWith add cv1 cv2) cm1 cm2

    where
    add :: (Bool, Double) -> (Bool, Double) -> (Bool, Double)
    add x@(False, _) _ = x
    add x@(True, _) (False, _) = x
    add (True, x1) (True, x2) = (True, x1 + x2)


-- | Strictly fold over the outputs, including zeroes.

cmDests :: forall b. Int -> (b -> Int -> Double -> b) -> b -> ConnMatrix -> b
cmDests sk f z (CM m) = V.ifoldl' acc z m
    where
    acc :: b -> Int -> ConnVector -> b
    acc s' dk (CV cv) =
        case cv U.!? sk of
          Nothing -> s'
          Just (False, _) -> s'
          Just (True, w)  -> f s' dk w


-- | Strictly fold over the nonzero inputs of a node.

cmFold :: Int -> (b -> Int -> Double -> b) -> b -> ConnMatrix -> b
cmFold dk f z (CM m) =
    U.ifoldl' (\s sk (b, w) -> if b && w == 0 then s else f s sk w) z .
    getCV $ m V.! dk


-- | Map over the inputs of a node.

cmMap :: (Int -> Int -> Double -> Double) -> ConnMatrix -> ConnMatrix
cmMap f =
    CM .
    V.imap (\dk -> CV . U.imap (\sk x@(b, w) -> if b then (b, f sk dk w) else x) . getCV) .
    getCM


-- | Edge length of a connection matrix.

cmSize :: ConnMatrix -> Int
cmSize (CM m) = V.length m


-- | Returns a random number between -1 and 1 exclusive.

random1 :: MTGen -> IO Double
random1 mt = do
    b <- random mt
    x <- random mt
    return (if b then x else -x)