-- | -- Module: AI.Instinct.ConnMatrix -- Copyright: (c) 2011 Ertugrul Soeylemez -- License: BSD3 -- Maintainer: Ertugrul Soeylemez -- -- This module provides an efficient connection matrix type. module AI.Instinct.ConnMatrix ( -- * Connection matrix ConnMatrix, -- * Construction buildLayered, buildRandom, buildZero, -- * Accessing cmAdd, cmDests, cmFold, cmMap, cmSize, -- * Modification addLayer ) where import qualified Data.Vector as V import qualified Data.Vector.Unboxed as U import Control.Applicative import Control.Arrow import Data.List (foldl') import Data.Monoid import System.Random.Mersenne import Text.Printf -- | A connection matrix is essentially a two-dimensional array of -- synaptic weights. newtype ConnMatrix = CM { getCM :: V.Vector ConnVector } instance Monoid ConnMatrix where mempty = CM V.empty mappend = cmAdd instance Show ConnMatrix where show (CM m) = " " ++ header ++ rows where header = concatMap (printf "%9i") $ take (V.length m) [0 :: Int ..] rows = V.foldl (++) [] . V.imap (\i -> printf "\n%4i: %s" i . show) $ m -- | A connection vector contains the incoming weights. newtype ConnVector = CV { getCV :: U.Vector (Bool, Double) } instance Show ConnVector where show = concatMap (\(b, w) -> if b then printf "%9.5f" w else " .") . U.toList . getCV -- | @addLayer s1 n1 s2 n2@ overwrite @n1@ nodes starting from @s1@ to -- be fully connected with random weights to the @n2@ nodes starting -- from @s2@. addLayer :: Int -> Int -> Int -> Int -> ConnMatrix -> IO ConnMatrix addLayer s1 n1 s2 n2 (CM m') = do mt <- getStdGen let (m1, m3) = second (V.drop n1) $ V.splitAt s1 m' m2 <- V.replicateM n1 $ fmap (\ws -> CV $ U.replicate s2 (False, 0) U.++ ws) (U.replicateM n2 ((True, ) <$> random1 mt)) return (CM $ m1 V.++ m2 V.++ m3) -- | Build a layered connection matrix, where adjacent layers are fully -- connected. buildLayered :: [Int] -> IO ConnMatrix buildLayered ls = mkLayer ls 0 0 0 (buildZero size) where mkLayer :: [Int] -> Int -> Int -> Int -> ConnMatrix -> IO ConnMatrix mkLayer [] _ _ _ m' = return m' mkLayer (l:ls) s1 s2 n2 m' = addLayer s1 l s2 n2 m' >>= mkLayer ls (s1+l) s1 l size :: Int size = foldl' (+) 0 ls -- | Build a completely random connection matrix with the given edge -- length. The random values will be between -1 and 1 exclusive. buildRandom :: Int -> IO ConnMatrix buildRandom size = do mt <- getStdGen CM <$> V.replicateM size (CV <$> U.replicateM size ((True, ) <$> random1 mt)) -- | Build a zero connection matrix. It will represent a completely -- disconnected network, where all nodes are isolated. buildZero :: Int -> ConnMatrix buildZero size = CM $ V.replicate size (CV U.empty) -- | Add two connection matrices. Note that this function is -- left-biased in that it will adopt the connectivity of the first -- connection matrix. -- -- You may want to use the 'Monoid' instance instead of this function. cmAdd :: ConnMatrix -> ConnMatrix -> ConnMatrix cmAdd (CM cm1) (CM cm2) = CM $ V.zipWith (\(CV cv1) (CV cv2) -> CV $ U.zipWith add cv1 cv2) cm1 cm2 where add :: (Bool, Double) -> (Bool, Double) -> (Bool, Double) add x@(False, _) _ = x add x@(True, _) (False, _) = x add (True, x1) (True, x2) = (True, x1 + x2) -- | Strictly fold over the outputs, including zeroes. cmDests :: forall b. Int -> (b -> Int -> Double -> b) -> b -> ConnMatrix -> b cmDests sk f z (CM m) = V.ifoldl' acc z m where acc :: b -> Int -> ConnVector -> b acc s' dk (CV cv) = case cv U.!? sk of Nothing -> s' Just (False, _) -> s' Just (True, w) -> f s' dk w -- | Strictly fold over the nonzero inputs of a node. cmFold :: Int -> (b -> Int -> Double -> b) -> b -> ConnMatrix -> b cmFold dk f z (CM m) = U.ifoldl' (\s sk (b, w) -> if b && w == 0 then s else f s sk w) z . getCV $ m V.! dk -- | Map over the inputs of a node. cmMap :: (Int -> Int -> Double -> Double) -> ConnMatrix -> ConnMatrix cmMap f = CM . V.imap (\dk -> CV . U.imap (\sk x@(b, w) -> if b then (b, f sk dk w) else x) . getCV) . getCM -- | Edge length of a connection matrix. cmSize :: ConnMatrix -> Int cmSize (CM m) = V.length m -- | Returns a random number between -1 and 1 exclusive. random1 :: MTGen -> IO Double random1 mt = do b <- random mt x <- random mt return (if b then x else -x)