module Hopfield.Measurement (
BasinMeasure
, hammingDistribution
, sampleHammingRange
, sampleHammingDistance
, samplePatternRing
, samplePatternBasin
, measurePatternBasin
, checkFixed
, measureError
) where
import Control.Monad (liftM, replicateM)
import Control.Monad.Random (MonadRandom)
import Data.List
import Data.Maybe
import qualified Data.Vector as V
import Math.Combinatorics.Exact.Binomial (choose)
import Numeric.Probability.Distribution (Spread, relative)
import Numeric.Probability.Random (T, pick)
import Hopfield.Hopfield
import Hopfield.Util ((./.), toArray, shuffle, runT)
type BasinMeasure m a = HopfieldData -> Pattern -> m a
hammingDistribution :: Int -> (Int, Int) -> T Int
hammingDistribution n (mini, maxi) = pick $ dist rs
where
dist = relative probs :: Spread Double Int
probs = [ fromIntegral $ n `choose` r | r <- rs]
rs = [mini..maxi]
sampleHammingRange :: MonadRandom m => Pattern -> T Int -> m Pattern
sampleHammingRange pat dist = do
r <- runT dist
(sample:_) <- sampleHammingDistance pat r 1
return sample
sampleHammingDistance :: MonadRandom m => Pattern -> Int -> Int -> m [Pattern]
sampleHammingDistance pat r numSamples
= liftM (map (V.fromList . multByPat)) coeffSamples
where
n = V.length pat
basePerm = toArray $ replicate r (1) ++ replicate (nr) 1
coeffSamples = replicateM numSamples $ shuffle basePerm
multByPat coeffs = zipWith (*) coeffs (V.toList pat)
samplePatternRing :: MonadRandom m => HopfieldData -> Pattern -> Int -> m Double
samplePatternRing hs pat r = do
samples <- sampleHammingDistance pat r 100
convergedPatterns <- mapM (repeatedUpdate $ weights hs) samples
let numConverging = length $ filter (==pat) convergedPatterns
return $ numConverging ./. (length samples)
samplePatternBasin :: (MonadRandom m) => BasinMeasure m [Double]
samplePatternBasin hs pat = mapM (samplePatternRing hs pat) [1..n]
where
n = V.length pat
measurePatternBasin :: (MonadRandom m) => BasinMeasure m Int
measurePatternBasin hs pat = do
t_mus <- samplePatternBasin hs pat
return $ fromMaybe n $ findIndex (<0.9) t_mus
where
n = V.length pat
compTerm :: HopfieldData -> Int -> Int -> Int
compTerm hs index n = (pat V.! n) * (computeH (weights hs) pat n pat V.! n)
where pat = (patterns hs) !! index
checkFixed :: HopfieldData -> Int -> Bool
checkFixed hs index = all (\x -> compTerm hs index x <= 1) [0.. V.length ((patterns hs) !! index) 1]
measureError :: HopfieldData -> Double
measureError hs = num_errors ./. num_pats
where
fixed_points = map (checkFixed hs) [0..num_pats1]
num_errors = length $ filter not fixed_points
num_pats = length $ patterns hs