module Hopfield.TestUtil where
import Control.Applicative
import Control.Monad
import Control.Monad.Random
import Data.Vector ((!))
import qualified Data.Vector as V
import Test.QuickCheck
import Hopfield.Hopfield
import Hopfield.Measurement
import Hopfield.Boltzmann.RestrictedBoltzmannMachine
import Hopfield.Util
data Type = H | BM
instance (Arbitrary a) => Arbitrary (V.Vector a) where
arbitrary = fmap V.fromList arbitrary
nonempty :: forall a. Gen [a] -> Gen [a]
nonempty = (`suchThat` (not . null))
mapMonad :: Monad m => (a -> b) -> m [a] -> m [b]
mapMonad f m_xs = do
xs <- m_xs
return $ map f xs
toGenVector :: Gen [a] -> Gen (V.Vector a)
toGenVector listGen = fmap V.fromList listGen
signGen :: Gen Int
signGen = do
n <- choose (0,1)
return $ n*2 1
binaryGen :: Gen Int
binaryGen = do
n <- choose (0,1)
return n
patternGen :: Type -> Int -> Gen Pattern
patternGen H n = toGenVector $ vectorOf n signGen
patternGen BM n = toGenVector $ vectorOf n binaryGen
patternRangeGen :: Type -> (Int, Int) -> Gen Pattern
patternRangeGen t bounds = choose bounds >>= patternGen t
boundedListGen :: Gen a -> Int -> Gen [a]
boundedListGen g n = do
len <- choose (0, n)
vectorOf len g
patListGen :: Type -> Int -> Int -> Gen [Pattern]
patListGen t maxPatSize maxPatListSize = do
i <- choose (1, maxPatSize)
nonempty $ boundedListGen (patternGen t i) maxPatListSize
patternsTupleGen :: Type -> Int -> Int -> Gen ([Pattern], [Pattern])
patternsTupleGen t m1 m2 = do
fst_list <- patListGen t m1 m2
i <- choose (0, m2)
snd_list <- vectorOf i (patternGen t $ V.length $ head fst_list)
return $ (fst_list, snd_list)
sameElemList :: a -> Gen [a]
sameElemList x = do
len <- arbitrary
return $ replicate len x
sameElemVector :: a -> Gen (V.Vector a)
sameElemVector = toGenVector . sameElemList
allWeightsSame :: Int -> [[Double]]
allWeightsSame n
= [ [ if i==j then 0 else w | i <- [0..n1] ] | j <- [0..n1] ]
where w = (1 :: Int) ./. n
boundedReplicateGen :: Int -> Gen a -> Gen [a]
boundedReplicateGen n g = liftM2 replicate (choose (0, n)) g
replaceAtN :: Int -> a -> [a] -> [a]
replaceAtN _ _ [] = error "index greater than list size"
replaceAtN 0 r (_:xs) = (r:xs)
replaceAtN n r (x:xs)
| n > 0 = (x:(replaceAtN (n1) r xs))
| otherwise = error "negative index"
crosstalk :: HopfieldData -> Int -> Int -> Int
crosstalk hs index n = computeH (weights hs) pat n pat ! n
where pat = (patterns hs) !! index
trainingPatsAreFixedPoints :: LearningType -> [Pattern] -> Gen Bool
trainingPatsAreFixedPoints method pats =
and <$> mapM checkFixedPoint [0.. length pats 1]
where
hs = buildHopfieldData method pats
ws = weights hs
checkFixedPoint index = do
i <- arbitrary
return $ evalRand (update ws (pats !! index)) (mkStdGen i) == Nothing || (not $ checkFixed hs index)
energyDecreasesAfterUpdate :: LearningType -> ([Pattern], [Pattern]) -> Gen Bool
energyDecreasesAfterUpdate method (training_pats, pats)
= and <$> (forM pats $ \pat -> do
i <- arbitrary
return $ evalRand (energyDecreases pat) (mkStdGen i)
)
where
ws = weights $ buildHopfieldData method training_pats
check pat afterPat = energy ws pat >= energy ws afterPat || energy ws afterPat energy ws pat <= 0.00000001
energyDecreases :: (MonadRandom m) => Pattern -> m Bool
energyDecreases pat = do
maybe_pat <- update ws pat
case maybe_pat of
Nothing -> return True
Just updatedPattern -> return $ check pat updatedPattern
repeatedUpdateCheck :: LearningType -> ([Pattern], [Pattern]) -> Gen Bool
repeatedUpdateCheck method (training_pats, pats)
= and <$> mapM s pats
where
ws = weights $ buildHopfieldData method training_pats
stopped pat = do
p <- converged_pattern
maybe_new_p <- update ws p
return $ maybe_new_p == Nothing
where
converged_pattern = repeatedUpdate ws pat
s pat = do
i <- arbitrary
return $ evalRand (stopped pat) (mkStdGen i)
boltzmannBuildGen :: Int -> Int -> Int -> Gen ([Pattern], Int)
boltzmannBuildGen maxPatSize maxPatListSize max_hidden = do
pats <- patListGen BM maxPatSize maxPatListSize
i <- choose (1, max_hidden)
return $ (pats, i)
buildBoltzmannCheck :: ([Pattern], Int) -> Gen Bool
buildBoltzmannCheck (pats, nr_h) = do
i <- arbitrary
let bd = evalRand (buildBoltzmannData' pats nr_h) (mkStdGen i)
return $ patternsB bd == pats && nr_hiddenB bd == nr_h
boltzmannAndPatGen :: Int -> Int -> Int -> Gen ([Pattern], Int, Pattern)
boltzmannAndPatGen maxPatSize maxPatListSize max_hidden = do
pats_train <- patListGen BM maxPatSize maxPatListSize
i <- choose (1, max_hidden)
pats_check <- patternGen BM (V.length $ pats_train !! 0)
return $ (pats_train, i, pats_check)
probabilityCheck :: ([Pattern], Int, Pattern) -> Gen Bool
probabilityCheck (pats, nr_h, pat) = do
seed <- arbitrary
let bd = evalRand (buildBoltzmannData' pats nr_h) (mkStdGen seed)
ws = weightsB bd
return $ all (\x -> c $ getActivationProbability Matching Visible ws pat x) [0 .. nr_h 1]
where c x = x <= 1 && x >=0
updateNeuronCheck :: Int -> ([Pattern], Int, Pattern) -> Gen Bool
updateNeuronCheck r (pats, nr_h, pat)
| not (r == 0 || r == 1) = error "r has to be 0 or 1 for updateNeuronCheck"
| otherwise = do
i <- choose (0, nr_h 1)
seed <- arbitrary
let bd = evalRand (buildBoltzmannData' pats nr_h) (mkStdGen seed)
return $ updateNeuron' (fromIntegral r) Matching Visible (weightsB bd) pat i == (1 r)
buildIntTuple :: Gen (Int, Int)
buildIntTuple = do
i <- choose (1, 100)
let min_size = ceiling $ log2 $ fromIntegral i
j <- choose (min_size + 1, min_size + 2)
return (i, j)
binaryCheck :: (Int, Int) -> Bool
binaryCheck (x, y) = x == refold
where
refold = sum [ b * 2^pos | b <- reverse bits | pos <- [(0:: Int)..] ]
bits = toBinary x y
evalRandGen :: Rand StdGen a -> Gen a
evalRandGen e = do
rndInt <- arbitrary
return $ evalRand e (mkStdGen rndInt)