{-# LANGUAGE PatternGuards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE BangPatterns #-} -- | Base Hopfield model, providing training and running. module Hopfield.Hopfield ( Pattern , Weights , LearningType (Hebbian, Storkey) -- * Hopfield data structure , HopfieldData () , weights , patterns , buildHopfieldData -- * Running , update , addPatterns , repeatedUpdate , updateChain , matchPattern , computeH -- * Energy , energy ) where import Control.Monad import Control.Monad.Random (MonadRandom) import Data.Maybe import Data.Vector ((!)) import qualified Data.Vector as V import Data.Vector.Generic.Mutable (write) import Hopfield.Common import Hopfield.Util data LearningType = Hebbian | Storkey deriving (Eq, Show, Read) --make Hopefield data implement show -- | Encapsulates the network weights together with the patterns that generate -- it with the patterns which generate it data HopfieldData = HopfieldData { weights :: Weights -- ^ the weights of the network , patterns :: [Pattern] -- ^ the patterns which were used to train it } deriving (Show) -- | Checks if weights and pattern given to the function satisfy their constraints, -- if yes, calling the function, otherwise erroring out. -- Usage: `checkWsPat (functionTakingWeightsAndPattern)`. checkWsPat :: (Weights -> Pattern -> a) -> Weights -> Pattern -> a checkWsPat f ws pat | Just e <- validWeights ws = error e | Just e <- validPattern pat = error e | Just e <- validWeightsPatternSize ws pat = error e | otherwise = f ws pat -- | @update weights pattern@: Applies the update rule on @pattern@ for the -- first updatable neuron given the Hopfield network (represented by @weights@). -- -- Pre: @length weights == length pattern@ update :: MonadRandom m => Weights -> Pattern -> m (Maybe Pattern) update = checkWsPat update_ -- | @repeatedUpdate weights pattern@: Performs repeated updates on the given -- pattern until it reaches a stable state with respect to the Hopfield network -- (represented by @weights@). -- Pre: @length weights == length pattern@ repeatedUpdate :: (MonadRandom m) => Weights -> Pattern -> m Pattern repeatedUpdate = checkWsPat repeatedUpdate_ -- | Computes the weighted sum of current neuron values, which will give us -- the value of the neuron (by taking the sign) computeH :: Weights -> Pattern -> Int -> Int computeH ws pat i = checkWsPat (\w p -> computeH_ w p i) ws pat -- | @energy weights pattern@: Computes the energy of a pattern given a Hopfield -- network (represented by @weights@). -- Pre: @length weights == length pattern@ energy :: Weights -> Pattern -> Double energy = checkWsPat energy_ -- | @buildHopfieldData patterns@: Takes a list of patterns and -- builds a Hopfield network (by training) in which these patterns are -- stable states. The result of this function can be used to run a pattern -- against the network, by using 'matchPattern'. buildHopfieldData :: LearningType -> [Pattern] -> HopfieldData buildHopfieldData _ [] = error "Train patterns are empty" buildHopfieldData learningType pats | first_len == 0 = error "Cannot have empty patterns" | any (\x -> V.length x /= first_len) pats = error "All training patterns must have the same length" | otherwise = HopfieldData (trainingFunction pats) pats where first_len = V.length (head pats) trainingFunction = case learningType of Hebbian -> train Storkey -> trainStorkey -- | @train patterns@: Trains and constructs network given a list of patterns -- which are used to build the weight matrix. As a consequence, they will be -- stable points in the network (by construction). train :: [Pattern] -> Weights train pats = vector2D ws -- No need to check pats ws size, buildHopfieldData does it where ws = [ [ w i j ./. n | j <- [0 .. n-1] ] | i <- [0 .. n-1] ] w i j | i == j = 0 | otherwise = sum [ (pat ! i) * (pat ! j) | pat <- pats ] n = V.length (head pats) -- | See `computeH`. computeH_ :: Weights -> Pattern -> Int -> Int computeH_ ws pat i = {-# SCC "computeHall" #-} if weighted >= 0 then 1 else -1 where weighted :: Double wss = ws ! i weighted = go 0 0.0 go :: Int -> Double -> Double go !j !s | j == p = s | otherwise = let w = wss `V.unsafeIndex` j x = if pat `V.unsafeIndex` j > 0 then w else -w in go (j+1) (s+x) p = {-# SCC "computeHvlength" #-} V.length pat -- | See `update`. -- The update is done by finding a neuron that will change its value given the -- current state. The search for this neuron is done in a random manner: -- pick up a random neuron, check if it is updatable: if so, update the pattern -- by updating this neuron. If not, continue until an updatable neuron is found. -- (Note: Initially the update was performed by obtaining a list of all -- updatable neurons and then picking a random one. The current method is 2 times -- faster) update_ :: MonadRandom m => Weights -> Pattern -> m (Maybe Pattern) update_ ws pat = do randomIndices <- shuffle . toArray $ [0 .. V.length pat - 1] -- TODO avoid Array -> List -> Vector conversion return $ case firstUpdatable (V.fromList randomIndices) of Nothing -> Nothing Just index -> Just $ flipAtIndex pat index where firstUpdatable indices = go 0 where go n | n == V.length pat = Nothing | pat ! i /= computeH_ ws pat i = Just i | otherwise = go (n+1) where i = indices ! n flipAtIndex vec index = let val = vec ! index -- seq only brings small saving here in val `seq` V.modify (\v -> write v index (-val)) vec -- | See `repeatedUpdate`. repeatedUpdate_ :: (MonadRandom m) => Weights -> Pattern -> m Pattern repeatedUpdate_ ws pat = repeatUntilNothing (update_ ws) pat -- | @matchPatterns hopfieldData pattern@: -- Computes the stable state of a pattern given a Hopfield network(represented -- by @weights@) and tries to find a match in a list of patterns which are -- stored in @hopfieldData@. -- Returns: -- -- The index of the matching pattern in @patterns@, if a match exists -- The converged pattern (the stable state), otherwise -- -- Pre: @length weights == length pattern@ matchPattern :: MonadRandom m => HopfieldData -> Pattern -> m (Either Pattern Int) matchPattern (HopfieldData ws pats) pat = do converged_pattern <- repeatedUpdate_ ws pat return $ findInList pats converged_pattern -- | Like `repeatedUpdate`, but collecting all patterns until convergence. -- The last pattern in the list is the converged pattern. -- The argument pattern is NOT prepended to the result list. -- -- POST: The returned list is not empty. updateChain :: (MonadRandom m) => HopfieldData -> Pattern -> m [Pattern] updateChain (HopfieldData ws _pats) pat | Just e <- validPattern pat = error e | otherwise = (pat:) `liftM` unfoldrSelfM (update_ ws) pat -- | Stores patterns in an already trained network. One has to ensure that this -- function is not over used, as this will decrease the capacity of the network. addPatterns :: LearningType -> HopfieldData -> [Pattern] -> HopfieldData addPatterns learning (HopfieldData ws pats) addedPats | any (isJust . validPattern) addedPats = error "invalid patterns in addMultiplePatterns" | any (isJust . validWeightsPatternSize ws) addedPats = error "pattern does not match weights in addMultiplePatterns" | otherwise = HopfieldData new_ws (pats ++ addedPats) where new_ws = foldl (updateWeightsGivenNewPattern learning) ws addedPats -- Updates the weight matrix when a new pattern is stored in the network updateWeightsGivenNewPattern :: LearningType -> Weights -> Pattern -> Weights updateWeightsGivenNewPattern Storkey ws pat = updateWeightsStorkey ws pat updateWeightsGivenNewPattern Hebbian ws pat = vector2D updated_ws where updated_ws = [ [ws ! i ! j + (pat ! i * pat ! j) ./. n | j <- neurons ] | i <- neurons] n = V.length ws - 1 neurons = [0 .. n] -- | See `energy`. energy_ :: Weights -> Pattern -> Double energy_ ws pat = s / (-2.0) where p = V.length pat w i j = ws ! i ! j x i = pat ! i s = sum [ w i j *. (x i * x j) | i <- [0 .. p-1], j <- [0 .. p-1] ] -- | Checks if a pattern consists of only 1s and -1s. -- Returns @Nothing@ on success, an error string on failure. validPattern :: Pattern -> Maybe String validPattern pat = case [ x | x <- V.toList pat, not (x == 1 || x == -1) ] of [] -> Nothing x:_ -> Just $ "Pattern contains invalid value " ++ show x -- | @validWeightsPatternSize weights pattern@ -- Returns an error string in a Just if the @pattern@ is not compatible -- with @weights@ and Nothing otherwise. validWeightsPatternSize :: Weights -> Pattern -> Maybe String validWeightsPatternSize ws pat | V.length ws /= V.length pat = Just "Pattern size must match network size" | otherwise = Nothing -- Checks the validity of a weight matrix by ensuring: -- * It is non-empty -- -- * It is square -- -- * It is symmetric -- -- * All diagonal elements must be zero -- These checks hold for both Hebbian and Storkey. validWeights :: Weights -> Maybe String validWeights ws | n == 0 = Just "Weight matrix must be non-empty" | any (\x -> V.length x /= n) $ V.toList ws = Just "Weight matrix has to be a square matrix" | any (/= 0) [ ws ! i ! i | i <- [0..n-1] ] = Just "Weight matrix first diagonal must be zero" | not $ and [ abs( (ws ! i ! j) - (ws ! j ! i) ) < 0.0001 | i <- [0..n-1], j <- [0..n-1] ] = Just "Weight matrix must be symmetric" | null [ abs (ws ! i ! j) > 1 | i <- [0..n-1], j <- [0..n-1] ] = Just "Weights should be between (-1, 1)" | otherwise = Nothing where n = V.length ws -- Storkey training provides advantages for the Hopfield network as -- it gives it bigger capacity and higher basins of attraction. -- For more details see: -- http://homepages.inf.ed.ac.uk/amos/publications/Storkey1997IncreasingtheCapacityoftheHopfieldNetworkwithoutSacrificingFunctionality.pdf -- | @storkeyHiddenSum ws pat i j@ computes the value at indices @i@ @j@ in the -- hidden matrix which is used for updating in the weight matrix during trainig -- given the training pattern @pat@. storkeyHiddenSum :: Weights -> Pattern -> Int -> Int -> Double storkeyHiddenSum ws pat i j = sum [ ws ! i ! k *. (pat ! k) | k <- [0 .. n - 1] , k /= i , k /= j] where n = V.length ws -- | @updateWeightsGivenIndicesStorkey ws pat i j@ computes the new value at -- indices @i@ @j@ of the weights matrix for the training iteration of -- pattern @pat@. updateWeightsGivenIndicesStorkey :: Weights -> Pattern -> Int -> Int -> Double updateWeightsGivenIndicesStorkey ws pat i j | i == j = 0.0 | otherwise = ws ! i ! j + (1 :: Int) ./. n * (fromIntegral (pat ! i * (pat ! j)) - h j i *. (pat ! i) - h i j *. (pat ! j)) where n = V.length ws h = storkeyHiddenSum ws pat -- | @updateWeightsStorkey ws pat@ updates the weights matrix, given training -- instance @pat@. updateWeightsStorkey :: Weights -> Pattern -> Weights updateWeightsStorkey ws pat = vector2D [ [ updateWeightsGivenIndicesStorkey ws pat i j | j <- [0 ..n - 1] ] | i <- [0 ..n - 1] ] where n = V.length ws -- | @trainStorkey pats@ trains the Hopfield network by computing the weights -- matrix by iterating trough all training instances (@pats@) and updating the -- weights according to the Storkey learning rule. trainStorkey :: [Pattern] -> Weights -- No need to check pats ws size, buildHopfieldData does it trainStorkey pats = foldl updateWeightsStorkey start_ws pats where start_ws = vector2D $ replicate n $ replicate n 0 n = V.length $ head pats