module AI.Instinct.ConnMatrix
(
ConnMatrix,
buildLayered,
buildRandom,
buildZero,
cmAdd,
cmDests,
cmFold,
cmMap,
cmSize,
addLayer
)
where
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as U
import Control.Applicative
import Control.Arrow
import Data.List (foldl')
import Data.Monoid
import System.Random.Mersenne
import Text.Printf
newtype ConnMatrix =
CM { getCM :: V.Vector ConnVector }
instance Monoid ConnMatrix where
mempty = CM V.empty
mappend = cmAdd
instance Show ConnMatrix where
show (CM m) = " " ++ header ++ rows
where
header = concatMap (printf "%9i") $ take (V.length m) [0 :: Int ..]
rows = V.foldl (++) [] . V.imap (\i -> printf "\n%4i: %s" i . show) $ m
newtype ConnVector =
CV { getCV :: U.Vector (Bool, Double) }
instance Show ConnVector where
show =
concatMap (\(b, w) -> if b then printf "%9.5f" w else " .") .
U.toList .
getCV
addLayer :: Int -> Int -> Int -> Int -> ConnMatrix -> IO ConnMatrix
addLayer s1 n1 s2 n2 (CM m') = do
mt <- getStdGen
let (m1, m3) = second (V.drop n1) $ V.splitAt s1 m'
m2 <-
V.replicateM n1 $
fmap (\ws -> CV $ U.replicate s2 (False, 0) U.++ ws)
(U.replicateM n2 ((True, ) <$> random1 mt))
return (CM $ m1 V.++ m2 V.++ m3)
buildLayered :: [Int] -> IO ConnMatrix
buildLayered ls = mkLayer ls 0 0 0 (buildZero size)
where
mkLayer :: [Int] -> Int -> Int -> Int -> ConnMatrix -> IO ConnMatrix
mkLayer [] _ _ _ m' = return m'
mkLayer (l:ls) s1 s2 n2 m' =
addLayer s1 l s2 n2 m' >>= mkLayer ls (s1+l) s1 l
size :: Int
size = foldl' (+) 0 ls
buildRandom :: Int -> IO ConnMatrix
buildRandom size = do
mt <- getStdGen
CM <$> V.replicateM size (CV <$> U.replicateM size ((True, ) <$> random1 mt))
buildZero :: Int -> ConnMatrix
buildZero size = CM $ V.replicate size (CV U.empty)
cmAdd :: ConnMatrix -> ConnMatrix -> ConnMatrix
cmAdd (CM cm1) (CM cm2) =
CM $
V.zipWith (\(CV cv1) (CV cv2) -> CV $ U.zipWith add cv1 cv2) cm1 cm2
where
add :: (Bool, Double) -> (Bool, Double) -> (Bool, Double)
add x@(False, _) _ = x
add x@(True, _) (False, _) = x
add (True, x1) (True, x2) = (True, x1 + x2)
cmDests :: forall b. Int -> (b -> Int -> Double -> b) -> b -> ConnMatrix -> b
cmDests sk f z (CM m) = V.ifoldl' acc z m
where
acc :: b -> Int -> ConnVector -> b
acc s' dk (CV cv) =
case cv U.!? sk of
Nothing -> s'
Just (False, _) -> s'
Just (True, w) -> f s' dk w
cmFold :: Int -> (b -> Int -> Double -> b) -> b -> ConnMatrix -> b
cmFold dk f z (CM m) =
U.ifoldl' (\s sk (b, w) -> if b && w == 0 then s else f s sk w) z .
getCV $ m V.! dk
cmMap :: (Int -> Int -> Double -> Double) -> ConnMatrix -> ConnMatrix
cmMap f =
CM .
V.imap (\dk -> CV . U.imap (\sk x@(b, w) -> if b then (b, f sk dk w) else x) . getCV) .
getCM
cmSize :: ConnMatrix -> Int
cmSize (CM m) = V.length m
random1 :: MTGen -> IO Double
random1 mt = do
b <- random mt
x <- random mt
return (if b then x else x)